Merge branch 'master' of https://github.com/tensorflow/tensorflow into tensorrt to fix some of the failed tests.
This commit is contained in:
commit
149fc8dbd6
@ -4,7 +4,7 @@ https://stackoverflow.com/questions/tagged/tensorflow
|
||||
|
||||
If you open a GitHub issue, here is our policy:
|
||||
|
||||
1. It must be a bug or a feature request.
|
||||
1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead).
|
||||
2. The form below must be filled out.
|
||||
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues).
|
||||
|
||||
|
@ -6,7 +6,7 @@
|
||||
|
||||
| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
|
||||
|-----------------|---------------------|------------------|-------------------|---------------|
|
||||
| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) |
|
||||
| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
|
||||
|
||||
**TensorFlow** is an open source software library for numerical computation using
|
||||
data flow graphs. The graph nodes represent mathematical operations, while
|
||||
|
25
RELEASE.md
25
RELEASE.md
@ -1,18 +1,39 @@
|
||||
# Release 1.5.0
|
||||
|
||||
## Breaking Changes
|
||||
* Prebuilt binaries are now built against CUDA 9 and cuDNN 7.
|
||||
* Prebuilt binaries are now built against CUDA 9.0 and cuDNN 7.
|
||||
* Our Linux binaries are built using ubuntu 16 containers, potentially
|
||||
introducing glibc incompatibility issues with ubuntu 14.
|
||||
* Starting from 1.6 release, our prebuilt binaries will use AVX instructions.
|
||||
This may break TF on older CPUs.
|
||||
|
||||
## Known Bugs
|
||||
* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or
|
||||
`CUDA_ILLEGAL_ADDRESS` failures.
|
||||
|
||||
Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9
|
||||
and CUDA 9.1 sometimes does not properly compute the carry bit when
|
||||
decomposing 64-bit address calculations with large offsets (e.g. `load [x +
|
||||
large_constant]`) into 32-bit arithmetic in SASS.
|
||||
|
||||
As a result, these versions of `ptxas` miscompile most XLA programs which use
|
||||
more than 4GB of temp memory. This results in garbage results and/or
|
||||
`CUDA_ERROR_ILLEGAL_ADDRESS` failures.
|
||||
|
||||
A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a
|
||||
fix for CUDA 9.0.x. Until the fix is available, the only workaround is to
|
||||
[downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x
|
||||
or disable XLA:GPU.
|
||||
|
||||
TensorFlow will print a warning if you use XLA:GPU with a known-bad version of
|
||||
CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122.
|
||||
|
||||
## Major Features And Improvements
|
||||
* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager)
|
||||
preview version is now available.
|
||||
* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite)
|
||||
dev preview is now available.
|
||||
* CUDA 9 and cuDNN 7 support.
|
||||
* CUDA 9.0 and cuDNN 7 support.
|
||||
* Accelerated Linear Algebra (XLA):
|
||||
* Add `complex64` support to XLA compiler.
|
||||
* `bfloat` support is now added to XLA infrastructure.
|
||||
|
@ -298,7 +298,7 @@ def get_var(environ_cp,
|
||||
System".
|
||||
enabled_by_default: boolean for default behavior.
|
||||
question: optional string for how to ask for user input.
|
||||
yes_reply: optionanl string for reply when feature is enabled.
|
||||
yes_reply: optional string for reply when feature is enabled.
|
||||
no_reply: optional string for reply when feature is disabled.
|
||||
|
||||
Returns:
|
||||
@ -411,7 +411,7 @@ def set_action_env_var(environ_cp,
|
||||
System".
|
||||
enabled_by_default: boolean for default behavior.
|
||||
question: optional string for how to ask for user input.
|
||||
yes_reply: optionanl string for reply when feature is enabled.
|
||||
yes_reply: optional string for reply when feature is enabled.
|
||||
no_reply: optional string for reply when feature is disabled.
|
||||
"""
|
||||
var = int(
|
||||
@ -1354,6 +1354,7 @@ def main():
|
||||
environ_cp['TF_NEED_GCP'] = '0'
|
||||
environ_cp['TF_NEED_HDFS'] = '0'
|
||||
environ_cp['TF_NEED_JEMALLOC'] = '0'
|
||||
environ_cp['TF_NEED_KAFKA'] = '0'
|
||||
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
|
||||
environ_cp['TF_NEED_COMPUTECPP'] = '0'
|
||||
environ_cp['TF_NEED_OPENCL'] = '0'
|
||||
@ -1372,6 +1373,8 @@ def main():
|
||||
'with_hdfs_support', True, 'hdfs')
|
||||
set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
|
||||
'with_s3_support', True, 's3')
|
||||
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
|
||||
'with_kafka_support', False, 'kafka')
|
||||
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
|
||||
False, 'xla')
|
||||
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',
|
||||
|
@ -211,6 +211,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_kafka_support",
|
||||
define_values = {"with_kafka_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Crosses between platforms and file system libraries not supported on those
|
||||
# platforms due to limitations in nested select() statements.
|
||||
config_setting(
|
||||
@ -544,8 +550,10 @@ filegroup(
|
||||
"//tensorflow/contrib/predictor:all_files",
|
||||
"//tensorflow/contrib/py2tf:all_files",
|
||||
"//tensorflow/contrib/py2tf/converters:all_files",
|
||||
"//tensorflow/contrib/py2tf/impl:all_files",
|
||||
"//tensorflow/contrib/py2tf/pyct:all_files",
|
||||
"//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
|
||||
"//tensorflow/contrib/py2tf/utils:all_files",
|
||||
"//tensorflow/contrib/quantize:all_files",
|
||||
"//tensorflow/contrib/receptive_field:all_files",
|
||||
"//tensorflow/contrib/reduce_slice_ops:all_files",
|
||||
|
239
tensorflow/SECURITY.md
Normal file
239
tensorflow/SECURITY.md
Normal file
@ -0,0 +1,239 @@
|
||||
# Using TensorFlow Securely
|
||||
|
||||
This document discusses how to safely deal with untrusted programs (models or
|
||||
model parameters), and input data. Below, we also provide guidelines on how to
|
||||
report vulnerabilities in TensorFlow.
|
||||
|
||||
## TensorFlow models are programs
|
||||
|
||||
TensorFlow's runtime system interprets and executes programs. What machine
|
||||
learning practitioners term
|
||||
[**models**](https://developers.google.com/machine-learning/glossary/#model) are
|
||||
expressed as programs that TensorFlow executes. TensorFlow programs are encoded
|
||||
as computation
|
||||
[**graphs**](https://developers.google.com/machine-learning/glossary/#graph).
|
||||
The model's parameters are often stored separately in **checkpoints**.
|
||||
|
||||
At runtime, TensorFlow executes the computation graph using the parameters
|
||||
provided. Note that the behavior of the computation graph may change
|
||||
depending on the parameters provided. TensorFlow itself is not a sandbox. When
|
||||
executing the computation graph, TensorFlow may read and write files, send and
|
||||
receive data over the network, and even spawn additional processes. All these
|
||||
tasks are performed with the permissions of the TensorFlow process. Allowing
|
||||
for this flexibility makes for a powerful machine learning platform,
|
||||
but it has implications for security.
|
||||
|
||||
The computation graph may also accept **inputs**. Those inputs are the
|
||||
data you supply to TensorFlow to train a model, or to use a model to run
|
||||
inference on the data.
|
||||
|
||||
**TensorFlow models are programs, and need to be treated as such from a security
|
||||
perspective.**
|
||||
|
||||
## Running untrusted models
|
||||
|
||||
As a general rule: **Always** execute untrusted models inside a sandbox (e.g.,
|
||||
[nsjail](https://github.com/google/nsjail)).
|
||||
|
||||
There are several ways in which a model could become untrusted. Obviously, if an
|
||||
untrusted party supplies TensorFlow kernels, arbitrary code may be executed.
|
||||
The same is true if the untrusted party provides Python code, such as the
|
||||
Python code that generates TensorFlow graphs.
|
||||
|
||||
Even if the untrusted party only supplies the serialized computation
|
||||
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
|
||||
set of computation primitives available to TensorFlow is powerful enough that
|
||||
you should assume that the TensorFlow process effectively executes arbitrary
|
||||
code. One common solution is to whitelist only a few safe Ops. While this is
|
||||
possible in theory, we still recommend you sandbox the execution.
|
||||
|
||||
It depends on the computation graph whether a user provided checkpoint is safe.
|
||||
It is easily possible to create computation graphs in which malicious
|
||||
checkpoints can trigger unsafe behavior. For example, consider a graph that
|
||||
contains a `tf.cond` depending on the value of a `tf.Variable`. One branch of
|
||||
the `tf.cond` is harmless, but the other is unsafe. Since the `tf.Variable` is
|
||||
stored in the checkpoint, whoever provides the checkpoint now has the ability to
|
||||
trigger unsafe behavior, even though the graph is not under their control.
|
||||
|
||||
In other words, graphs can contain vulnerabilities of their own. To allow users
|
||||
to provide checkpoints to a model you run on their behalf (e.g., in order to
|
||||
compare model quality for a fixed model architecture), you must carefully audit
|
||||
your model, and we recommend you run the TensorFlow process in a sandbox.
|
||||
|
||||
## Accepting untrusted Inputs
|
||||
|
||||
It is possible to write models that are secure in a sense that they can safely
|
||||
process untrusted inputs assuming there are no bugs. There are two main reasons
|
||||
to not rely on this: first, it is easy to write models which must not be exposed
|
||||
to untrusted inputs, and second, there are bugs in any software system of
|
||||
sufficient complexity. Letting users control inputs could allow them to trigger
|
||||
bugs either in TensorFlow or in dependent libraries.
|
||||
|
||||
In general, it is good practice to isolate parts of any system which is exposed
|
||||
to untrusted (e.g., user-provided) inputs in a sandbox.
|
||||
|
||||
A useful analogy to how any TensorFlow graph is executed is any interpreted
|
||||
programming language, such as Python. While it is possible to write secure
|
||||
Python code which can be exposed to user supplied inputs (by, e.g., carefully
|
||||
quoting and sanitizing input strings, size-checking input blobs, etc.), it is
|
||||
very easy to write Python programs which are insecure. Even secure Python code
|
||||
could be rendered insecure by a bug in the Python interpreter, or in a bug in a
|
||||
Python library used (e.g.,
|
||||
[this one](https://www.cvedetails.com/cve/CVE-2017-12852/)).
|
||||
|
||||
## Running a TensorFlow server
|
||||
|
||||
TensorFlow is a platform for distributed computing, and as such there is a
|
||||
TensorFlow server (`tf.train.Server`). **The TensorFlow server is meant for
|
||||
internal communication only. It is not built for use in an untrusted network.**
|
||||
|
||||
For performance reasons, the default TensorFlow server does not include any
|
||||
authorization protocol and sends messages unencrypted. It accepts connections
|
||||
from anywhere, and executes the graphs it is sent without performing any checks.
|
||||
Therefore, if you run a `tf.train.Server` in your network, anybody with
|
||||
access to the network can execute what you should consider arbitrary code with
|
||||
the privileges of the process running the `tf.train.Server`.
|
||||
|
||||
When running distributed TensorFlow, you must isolate the network in which the
|
||||
cluster lives. Cloud providers provide instructions for setting up isolated
|
||||
networks, which are sometimes branded as "virtual private cloud." Refer to the
|
||||
instructions for
|
||||
[GCP](https://cloud.google.com/compute/docs/networks-and-firewalls) and
|
||||
[AWS](https://aws.amazon.com/vpc/)) for details.
|
||||
|
||||
Note that `tf.train.Server` is different from the server created by
|
||||
`tensorflow/serving` (the default binary for which is called `ModelServer`).
|
||||
By default, `ModelServer` also has no built-in mechanism for authentication.
|
||||
Connecting it to an untrusted network allows anyone on this network to run the
|
||||
graphs known to the `ModelServer`. This means that an attacker may run
|
||||
graphs using untrusted inputs as described above, but they would not be able to
|
||||
execute arbitrary graphs. It is possible to safely expose a `ModelServer`
|
||||
directly to an untrusted network, **but only if the graphs it is configured to
|
||||
use have been carefully audited to be safe**.
|
||||
|
||||
Similar to best practices for other servers, we recommend running any
|
||||
`ModelServer` with appropriate privileges (i.e., using a separate user with
|
||||
reduced permisisons). In the spirit of defense in depth, we recommend
|
||||
authenticating requests to any TensorFlow server connected to an untrusted
|
||||
network, as well as sandboxing the server to minimize the adverse effects of
|
||||
any breach.
|
||||
|
||||
## Vulnerabilities in TensorFlow
|
||||
|
||||
TensorFlow is a large and complex system. It also depends on a large set of
|
||||
third party libraries (e.g., `numpy`, `libjpeg-turbo`, PNG parsers, `protobuf`).
|
||||
It is possible that TensorFlow or its dependent libraries contain
|
||||
vulnerabilities that would allow triggering unexpected or dangerous behavior
|
||||
with specially crafted inputs.
|
||||
|
||||
### What is a vulnerability?
|
||||
|
||||
Given TensorFlow's flexibility, it is possible to specify computation graphs
|
||||
which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models
|
||||
can perform arbitrary computations means that they may read and write files,
|
||||
communicate via the network, produce deadlocks and infinite loops, or run out
|
||||
of memory. It is only when these behaviors are outside the specifications of the
|
||||
operations involved that such behavior is a vulnerability.
|
||||
|
||||
A `FileWriter` writing a file is not unexpected behavior and therefore is not a
|
||||
vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution
|
||||
**is** a vulnerability.
|
||||
|
||||
This is more subtle from a system perspective. For example, it is easy to cause
|
||||
a TensorFlow process to try to allocate more memory than available by specifying
|
||||
a computation graph containing an ill-considered `tf.tile` operation. TensorFlow
|
||||
should exit cleanly in this case (it would raise an exception in Python, or
|
||||
return an error `Status` in C++). However, if the surrounding system is not
|
||||
expecting the possibility, such behavior could be used in a denial of service
|
||||
attack (or worse). Because TensorFlow behaves correctly, this is not a
|
||||
vulnerability in TensorFlow (although it would be a vulnerability of this
|
||||
hypothetical system).
|
||||
|
||||
As a general rule, it is incorrect behavior for Tensorflow to access memory it
|
||||
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
|
||||
such behaviors constitute a vulnerability.
|
||||
|
||||
One of the most critical parts of any system is input handling. If malicious
|
||||
input can trigger side effects or incorrect behavior, this is a bug, and likely
|
||||
a vulnerability.
|
||||
|
||||
### Reporting vulnerabilities
|
||||
|
||||
Please email reports about any security related issues you find to
|
||||
`security@tensorflow.org`. This mail is delivered to a small security team. Your
|
||||
email will be acknowledged within one business day, and you'll receive a more
|
||||
detailed response to your email within 7 days indicating the next steps in
|
||||
handling your report. For critical problems, you may encrypt your report (see
|
||||
below).
|
||||
|
||||
Please use a descriptive subject line for your report email. After the initial
|
||||
reply to your report, the security team will endeavor to keep you informed of
|
||||
the progress being made towards a fix and announcement.
|
||||
|
||||
If you believe that an existing (public) issue is security-related, please send
|
||||
an email to `security@tensorflow.org`. The email should include the issue ID and
|
||||
a short description of why it should be handled according to this security
|
||||
policy.
|
||||
|
||||
Once an issue is reported, TensorFlow uses the following disclosure process:
|
||||
|
||||
* When a report is received, we confirm the issue and determine its severity.
|
||||
* If we know of specific third-party services or software based on TensorFlow
|
||||
that require mitigation before publication, those projects will be notified.
|
||||
* An advisory is prepared (but not published) which details the problem and
|
||||
steps for mitigation.
|
||||
* Wherever possible, fixes are prepared for the last minor release of the two
|
||||
latest major releases, as well as the master branch. We will attempt to
|
||||
commit these fixes as soon as possible, and as close together as
|
||||
possible.
|
||||
* Patch releases are published for all fixed released versions, a
|
||||
notification is sent to discuss@tensorflow.org, and the advisory is published.
|
||||
|
||||
Past security advisories are listed below. We credit reporters for identifying
|
||||
security issues, although we keep your name confidential if you request it.
|
||||
|
||||
#### Encryption key for `security@tensorflow.org`
|
||||
|
||||
If your disclosure is extremely sensitive, you may choose to encrypt your
|
||||
report using the key below. Please only use this for critical security
|
||||
reports.
|
||||
|
||||
```
|
||||
-----BEGIN PGP PUBLIC KEY BLOCK-----
|
||||
|
||||
mQENBFpqdzwBCADTeAHLNEe9Vm77AxhmGP+CdjlY84O6DouOCDSq00zFYdIU/7aI
|
||||
LjYwhEmDEvLnRCYeFGdIHVtW9YrVktqYE9HXVQC7nULU6U6cvkQbwHCdrjaDaylP
|
||||
aJUXkNrrxibhx9YYdy465CfusAaZ0aM+T9DpcZg98SmsSml/HAiiY4mbg/yNVdPs
|
||||
SEp/Ui4zdIBNNs6at2gGZrd4qWhdM0MqGJlehqdeUKRICE/mdedXwsWLM8AfEA0e
|
||||
OeTVhZ+EtYCypiF4fVl/NsqJ/zhBJpCx/1FBI1Uf/lu2TE4eOS1FgmIqb2j4T+jY
|
||||
e+4C8kGB405PAC0n50YpOrOs6k7fiQDjYmbNABEBAAG0LVRlbnNvckZsb3cgU2Vj
|
||||
dXJpdHkgPHNlY3VyaXR5QHRlbnNvcmZsb3cub3JnPokBTgQTAQgAOBYhBEkvXzHm
|
||||
gOJBnwP4Wxnef3wVoM2yBQJaanc8AhsDBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheA
|
||||
AAoJEBnef3wVoM2yNlkIAICqetv33MD9W6mPAXH3eon+KJoeHQHYOuwWfYkUF6CC
|
||||
o+X2dlPqBSqMG3bFuTrrcwjr9w1V8HkNuzzOJvCm1CJVKaxMzPuXhBq5+DeT67+a
|
||||
T/wK1L2R1bF0gs7Pp40W3np8iAFEh8sgqtxXvLGJLGDZ1Lnfdprg3HciqaVAiTum
|
||||
HBFwszszZZ1wAnKJs5KVteFN7GSSng3qBcj0E0ql2nPGEqCVh+6RG/TU5C8gEsEf
|
||||
3DX768M4okmFDKTzLNBm+l08kkBFt+P43rNK8dyC4PXk7yJa93SmS/dlK6DZ16Yw
|
||||
2FS1StiZSVqygTW59rM5XNwdhKVXy2mf/RtNSr84gSi5AQ0EWmp3PAEIALInfBLR
|
||||
N6fAUGPFj+K3za3PeD0fWDijlC9f4Ety/icwWPkOBdYVBn0atzI21thPRbfuUxfe
|
||||
zr76xNNrtRRlbDSAChA1J5T86EflowcQor8dNC6fS+oHFCGeUjfEAm16P6mGTo0p
|
||||
osdG2XnnTHOOEFbEUeWOwR/zT0QRaGGknoy2pc4doWcJptqJIdTl1K8xyBieik/b
|
||||
nSoClqQdZJa4XA3H9G+F4NmoZGEguC5GGb2P9NHYAJ3MLHBHywZip8g9oojIwda+
|
||||
OCLL4UPEZ89cl0EyhXM0nIAmGn3Chdjfu3ebF0SeuToGN8E1goUs3qSE77ZdzIsR
|
||||
BzZSDFrgmZH+uP0AEQEAAYkBNgQYAQgAIBYhBEkvXzHmgOJBnwP4Wxnef3wVoM2y
|
||||
BQJaanc8AhsMAAoJEBnef3wVoM2yX4wIALcYZbQhSEzCsTl56UHofze6C3QuFQIH
|
||||
J4MIKrkTfwiHlCujv7GASGU2Vtis5YEyOoMidUVLlwnebE388MmaJYRm0fhYq6lP
|
||||
A3vnOCcczy1tbo846bRdv012zdUA+wY+mOITdOoUjAhYulUR0kiA2UdLSfYzbWwy
|
||||
7Obq96Jb/cPRxk8jKUu2rqC/KDrkFDtAtjdIHh6nbbQhFuaRuWntISZgpIJxd8Bt
|
||||
Gwi0imUVd9m9wZGuTbDGi6YTNk0GPpX5OMF5hjtM/objzTihSw9UN+65Y/oSQM81
|
||||
v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
|
||||
=CDME
|
||||
-----END PGP PUBLIC KEY BLOCK-----
|
||||
```
|
||||
|
||||
### Known vulnerabilities
|
||||
|
||||
| Type | Versions affected | Reported by | Additional Information |
|
||||
|------|:-----------------:|---------------------------------------|
|
||||
| out of bounds read| <=1.4 | @zhangbo5891001 | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |
|
||||
|
@ -195,10 +195,10 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
|
||||
reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
|
||||
// TF_STRING and TF_RESOURCE tensors have a different representation in
|
||||
// TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
|
||||
// (any alignement requirements will be taken care of by TF_TensorToTensor
|
||||
// (any alignment requirements will be taken care of by TF_TensorToTensor
|
||||
// and TF_TensorFromTensor).
|
||||
//
|
||||
// Other types have the same represntation, so copy only if it is safe to do
|
||||
// Other types have the same representation, so copy only if it is safe to do
|
||||
// so.
|
||||
buf->data_ = allocate_tensor("TF_NewTensor", len);
|
||||
std::memcpy(buf->data_, data, len);
|
||||
@ -2144,7 +2144,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph,
|
||||
opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
|
||||
}
|
||||
|
||||
// TOOD(skyewm): change to OutputTensor
|
||||
// TODO(skyewm): change to OutputTensor
|
||||
tensorflow::ImportGraphDefResults results;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));
|
||||
|
@ -46,6 +46,7 @@ tf_cuda_library(
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib_internal",
|
||||
|
@ -85,15 +85,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TFE_Context* ret = new TFE_Context(session);
|
||||
ret->policy = opts->policy;
|
||||
ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime(
|
||||
ret->session->device_mgr, opts->session_options.options.env,
|
||||
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {}));
|
||||
ret->rendezvous =
|
||||
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);
|
||||
|
||||
return ret;
|
||||
return new TFE_Context(*opts, session);
|
||||
}
|
||||
|
||||
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
|
||||
@ -261,15 +253,6 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
|
||||
|
||||
void TFE_DeleteOp(TFE_Op* op) { delete op; }
|
||||
|
||||
static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device,
|
||||
TF_Status* status) {
|
||||
// Questionable heuristic: Place the op on the same device as the first input
|
||||
// placed outside of host memory?
|
||||
if (IsCPU(op->device) && !IsCPU(device)) {
|
||||
op->device = device;
|
||||
}
|
||||
}
|
||||
|
||||
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
tensorflow::Device* d = nullptr;
|
||||
if (device_name != nullptr && strlen(device_name) > 0) {
|
||||
@ -277,11 +260,24 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
|
||||
op->ctx->session->device_mgr->LookupDevice(device_name, &d);
|
||||
if (!status->status.ok()) return;
|
||||
}
|
||||
TFE_OpSetDeviceHelper(op, d, status);
|
||||
op->device = d;
|
||||
}
|
||||
|
||||
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
|
||||
tensorflow::Device* device =
|
||||
(op->device == nullptr) ? op->ctx->devices()[0] : op->device;
|
||||
return device->name().c_str();
|
||||
}
|
||||
|
||||
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
TFE_OpSetDeviceHelper(op, h->d, status);
|
||||
// Questionable heuristic ...
|
||||
//
|
||||
// Motivation: After an 'op' is placed on GPU because some of its earlier
|
||||
// inputs are on GPU, we want to keep the 'op' there, even if some later
|
||||
// inputs of it are not on GPU.
|
||||
if (IsCPU(op->device) && !IsCPU(h->d)) {
|
||||
op->device = h->d;
|
||||
}
|
||||
if (!status->status.ok()) return;
|
||||
op->inputs.push_back(h->t);
|
||||
op->input_devices.push_back(h->d);
|
||||
@ -298,7 +294,7 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
|
||||
return TF_ATTR_INT; // The compiler requires that we return something.
|
||||
}
|
||||
status->status =
|
||||
tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list);
|
||||
tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -154,6 +154,9 @@ TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
|
||||
TF_Status* status);
|
||||
// The returned string remains valid throughout the lifetime of 'op'.
|
||||
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
|
||||
TF_Status* status);
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);
|
||||
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
struct TFE_ContextOptions {
|
||||
TF_SessionOptions session_options;
|
||||
@ -43,9 +44,15 @@ struct TFE_ContextOptions {
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
explicit TFE_Context(TF_Session* s) : session(s) {}
|
||||
explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s)
|
||||
: policy(opts.policy),
|
||||
session(s),
|
||||
rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)),
|
||||
pflr(new tensorflow::ProcessFunctionLibraryRuntime(
|
||||
session->device_mgr, opts.session_options.options.env,
|
||||
TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {}
|
||||
|
||||
TFE_ContextDevicePlacementPolicy policy;
|
||||
const TFE_ContextDevicePlacementPolicy policy;
|
||||
|
||||
// Note: we cannot use C++11 thread_local here as there is no concept of a
|
||||
// thread-local-object-local variable in C++11.
|
||||
@ -54,8 +61,8 @@ struct TFE_Context {
|
||||
thread_local_policies GUARDED_BY(policy_map_mu);
|
||||
|
||||
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
|
||||
TF_Session* session;
|
||||
tensorflow::Rendezvous* rendezvous;
|
||||
TF_Session* const session;
|
||||
tensorflow::Rendezvous* const rendezvous;
|
||||
|
||||
tensorflow::mutex functions_mu;
|
||||
tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
|
||||
@ -64,14 +71,14 @@ struct TFE_Context {
|
||||
// One FunctionLibraryRuntime per device.
|
||||
// func_libs[i] is the FunctionLibraryRuntime corresponding to
|
||||
// session->devices[i].
|
||||
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
|
||||
const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
|
||||
|
||||
tensorflow::mutex cache_mu;
|
||||
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
|
||||
tensorflow::Fprint128Hasher>
|
||||
kernel_cache GUARDED_BY(cache_mu);
|
||||
|
||||
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
|
||||
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const {
|
||||
return pflr->GetFLR(d->name());
|
||||
}
|
||||
|
||||
@ -100,6 +107,8 @@ struct TFE_TensorHandle {
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
|
||||
// primitive operation.
|
||||
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
|
||||
: ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}
|
||||
|
||||
|
@ -60,6 +60,31 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
return op;
|
||||
}
|
||||
|
||||
// If there is a GPU device, returns true and sets 'gpu_device_name'
|
||||
// accordingly.
|
||||
bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
const int num_devices = TF_DeviceListCount(devices);
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
const string device_type(TF_DeviceListType(devices, i, status.get()));
|
||||
CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
const string device_name(TF_DeviceListName(devices, i, status.get()));
|
||||
CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
if (device_type == "GPU") {
|
||||
*gpu_device_name = device_name;
|
||||
LOG(INFO) << "Found GPU device " << device_name;
|
||||
TF_DeleteDeviceList(devices);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
TF_DeleteDeviceList(devices);
|
||||
return false;
|
||||
}
|
||||
|
||||
void BM_InitOp(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
@ -288,22 +313,15 @@ TEST(CAPI, TensorHandleSilentCopy) {
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
const int num_devices = TF_DeviceListCount(devices);
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
if (num_devices > 1) {
|
||||
const int device_to_use = 1;
|
||||
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hgpu =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
|
||||
string gpu_device_name;
|
||||
if (GetGPUDeviceName(ctx, &gpu_device_name)) {
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
TFE_OpSetDevice(matmul, name.c_str(), status.get());
|
||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
@ -314,7 +332,6 @@ TEST(CAPI, TensorHandleSilentCopy) {
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
}
|
||||
|
||||
TF_DeleteDeviceList(devices);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_DeleteContext(ctx, status.get());
|
||||
@ -337,22 +354,15 @@ TEST(CAPI, TensorHandleSilentCopyLocal) {
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
const int num_devices = TF_DeviceListCount(devices);
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
if (num_devices > 1) {
|
||||
const int device_to_use = 1;
|
||||
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hgpu =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
|
||||
string gpu_device_name;
|
||||
if (GetGPUDeviceName(ctx, &gpu_device_name)) {
|
||||
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
|
||||
hcpu, ctx, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
TFE_OpSetDevice(matmul, name.c_str(), status.get());
|
||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
@ -363,13 +373,43 @@ TEST(CAPI, TensorHandleSilentCopyLocal) {
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
}
|
||||
|
||||
TF_DeleteDeviceList(devices);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_DeleteContext(ctx, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CAPI, SetAndGetOpDevices) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
string gpu_device_name;
|
||||
if (GetGPUDeviceName(ctx, &gpu_device_name)) {
|
||||
TFE_OpSetDevice(matmul, "GPU:0", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
const char* device_name = TFE_OpGetDevice(matmul, status);
|
||||
ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr);
|
||||
|
||||
TFE_OpSetDevice(matmul, "CPU:0", status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
device_name = TFE_OpGetDevice(matmul, status);
|
||||
ASSERT_TRUE(strstr(device_name, "CPU:0") != nullptr);
|
||||
}
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteContext(ctx, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, Execute) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
|
@ -86,10 +86,9 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
|
||||
Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
|
||||
TF_AttrType* out, unsigned char* is_list) {
|
||||
CHECK(m);
|
||||
auto* t = gtl::FindOrNull(*m, attr_name);
|
||||
auto* t = gtl::FindOrNull(m, attr_name);
|
||||
if (t == nullptr) {
|
||||
return errors::InvalidArgument("Attribute '", attr_name,
|
||||
"' does not exist for this operation");
|
||||
@ -173,14 +172,14 @@ void CombineUnordered(const tensorflow::Fprint128& a,
|
||||
b->high64 += a.high64;
|
||||
}
|
||||
|
||||
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s,
|
||||
inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
|
||||
const tensorflow::Fprint128& b) {
|
||||
// TODO(agarwal): avoid ToString().
|
||||
tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
|
||||
return FingerprintCat128(a, b);
|
||||
}
|
||||
|
||||
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) {
|
||||
inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
|
||||
return CacheKeyHelper(s, {b, b});
|
||||
}
|
||||
|
||||
|
@ -43,7 +43,7 @@ typedef std::unordered_map<string, uint32> AttrTypeMap;
|
||||
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
|
||||
|
||||
// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
|
||||
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
|
||||
Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
|
||||
TF_AttrType* out, unsigned char* is_list);
|
||||
|
||||
// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.
|
||||
|
@ -63,17 +63,17 @@ TEST(AttrTypeMap, Lookup) {
|
||||
|
||||
TF_AttrType t;
|
||||
unsigned char is_list = 1;
|
||||
s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
|
||||
s = AttrTypeByName(*m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_NE(is_list, 0);
|
||||
s = AttrTypeByName(m, "transpose_a", &t, &is_list);
|
||||
s = AttrTypeByName(*m, "transpose_a", &t, &is_list);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
EXPECT_EQ(TF_ATTR_BOOL, t);
|
||||
EXPECT_EQ(is_list, 0);
|
||||
|
||||
s = AttrTypeMapForOp("Squeeze", &m);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
s = AttrTypeByName(m, "squeeze_dims", &t, &is_list);
|
||||
s = AttrTypeByName(*m, "squeeze_dims", &t, &is_list);
|
||||
ASSERT_TRUE(s.ok()) << s;
|
||||
EXPECT_EQ(TF_ATTR_INT, t);
|
||||
EXPECT_NE(is_list, 0);
|
||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
||||
// Language-agnostic gradient tape. Does not perform backpropagation, just
|
||||
// maintains the data structures required to do so.
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -54,11 +54,11 @@ struct OpTapeEntry {
|
||||
// Map from tensor_id to internally-defined operation-id of the operation which
|
||||
// produced this tensor. A value of -1 means that the tensor was directly
|
||||
// watched and not the result of any operation in the tape.
|
||||
using TensorTape = std::unordered_map<int64, int64>;
|
||||
using TensorTape = gtl::FlatMap<int64, int64>;
|
||||
|
||||
// Map from operation-id to tape entry.
|
||||
template <typename BackwardFunction>
|
||||
using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
|
||||
using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
|
||||
|
||||
// Operations the tape needs to perform on tensors to do backpropagation. Named
|
||||
// "vspace" because a subset of these are related to a vector space, such as
|
||||
@ -159,7 +159,7 @@ class GradientTape {
|
||||
|
||||
// Map from tensor id to number of remaining usages (i.e. how many entries in
|
||||
// the tape refer to it); to aid in tape garbage collection.
|
||||
std::unordered_map<int64, int64> tensor_usage_;
|
||||
gtl::FlatMap<int64, int64> tensor_usage_;
|
||||
|
||||
// If false, all activations are deleted in the first call to ComputeGradient.
|
||||
// Else, only when this is destructed.
|
||||
@ -286,11 +286,11 @@ struct BackpropInitialState {
|
||||
|
||||
// Map from tensor ID to how many references still exist for this tensor in
|
||||
// the tape.
|
||||
std::unordered_map<int64, int64> tensor_usage_counts;
|
||||
gtl::FlatMap<int64, int64> tensor_usage_counts;
|
||||
|
||||
// Maps from op ID to how many output tensors of this op still need to have
|
||||
// their gradients computed.
|
||||
std::unordered_map<int64, int64> op_missing_tensor;
|
||||
gtl::FlatMap<int64, int64> op_missing_tensor;
|
||||
};
|
||||
|
||||
// If `persistent_tape` is true, op_tape is not changed and none of the
|
||||
@ -301,8 +301,8 @@ struct BackpropInitialState {
|
||||
template <typename BackwardFunction>
|
||||
BackpropInitialState<BackwardFunction> PrepareBackprop(
|
||||
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
|
||||
OpTape<BackwardFunction>* op_tape,
|
||||
const std::unordered_set<int64>& sources_set, bool persistent_tape) {
|
||||
OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
|
||||
bool persistent_tape) {
|
||||
std::vector<int64> tensor_stack;
|
||||
tensor_stack.reserve(target.size());
|
||||
for (auto t : target) {
|
||||
@ -362,7 +362,7 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
|
||||
template <typename BackwardFunction>
|
||||
std::vector<int64> InitialStack(
|
||||
const OpTape<BackwardFunction>& op_tape,
|
||||
const std::unordered_map<int64, int64>& op_missing_tensor) {
|
||||
const gtl::FlatMap<int64, int64>& op_missing_tensor) {
|
||||
std::vector<int64> result;
|
||||
for (auto& op_entry : op_tape) {
|
||||
if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
|
||||
@ -373,13 +373,13 @@ std::vector<int64> InitialStack(
|
||||
}
|
||||
|
||||
template <typename Gradient, typename BackwardFunction>
|
||||
Status InitialGradients(
|
||||
const VSpace<Gradient, BackwardFunction>& vspace,
|
||||
gtl::ArraySlice<int64> target_tensor_ids,
|
||||
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
|
||||
const OpTape<BackwardFunction>& op_tape,
|
||||
const std::unordered_map<int64, int64>& tensor_usage_counts,
|
||||
std::unordered_map<int64, std::vector<Gradient*>>* result) {
|
||||
Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
|
||||
gtl::ArraySlice<int64> target_tensor_ids,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
const TensorTape& tensor_tape,
|
||||
const OpTape<BackwardFunction>& op_tape,
|
||||
const gtl::FlatMap<int64, int64>& tensor_usage_counts,
|
||||
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
|
||||
for (int i = 0; i < target_tensor_ids.size(); ++i) {
|
||||
const int64 id = target_tensor_ids[i];
|
||||
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
|
||||
@ -441,13 +441,13 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
||||
gtl::ArraySlice<int64> source_tensor_ids,
|
||||
gtl::ArraySlice<Gradient*> output_gradients,
|
||||
std::vector<Gradient*>* result) {
|
||||
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
|
||||
source_tensor_ids.end());
|
||||
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
|
||||
source_tensor_ids.end());
|
||||
BackpropInitialState<BackwardFunction> state = PrepareBackprop(
|
||||
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
|
||||
std::vector<int64> op_stack =
|
||||
InitialStack(state.op_tape, state.op_missing_tensor);
|
||||
std::unordered_map<int64, std::vector<Gradient*>> gradients;
|
||||
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
|
||||
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
|
||||
tensor_tape_, state.op_tape,
|
||||
state.tensor_usage_counts, &gradients);
|
||||
@ -463,7 +463,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
||||
cleanup();
|
||||
return s;
|
||||
}
|
||||
std::unordered_map<int64, int64> gradients_size;
|
||||
gtl::FlatMap<int64, int64> gradients_size;
|
||||
// TODO(apassos) multiple threads could be dequeuing from op_stack at the same
|
||||
// time, for better CPU backprop performance.
|
||||
VLOG(1) << "Initial stack:";
|
||||
@ -472,11 +472,10 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
|
||||
VLOG(1) << " " << t;
|
||||
}
|
||||
}
|
||||
std::unordered_map<string, std::unordered_set<int>>
|
||||
functions_accept_none_for_indices({
|
||||
{"SoftmaxCrossEntropyWithLogits", {1}},
|
||||
{"FusedBatchNorm", {1, 2, 3, 4}},
|
||||
});
|
||||
gtl::FlatMap<string, gtl::FlatSet<int>> functions_accept_none_for_indices({
|
||||
{"SoftmaxCrossEntropyWithLogits", {1}},
|
||||
{"FusedBatchNorm", {1, 2, 3, 4}},
|
||||
});
|
||||
while (!op_stack.empty()) {
|
||||
const int64 op = op_stack.back();
|
||||
VLOG(1) << "Popped " << op;
|
||||
|
@ -433,6 +433,7 @@ tf_gen_op_wrappers_cc(
|
||||
"linalg_ops",
|
||||
"logging_ops",
|
||||
"lookup_ops",
|
||||
"manip_ops",
|
||||
"math_ops",
|
||||
"nn_ops",
|
||||
"no_op",
|
||||
|
@ -96,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
|
||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||
const SessionOptions& session_options,
|
||||
std::unique_ptr<Session>* session) {
|
||||
session->reset(NewSession(session_options));
|
||||
Session* session_p = nullptr;
|
||||
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
|
||||
session->reset(session_p);
|
||||
return (*session)->Create(meta_graph_def.graph_def());
|
||||
}
|
||||
|
||||
|
@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, SessionCreationFailure) {
|
||||
SavedModelBundle bundle;
|
||||
// Use invalid SessionOptions to cause session creation to fail. Default
|
||||
// options work, so provide an invalid value for the target field.
|
||||
SessionOptions session_options;
|
||||
constexpr char kInvalidTarget[] = "invalid target";
|
||||
session_options.target = kInvalidTarget;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||
Status st = LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle);
|
||||
EXPECT_FALSE(st.ok());
|
||||
EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget))
|
||||
<< st.error_message();
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, PbtxtFormat) {
|
||||
SavedModelBundle bundle;
|
||||
SessionOptions session_options;
|
||||
|
@ -23,7 +23,6 @@ cc_library(
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
To use from your BUILD file, add the following line to load the macro:
|
||||
|
||||
load("@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
|
||||
|
||||
Then call the macro like this:
|
||||
|
||||
@ -16,14 +16,15 @@ tf_library(
|
||||
)
|
||||
"""
|
||||
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_android", "tf_copts")
|
||||
load("//tensorflow:tensorflow.bzl",
|
||||
"if_android", "tf_cc_test", "tf_copts")
|
||||
|
||||
def tf_library(name, graph, config,
|
||||
freeze_checkpoint=None, freeze_saver=None,
|
||||
cpp_class=None, gen_test=True, gen_benchmark=True,
|
||||
visibility=None, testonly=None,
|
||||
tfcompile_flags=None,
|
||||
tfcompile_tool="@org_tensorflow//tensorflow/compiler/aot:tfcompile",
|
||||
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
|
||||
include_standard_runtime_deps=True, deps=None, tags=None):
|
||||
"""Runs tfcompile to compile a TensorFlow graph into executable code.
|
||||
|
||||
@ -119,9 +120,9 @@ def tf_library(name, graph, config,
|
||||
out_nodes_file,
|
||||
] + freeze_saver_srcs,
|
||||
outs=[freeze_file],
|
||||
cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" +
|
||||
cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
|
||||
freeze_args),
|
||||
tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"],
|
||||
tools=["//tensorflow/python/tools:freeze_graph"],
|
||||
tags=tags,
|
||||
)
|
||||
tfcompile_graph = freeze_file
|
||||
@ -213,22 +214,22 @@ def tf_library(name, graph, config,
|
||||
# These deps are required by all tf_library targets even if
|
||||
# include_standard_runtime_deps is False. Without them, the
|
||||
# generated code will fail to compile.
|
||||
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
|
||||
"@org_tensorflow//tensorflow/core:framework_lite",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
|
||||
"//tensorflow/core:framework_lite",
|
||||
] + (need_xla_data_proto and [
|
||||
# If we're generating the program shape, we must depend on the proto.
|
||||
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
] or []) + (include_standard_runtime_deps and [
|
||||
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
|
||||
"@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
||||
"@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
|
||||
"//third_party/eigen3",
|
||||
] or []) + (deps or []),
|
||||
tags=tags,
|
||||
@ -254,28 +255,32 @@ def tf_library(name, graph, config,
|
||||
name=("gen_" + test_name),
|
||||
testonly=1,
|
||||
srcs=[
|
||||
"@org_tensorflow//tensorflow/compiler/aot:test.cc",
|
||||
"//tensorflow/compiler/aot:test.cc",
|
||||
header_file,
|
||||
],
|
||||
outs=[test_file],
|
||||
cmd=("sed " + sed_replace +
|
||||
" $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " +
|
||||
" $(location //tensorflow/compiler/aot:test.cc) " +
|
||||
"> $(OUTS)"),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# The cc_test rule for the generated code.
|
||||
native.cc_test(
|
||||
# The cc_test rule for the generated code. To ensure that this works
|
||||
# reliably across build configurations, we must use tf_cc_test instead of
|
||||
# native.cc_test. This is related to how we build
|
||||
# //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
|
||||
# for more details.
|
||||
tf_cc_test(
|
||||
name=test_name,
|
||||
srcs=[test_file],
|
||||
deps=[
|
||||
":" + name,
|
||||
"@org_tensorflow//tensorflow/compiler/aot:runtime",
|
||||
"@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main",
|
||||
"@org_tensorflow//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/aot:tf_library_test_main",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//third_party/eigen3",
|
||||
"@org_tensorflow//tensorflow/core:lib",
|
||||
"@org_tensorflow//tensorflow/core:test",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
tags=tags,
|
||||
)
|
||||
@ -283,7 +288,7 @@ def tf_library(name, graph, config,
|
||||
if gen_benchmark:
|
||||
benchmark_name = name + "_benchmark"
|
||||
benchmark_file = benchmark_name + ".cc"
|
||||
benchmark_main = ("@org_tensorflow//tensorflow/compiler/aot:" +
|
||||
benchmark_main = ("//tensorflow/compiler/aot:" +
|
||||
"benchmark_main.template")
|
||||
|
||||
# Rule to rewrite benchmark.cc to produce the benchmark_file.
|
||||
@ -301,7 +306,9 @@ def tf_library(name, graph, config,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# The cc_benchmark rule for the generated code.
|
||||
# The cc_benchmark rule for the generated code. This does not need the
|
||||
# tf_cc_binary since we (by deliberate design) do not depend on
|
||||
# //tensorflow/core:lib.
|
||||
#
|
||||
# Note: to get smaller size on android for comparison, compile with:
|
||||
# --copt=-fvisibility=hidden
|
||||
@ -315,12 +322,12 @@ def tf_library(name, graph, config,
|
||||
linkopts = if_android(["-pie", "-s"]),
|
||||
deps=[
|
||||
":" + name,
|
||||
"@org_tensorflow//tensorflow/compiler/aot:benchmark",
|
||||
"@org_tensorflow//tensorflow/compiler/aot:runtime",
|
||||
"@org_tensorflow//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/aot:benchmark",
|
||||
"//tensorflow/compiler/aot:runtime",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//third_party/eigen3",
|
||||
] + if_android([
|
||||
"@org_tensorflow//tensorflow/compiler/aot:benchmark_extra_android",
|
||||
"//tensorflow/compiler/aot:benchmark_extra_android",
|
||||
]),
|
||||
tags=tags,
|
||||
)
|
||||
@ -330,11 +337,11 @@ def target_llvm_triple():
|
||||
# TODO(toddw): Add target_triple for other targets. For details see:
|
||||
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
|
||||
return select({
|
||||
"@org_tensorflow//tensorflow:android_armeabi": "armv5-none-android",
|
||||
"@org_tensorflow//tensorflow:android_arm": "armv7-none-android",
|
||||
"@org_tensorflow//tensorflow:android_arm64": "aarch64-none-android",
|
||||
"@org_tensorflow//tensorflow:android_x86": "i686-none-android",
|
||||
"@org_tensorflow//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||
"@org_tensorflow//tensorflow:darwin": "x86_64-none-darwin",
|
||||
"//tensorflow:android_armeabi": "armv5-none-android",
|
||||
"//tensorflow:android_arm": "armv7-none-android",
|
||||
"//tensorflow:android_arm64": "aarch64-none-android",
|
||||
"//tensorflow:android_x86": "i686-none-android",
|
||||
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||
"//tensorflow:darwin": "x86_64-none-darwin",
|
||||
"//conditions:default": "x86_64-pc-linux",
|
||||
})
|
||||
|
@ -30,12 +30,14 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
@ -141,8 +143,7 @@ struct NodeSlot {
|
||||
// everything to use it.
|
||||
static const char* const kArgOp = "_Arg";
|
||||
static const char* const kRetValOp = "_Retval";
|
||||
static const char* const kSendToHostOp = "_XlaSendToHost";
|
||||
static const char* const kRecvFromHostOp = "_XlaRecvFromHost";
|
||||
static const char* const kHostComputeOp = "_XlaHostCompute";
|
||||
static const char* const kSendFromHostOp = "_XlaSendFromHost";
|
||||
static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
|
||||
|
||||
@ -171,7 +172,8 @@ class Encapsulator {
|
||||
|
||||
// Write a copy of the input graph to 'graph_out', where the subgraphs are
|
||||
// replaced with calls to the new functions.
|
||||
Status BuildOutputGraph(bool parallel_checking, Graph* graph_out);
|
||||
Status BuildOutputGraph(bool parallel_checking, Graph* graph_out,
|
||||
FunctionLibraryDefinition* library);
|
||||
|
||||
private:
|
||||
// A subgraph of the input, all marked with a common 'group_attribute'
|
||||
@ -201,21 +203,29 @@ class Encapsulator {
|
||||
// .. .
|
||||
// RAH --> C --> SFH
|
||||
//
|
||||
// The compiled cluster is as follows. STH is a SendToHost node which is the
|
||||
// source of a channel to the RAH node above. RFH is a RecvFromHost node which
|
||||
// is the destination of a channel from the SFH node above. There is a control
|
||||
// edge that ensures RFH follows STH, which is used in shape inference to
|
||||
// ensure that the shapes on the STH host channel are known before the RFH
|
||||
// channel is compiled.
|
||||
// The compiled cluster is as follows. HC is a HostCompute node which is the
|
||||
// source of a channel to the RAH node above and the destination of a channel
|
||||
// from the SFH node above.
|
||||
//
|
||||
// Arg --> B --> STH ..> RFH --> D --> Retval
|
||||
// Arg --> B --> HC --> D --> Retval
|
||||
//
|
||||
// The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most
|
||||
// one RAH and SFH in each compiled cluster. This design is preferred over
|
||||
// adding separate Arg/Retval nodes for each transmitted value because it
|
||||
// simplifies the host code that would like to limit communication between
|
||||
// host and device and, e.g., raise only one interrupt per channel rather than
|
||||
// one per transmitted value.
|
||||
// The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is
|
||||
// at most one RAH and SFH in each outside_compilation cluster. This design is
|
||||
// preferred over adding separate Arg/Retval nodes for each transmitted value
|
||||
// because it allows optimizations to the host code that would like to limit
|
||||
// communication between host and device and, e.g., raise only one interrupt
|
||||
// per channel rather than one per transmitted value.
|
||||
//
|
||||
// The shapes of the outputs from the HC node in general cannot be determined
|
||||
// until the shapes of its inputs are known at compile time, since e.g.,
|
||||
// above, the shape of C's outputs aren't known until the shape of its inputs
|
||||
// are known. If the shapes of the HC's outputs can be determined during the
|
||||
// rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal
|
||||
// graph is stored in the shape_inference_graph attr. This graph can be used
|
||||
// when compiling the HC Op to determined the shape of the SFH inputs given
|
||||
// the shapes of any ancestor RAH outputs. If it can be determined that the
|
||||
// shape of the SFH inputs will not be inferrable even once the shapes of the
|
||||
// RAH outputs are known, an error is returned by the rewriter.
|
||||
class Subgraph {
|
||||
public:
|
||||
// Creates a graph to build the subgraph in, if it doesn't already exist,
|
||||
@ -246,6 +256,10 @@ class Encapsulator {
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out);
|
||||
|
||||
// Returns the names of all the outside_compilation subgraphs in this
|
||||
// Subgraph.
|
||||
void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const;
|
||||
|
||||
// Returns the Node that inputs to the function should be wired up to.
|
||||
Node* GetCallNodeForInputs() const;
|
||||
|
||||
@ -305,15 +319,9 @@ class Encapsulator {
|
||||
void RecordOutsideCompilationOutputOrControl(
|
||||
const string& outside_compilation_id, const Edge* edge);
|
||||
|
||||
// Adds the SendToHost nodes for each outside_compilation subgraph once the
|
||||
// edges have all been recorded via RecordOutsideCompilationInputOrControl.
|
||||
Status AddSendsToOutsideCompilation(
|
||||
const std::unordered_map<const Node*, Node*>& node_images);
|
||||
|
||||
// Adds the RecvFromHost nodes for each outside_compilation subgraph once
|
||||
// the edges have all been recorded via
|
||||
// RecordOutsideCompilationOutputOrControl.
|
||||
Status AddRecvsFromOutsideCompilation(
|
||||
// Adds the HostCompute nodes for each outside_compilation subgraph.
|
||||
Status AddHostComputes(
|
||||
const string& subgraph_name,
|
||||
const std::unordered_map<const Node*, Node*>& node_images);
|
||||
|
||||
// Creates the sequencer node if it doesn't exist, adding it to graph_out.
|
||||
@ -323,10 +331,16 @@ class Encapsulator {
|
||||
// all the downstream nodes of call_node_outputs.
|
||||
void ConnectSequencerToOutputs(Graph* graph_out);
|
||||
|
||||
Status AddShapeInferenceInfo(
|
||||
const string& outside_compilation_subgraph_name,
|
||||
const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
|
||||
|
||||
Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
|
||||
|
||||
private:
|
||||
struct OutsideCompilationSubgraph {
|
||||
// Map from source (producer node/slot) tensors in the original graph to
|
||||
// input index (slot number in the SendToHost/RecvAtHost nodes that will
|
||||
// input index (slot number in the HostCompute/RecvAtHost nodes that will
|
||||
// be created) for the outside_compilation subgraph.
|
||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
|
||||
|
||||
@ -335,14 +349,14 @@ class Encapsulator {
|
||||
// outside_compilation subgraph. These are recorded by
|
||||
// RecordOutsideCompilationInputOrControl while walking all the subgraph
|
||||
// edges, and lifted control edges within the subgraph are added by
|
||||
// AddSendsToOutsideCompilation once the _SendToHost node has been
|
||||
// AddSendsToOutsideCompilation once the _HostCompute node has been
|
||||
// created. The matching control edge from _RecvAtHost to the
|
||||
// destination is added by CopyEdgeToOutputGraph.
|
||||
std::unordered_set<const Node*> control_inputs;
|
||||
|
||||
// Maps from source (producer node/slot) and destination (consumer
|
||||
// node/slot) tensors in the original graph to output index (slot number
|
||||
// in the SendFromHost/RecvFromHost nodes that will be created) for the
|
||||
// in the SendFromHost/HostCompute nodes that will be created) for the
|
||||
// outside_compilation subgraph.
|
||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
|
||||
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
|
||||
@ -352,13 +366,13 @@ class Encapsulator {
|
||||
// containing compiled subgraph. These are recorded by
|
||||
// RecordOutsideCompilationOutputOrControl while walking all the subgraph
|
||||
// edges, and lifted control edges within the subgraph are added by
|
||||
// AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been
|
||||
// AddRecvsFromToOutsideCompilation once the _HostCompute node has been
|
||||
// created. The matching control edge from the source to _SendFromHost to
|
||||
// the destination is added by CopyEdgeToOutputGraph.
|
||||
std::unordered_set<const Node*> control_outputs;
|
||||
|
||||
// _SendToHost node in the subgraph. Not owned.
|
||||
Node* send_to_host = nullptr;
|
||||
// Name of the _HostCompute node in the subgraph.
|
||||
string host_compute_name;
|
||||
|
||||
// _RecvAtHost node in the output graph. Not owned.
|
||||
Node* recv_at_host = nullptr;
|
||||
@ -516,6 +530,59 @@ class Encapsulator {
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
bool parallel_checking, Graph* graph_out);
|
||||
|
||||
// Constructs a minimal shape inference graph that can be used to determine
|
||||
// the shape of send_node at the time that the subgraph is compiled.
|
||||
// recv_at_host_nodes contains the names of all the recv_at_host nodes that
|
||||
// send_node might depend on. These recv_at_host nodes have shapes that are
|
||||
// not known during the rewrite pass, but will be known at compile time.
|
||||
//
|
||||
// If the shapes of all the inputs to send_node can be determined during the
|
||||
// rewrite pass, on exit graphdef_out is empty and the shapes are returned in
|
||||
// static_shape_out. Otherwise graphdef_out contains a graph that can be used
|
||||
// for shape inference at compile time, where all the source nodes of the
|
||||
// graph are either constants with known shapes, or nodes named in
|
||||
// recv_at_host_nodes.
|
||||
//
|
||||
// A non-OK status is returned if neither of the above conditions can be
|
||||
// satisfied, e.g., because send_node depends on a node that doesn't have a
|
||||
// registered shape inference function.
|
||||
Status DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
const Graph& graph_in, const ShapeRefiner& shape_refiner,
|
||||
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
||||
FunctionLibraryDefinition* library,
|
||||
std::vector<TensorShapeProto>* static_shape_out,
|
||||
std::unique_ptr<GraphDef>* graphdef_out);
|
||||
|
||||
// Makes a copy of graph containing only nodes that are ancestors of at least
|
||||
// one node in send_from_host_nodes and store it in pruned_graph. On exit
|
||||
// nodes_images contains a mapping from nodes in graph to nodes in
|
||||
// pruned_graph. All functions in the copied graph are inlined.
|
||||
Status MakePrunedGraphCopyAndInline(
|
||||
const Graph& graph, const std::vector<Node*>& sink_nodes,
|
||||
std::unique_ptr<Graph>* pruned_graph,
|
||||
std::unordered_map<const Node*, Node*>* node_images,
|
||||
FunctionLibraryDefinition* library);
|
||||
|
||||
// Makes a copy of graph containing only nodes that are ancestors of a
|
||||
// send_from_host node in an outside_compilation subgraph, and store it in
|
||||
// pruned_graph. Also perform shape inference on the pruned graph, using
|
||||
// shape_refiner. On exit node_images contains a mapping from nodes in graph
|
||||
// to nodes in pruned_graph.
|
||||
Status MakeGraphForOutsideCompilationSends(
|
||||
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
||||
ShapeRefiner* shape_refiner,
|
||||
std::unordered_map<const Node*, Node*>* node_images,
|
||||
FunctionLibraryDefinition* library);
|
||||
|
||||
// Performs static shape inference, as far as possible, for the send_from_host
|
||||
// nodes in each outside_compilation subgraph. Where it is not possible to
|
||||
// determine the shape statically, stores a serialized GraphDef in the
|
||||
// HostCompute 'shape_inference_graph' attr, to be used at compile time for
|
||||
// final inference. If the shapes are known statically they are stored in the
|
||||
// HostCompute 'shapes' attr.
|
||||
Status GetShapeInfoForOutsideCompilationSends(
|
||||
Graph* graph_out, FunctionLibraryDefinition* library);
|
||||
|
||||
const string group_attribute_;
|
||||
const string outside_compilation_attribute_;
|
||||
const Graph* graph_in_;
|
||||
@ -682,16 +749,20 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
|
||||
}
|
||||
}
|
||||
|
||||
Status Encapsulator::Subgraph::AddSendsToOutsideCompilation(
|
||||
Status Encapsulator::Subgraph::AddHostComputes(
|
||||
const string& subgraph_name,
|
||||
const std::unordered_map<const Node*, Node*>& node_images) {
|
||||
for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
|
||||
const string& oc_subgraph_name = oc_subgraph_iter.first;
|
||||
OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
|
||||
if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
|
||||
// Build a _SendToHost node sending all the args of the appropriate
|
||||
// types.
|
||||
std::vector<DataType> dtypes(oc_subgraph.inputs.size(), DT_INVALID);
|
||||
if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() ||
|
||||
!oc_subgraph.outputs_by_src.empty() ||
|
||||
!oc_subgraph.control_outputs.empty()) {
|
||||
// Build a _HostCompute node.
|
||||
std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size());
|
||||
std::vector<DataType> input_dtypes(oc_subgraph.inputs.size(), DT_INVALID);
|
||||
std::vector<DataType> output_dtypes(oc_subgraph.outputs_by_src.size(),
|
||||
DT_INVALID);
|
||||
|
||||
for (const auto& input_src : oc_subgraph.inputs) {
|
||||
const Node* src_node = input_src.first.node;
|
||||
@ -700,94 +771,64 @@ Status Encapsulator::Subgraph::AddSendsToOutsideCompilation(
|
||||
int input_index = input_src.second;
|
||||
|
||||
DataType dtype = src_node->output_type(src_slot);
|
||||
dtypes[input_index] = dtype;
|
||||
inputs[input_index].Reset(src_image->name(), src_slot, dtype);
|
||||
input_dtypes[input_index] = dtype;
|
||||
}
|
||||
|
||||
NodeDef send_def;
|
||||
NodeDefBuilder builder(
|
||||
strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"),
|
||||
kSendToHostOp);
|
||||
builder.Attr("dtypes", dtypes);
|
||||
for (const auto& output : oc_subgraph.outputs_by_src) {
|
||||
DataType dtype = output.first.dtype;
|
||||
int output_index = output.second;
|
||||
output_dtypes[output_index] = dtype;
|
||||
}
|
||||
|
||||
NodeDef host_compute_def;
|
||||
NodeDefBuilder builder(strings::StrCat("outside_compilation_",
|
||||
oc_subgraph_name, "_host_compute"),
|
||||
kHostComputeOp);
|
||||
builder.Input(inputs);
|
||||
Status s = builder.Finalize(&send_def);
|
||||
builder.Attr("Tinputs", input_dtypes);
|
||||
builder.Attr("Toutputs", output_dtypes);
|
||||
builder.Attr("key",
|
||||
strings::StrCat("host_compute_channel_", subgraph_name, "_",
|
||||
oc_subgraph_name));
|
||||
Status s = builder.Finalize(&host_compute_def);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
oc_subgraph.send_to_host = graph_->AddNode(send_def, &s);
|
||||
Node* host_compute = graph_->AddNode(host_compute_def, &s);
|
||||
if (!s.ok()) return s;
|
||||
oc_subgraph.host_compute_name = host_compute->name();
|
||||
|
||||
// Connect the _SendToHost node to its producers in the subgraph.
|
||||
// Connect the _HostCompute node to its producers in the subgraph.
|
||||
for (auto& input_src : oc_subgraph.inputs) {
|
||||
const Node* src_node = input_src.first.node;
|
||||
Node* src_image = node_images.at(src_node);
|
||||
int src_slot = input_src.first.slot;
|
||||
int input_index = input_src.second;
|
||||
graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host,
|
||||
input_index);
|
||||
graph_->AddEdge(src_image, src_slot, host_compute, input_index);
|
||||
}
|
||||
|
||||
// Connect the _SendToHost node to its control edge producers in the
|
||||
// Connect the _HostCompute node to its control edge producers in the
|
||||
// subgraph.
|
||||
for (const auto& src_node : oc_subgraph.control_inputs) {
|
||||
Node* src_image = node_images.at(src_node);
|
||||
graph_->AddControlEdge(src_image, oc_subgraph.send_to_host);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation(
|
||||
const std::unordered_map<const Node*, Node*>& node_images) {
|
||||
for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
|
||||
const string& oc_subgraph_name = oc_subgraph_iter.first;
|
||||
OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
|
||||
if (!oc_subgraph.outputs_by_src.empty() ||
|
||||
!oc_subgraph.control_outputs.empty()) {
|
||||
// Build a _RecvFromHost node producing all the outputs of the appropriate
|
||||
// types.
|
||||
std::vector<DataType> dtypes(oc_subgraph.outputs_by_src.size(),
|
||||
DT_INVALID);
|
||||
|
||||
for (const auto& output : oc_subgraph.outputs_by_src) {
|
||||
DataType dtype = output.first.dtype;
|
||||
int output_index = output.second;
|
||||
dtypes[output_index] = dtype;
|
||||
graph_->AddControlEdge(src_image, host_compute);
|
||||
}
|
||||
|
||||
NodeDef recv_def;
|
||||
NodeDefBuilder builder(
|
||||
strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"),
|
||||
kRecvFromHostOp);
|
||||
builder.Attr("dtypes", dtypes);
|
||||
Status s = builder.Finalize(&recv_def);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
Node* recv = graph_->AddNode(recv_def, &s);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
// Connect the consumers in the subgraph to the _RecvFromHost node.
|
||||
// Connect the consumers in the subgraph to the _HostCompute node.
|
||||
for (const auto& output : oc_subgraph.outputs_by_dst) {
|
||||
const Node* dst_node = output.first.node;
|
||||
Node* dst_image = node_images.at(dst_node);
|
||||
int dst_slot = output.first.slot;
|
||||
int output_index = output.second;
|
||||
|
||||
graph_->AddEdge(recv, output_index, dst_image, dst_slot);
|
||||
graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
|
||||
}
|
||||
|
||||
// Connect the control edge consumers in the subgraph to the _RecvFromHost
|
||||
// Connect the control edge consumers in the subgraph to the _HostCompute
|
||||
// node.
|
||||
for (const auto& dst_node : oc_subgraph.control_outputs) {
|
||||
Node* dst_image = node_images.at(dst_node);
|
||||
graph_->AddControlEdge(recv, dst_image);
|
||||
}
|
||||
|
||||
// Add a control edge in the subgraph so that the _SendToHost node, if
|
||||
// any, is compiled before the _RecvFromHost node.
|
||||
if (oc_subgraph.send_to_host != nullptr) {
|
||||
graph_->AddControlEdge(oc_subgraph.send_to_host, recv);
|
||||
graph_->AddControlEdge(host_compute, dst_image);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -882,6 +923,63 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::Subgraph::AddShapeInferenceInfo(
|
||||
const string& outside_compilation_subgraph_name,
|
||||
const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
|
||||
OutsideCompilationSubgraph& oc_subgraph =
|
||||
outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
|
||||
|
||||
Node* host_compute = nullptr;
|
||||
for (Node* n : graph_->nodes()) {
|
||||
if (n->name() == oc_subgraph.host_compute_name) {
|
||||
host_compute = n;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (host_compute == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"After rewriting subgraph ", outside_compilation_subgraph_name,
|
||||
" there is no HostCompute Op for outside compilation subgraph ",
|
||||
oc_subgraph.host_compute_name);
|
||||
}
|
||||
|
||||
if (inference_graph == nullptr) {
|
||||
host_compute->AddAttr("shape_inference_graph", "");
|
||||
host_compute->AddAttr("shapes", shapes);
|
||||
} else {
|
||||
string serialized_graph;
|
||||
if (!inference_graph->SerializeToString(&serialized_graph)) {
|
||||
return errors::Internal(
|
||||
"Failed to serialize graph for outside compilation subgraph ",
|
||||
oc_subgraph.host_compute_name);
|
||||
}
|
||||
host_compute->AddAttr("shape_inference_graph", serialized_graph);
|
||||
host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::Subgraph::ReplaceFunctionDef(
|
||||
FunctionLibraryDefinition* library) {
|
||||
const string& name = call_node_def_.name();
|
||||
|
||||
FunctionDef fdef;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
|
||||
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(2) << "Replace function def " << name;
|
||||
dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
|
||||
library);
|
||||
dump_graph::DumpFunctionDefToFile(
|
||||
strings::StrCat("replace_encapsulate_fdef_", name), fdef);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(library->RemoveFunction(name));
|
||||
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::Subgraph::BuildParallelCheckOp(
|
||||
const std::unordered_map<const Node*, Node*>& node_images,
|
||||
Graph* graph_out) {
|
||||
@ -980,7 +1078,9 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
|
||||
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
|
||||
"_", oc_subgraph_name, "_recv"),
|
||||
kRecvAtHostOp);
|
||||
builder.Attr("dtypes", dtypes);
|
||||
builder.Attr("Toutputs", dtypes);
|
||||
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
|
||||
"_", oc_subgraph_name));
|
||||
Status s = builder.Finalize(&recv_def);
|
||||
if (!s.ok()) return s;
|
||||
|
||||
@ -1020,7 +1120,9 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
|
||||
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
|
||||
"_", oc_subgraph_name, "_send"),
|
||||
kSendFromHostOp);
|
||||
builder.Attr("dtypes", dtypes);
|
||||
builder.Attr("Tinputs", dtypes);
|
||||
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
|
||||
"_", oc_subgraph_name));
|
||||
builder.Input(inputs);
|
||||
Status s = builder.Finalize(&send_def);
|
||||
if (!s.ok()) return s;
|
||||
@ -1062,6 +1164,13 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
|
||||
std::vector<string>* names) const {
|
||||
for (auto& entry : outside_compilation_subgraphs_) {
|
||||
names->push_back(entry.first);
|
||||
}
|
||||
}
|
||||
|
||||
Status Encapsulator::GetFunctionNameAttr(
|
||||
Node const* node, string* attr, string* outside_compilation_attr) const {
|
||||
Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
|
||||
@ -1220,8 +1329,7 @@ Status Encapsulator::SplitIntoSubgraphs() {
|
||||
// single input and output node for it.
|
||||
for (auto& entry : subgraphs_) {
|
||||
Subgraph& subgraph = entry.second;
|
||||
TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images));
|
||||
TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images));
|
||||
TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images));
|
||||
}
|
||||
|
||||
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
|
||||
@ -1509,8 +1617,346 @@ Status Encapsulator::AddEdgesToOutputGraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::BuildOutputGraph(bool parallel_checking,
|
||||
Graph* graph_out) {
|
||||
namespace {
|
||||
|
||||
// Adds a dummy Const node to graph_out. The "constant" has the type of
|
||||
// data_type and the shape indicated in 'shape'. The dummy node is not a valid
|
||||
// Const node because it does not have any value defined, but this doesn't
|
||||
// matter because it will only be used subsequently for shape inference. (It
|
||||
// would be possible to add a switch statement over data_type to create a value
|
||||
// for the constant, but that would entail maintaining the logic as new types
|
||||
// are added, and is not necessary.)
|
||||
Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
|
||||
Graph* graph_out) {
|
||||
TensorProto dummy_proto;
|
||||
dummy_proto.set_dtype(data_type);
|
||||
*dummy_proto.mutable_tensor_shape() = shape;
|
||||
// Don't set any value field in the proto, since it is only going to be used
|
||||
// for shape inference.
|
||||
|
||||
GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
|
||||
NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
|
||||
options.op_registry());
|
||||
node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
|
||||
return options.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
// Adds a copy of node_in to graph_out and adds the mapping to
|
||||
// copied_node_images.
|
||||
Status CopyShapeInferenceNodeToGraph(
|
||||
Node* node_in, const Node* send_node,
|
||||
const std::unordered_map<Node*, Node*>& dummy_node_images,
|
||||
FunctionLibraryDefinition* library,
|
||||
std::unordered_map<Node*, Node*>* copied_node_images, Graph* graph_out) {
|
||||
// Once all the ancestor nodes have been added to graph_out, add this node
|
||||
// and connect it to its ancestors.
|
||||
Node* node_out = graph_out->CopyNode(node_in);
|
||||
(*copied_node_images)[node_in] = node_out;
|
||||
// Don't bother to build the shape inference graph if there's a node with no
|
||||
// shape inference function, since it would just result in an error later at
|
||||
// compile time.
|
||||
const OpRegistrationData* op_reg_data;
|
||||
TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data));
|
||||
if (op_reg_data->shape_inference_fn == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape inference is not possible for outside_compilation "
|
||||
"SendFromHost node ",
|
||||
send_node->name(), " because it depends on node ", node_in->name(),
|
||||
" which does not have a shape inference function registered.");
|
||||
}
|
||||
// Add all the edges to the newly copied node.
|
||||
for (const Edge* in_edge : node_in->in_edges()) {
|
||||
if (!in_edge->IsControlEdge()) {
|
||||
Node* src = in_edge->src();
|
||||
const auto iter = dummy_node_images.find(src);
|
||||
if (iter == dummy_node_images.end()) {
|
||||
// The src is a copied node so use the original output port.
|
||||
graph_out->AddEdge((*copied_node_images)[in_edge->src()],
|
||||
in_edge->src_output(), node_out,
|
||||
in_edge->dst_input());
|
||||
} else {
|
||||
// The src is a dummy node so use output port 0.
|
||||
graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input());
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
const Graph& graph_in, const ShapeRefiner& shape_refiner,
|
||||
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
|
||||
FunctionLibraryDefinition* library,
|
||||
std::vector<TensorShapeProto>* static_shape_out,
|
||||
std::unique_ptr<GraphDef>* graphdef_out) {
|
||||
// Maps from nodes in graph_in to nodes in graph_out.
|
||||
//
|
||||
// When an edge has fully defined shape the source node in graph_in is
|
||||
// replaced in graph_out by a dummy constant node. The mapping from nodes
|
||||
// in graph_in to dummy nodes is stored in dummy_node_images.
|
||||
//
|
||||
// When a node in graph_in has at least one ancestor that doesn't have fully
|
||||
// defined shape, it is copied into graph_out. The mapping from nodes in
|
||||
// graph_in to copied nodes is stored in copied_node_images.
|
||||
//
|
||||
// The two types of node are treated differently because, when adding edges to
|
||||
// graph_out, an output from a dummy node always uses port 0, whereas an
|
||||
// output from a copied node uses the same port that was used in graph_in.
|
||||
std::unordered_map<Node*, Node*> dummy_node_images;
|
||||
std::unordered_map<Node*, Node*> copied_node_images;
|
||||
|
||||
std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
|
||||
graph_out->set_versions(graph_in.versions());
|
||||
static_shape_out->resize(send_node->num_inputs());
|
||||
|
||||
// We don't use the standard ReverseDFS because we want to cut off traversal
|
||||
// whenever we find an output with fully defined shape.
|
||||
// TODO(misard) make this work properly in the presence of control flow.
|
||||
struct Work {
|
||||
Node* node;
|
||||
bool leave; // Are we entering or leaving node?
|
||||
};
|
||||
std::vector<Work> stack({{send_node, false}});
|
||||
std::vector<bool> visited(graph_in.num_node_ids(), false);
|
||||
while (!stack.empty()) {
|
||||
Work w = stack.back();
|
||||
stack.pop_back();
|
||||
Node* n = w.node;
|
||||
|
||||
if (w.leave) {
|
||||
TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
|
||||
n, send_node, dummy_node_images, library, &copied_node_images,
|
||||
graph_out.get()));
|
||||
} else {
|
||||
if (visited[n->id()]) continue;
|
||||
visited[n->id()] = true;
|
||||
|
||||
// Arrange to revisit when all done with all inputs.
|
||||
stack.push_back(Work{n, true});
|
||||
|
||||
bool has_parent_with_unknown_shape = false;
|
||||
for (const Edge* in_edge : n->in_edges()) {
|
||||
if (!in_edge->IsControlEdge()) {
|
||||
Node* src_node = in_edge->src();
|
||||
int src_port = in_edge->src_output();
|
||||
shape_inference::InferenceContext* context =
|
||||
shape_refiner.GetContext(src_node);
|
||||
shape_inference::ShapeHandle shape = context->output(src_port);
|
||||
if (context->FullyDefined(shape)) {
|
||||
// This ancestor has known shape, so instead of adding it to the
|
||||
// stack, add a dummy node with that shape to graph_out and
|
||||
// continue.
|
||||
TensorShapeProto proto;
|
||||
context->ShapeHandleToProto(shape, &proto);
|
||||
dummy_node_images[src_node] = AddDummyShapedNode(
|
||||
src_node->output_type(src_port), proto, graph_out.get());
|
||||
if (n == send_node) {
|
||||
(*static_shape_out)[in_edge->dst_input()] = proto;
|
||||
}
|
||||
} else {
|
||||
if (!visited[src_node->id()]) {
|
||||
has_parent_with_unknown_shape = true;
|
||||
stack.push_back({src_node, false});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!has_parent_with_unknown_shape) {
|
||||
if (n == send_node) {
|
||||
// The shapes of all the inputs to send_node are statically known. We
|
||||
// won't have to do any inference at compile time so return now: the
|
||||
// shapes were stored in static_shape_out above.
|
||||
graphdef_out->reset();
|
||||
return Status::OK();
|
||||
} else {
|
||||
// Any shape that is being processed is either the original send node
|
||||
// or has at least one output with statically-unknown shape. If the
|
||||
// latter and it doesn't have any inputs with statically-unknown
|
||||
// shape, then check that it is of the recv nodes that we can fill in
|
||||
// the shape of at run-time later. If it isn't one of those, then we
|
||||
// won't have any additional knowledge at compile time, so we already
|
||||
// know we won't be able to do shape inference and we can return an
|
||||
// error now.
|
||||
if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape inference is not possible for outside_compilation "
|
||||
"SendFromHost node ",
|
||||
send_node->name(), " because shape of node ", n->name(),
|
||||
" will not be known at compilation time.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
graphdef_out->reset(new GraphDef());
|
||||
graph_out->ToGraphDef(graphdef_out->get());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::MakePrunedGraphCopyAndInline(
|
||||
const Graph& graph, const std::vector<Node*>& sink_nodes,
|
||||
std::unique_ptr<Graph>* pruned_graph,
|
||||
std::unordered_map<const Node*, Node*>* node_images,
|
||||
FunctionLibraryDefinition* library) {
|
||||
// First copy all ancestor nodes of sink_nodes into a new graph.
|
||||
pruned_graph->reset(new Graph(library));
|
||||
(*pruned_graph)->set_versions(graph.versions());
|
||||
ReverseDFSFrom(graph, sink_nodes,
|
||||
/*enter=*/nullptr,
|
||||
/*leave=*/[&](Node* n) {
|
||||
if (!n->IsSource()) {
|
||||
Node* copied = (*pruned_graph)->CopyNode(n);
|
||||
node_images->emplace(n, copied);
|
||||
}
|
||||
});
|
||||
|
||||
// Add all the edges between copied nodes.
|
||||
for (auto entry : *node_images) {
|
||||
const Node* orig = entry.first;
|
||||
Node* image = entry.second;
|
||||
for (const Edge* out_edge : orig->out_edges()) {
|
||||
auto iter = node_images->find(out_edge->dst());
|
||||
if (iter != node_images->end()) {
|
||||
// The source and destination are both in the copied graph.
|
||||
(*pruned_graph)
|
||||
->AddEdge(image, out_edge->src_output(), iter->second,
|
||||
out_edge->dst_input());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find all the function call nodes, and inline them.
|
||||
std::vector<Node*> function_nodes;
|
||||
for (auto node : (*pruned_graph)->nodes()) {
|
||||
const OpRegistrationData* op_reg_data;
|
||||
TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
|
||||
if (op_reg_data->is_function_op) {
|
||||
function_nodes.push_back(node);
|
||||
}
|
||||
}
|
||||
for (auto node : function_nodes) {
|
||||
VLOG(2) << "Inlining function " << node->name();
|
||||
const FunctionDef* fdef = library->Find(node->type_string());
|
||||
if (fdef == nullptr) {
|
||||
return errors::Internal("Failed to find function ", node->type_string(),
|
||||
" in function library.");
|
||||
}
|
||||
FunctionBody* fbody = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FunctionDefToBodyHelper(*fdef, node->attrs(), library,
|
||||
[library](const string& op, const OpDef** sig) {
|
||||
return library->LookUpOpDef(op, sig);
|
||||
},
|
||||
&fbody));
|
||||
InlineFunctionBody(*library, pruned_graph->get(), node, fbody);
|
||||
delete fbody;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::MakeGraphForOutsideCompilationSends(
|
||||
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
|
||||
ShapeRefiner* shape_refiner,
|
||||
std::unordered_map<const Node*, Node*>* node_images,
|
||||
FunctionLibraryDefinition* library) {
|
||||
// Find all the send_from_host nodes in all subgraphs, to use as roots for the
|
||||
// pruning.
|
||||
std::vector<Node*> send_from_host_nodes;
|
||||
for (auto& subgraph_entry : subgraphs_) {
|
||||
Subgraph& subgraph = subgraph_entry.second;
|
||||
std::vector<string> outside_compilation_names;
|
||||
subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
|
||||
for (const auto& name : outside_compilation_names) {
|
||||
Node* send_node = subgraph.GetSendFromHostNode(name);
|
||||
if (send_node != nullptr) {
|
||||
send_from_host_nodes.push_back(send_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make a copy of all the graph nodes needed to evaluate the send_from_host
|
||||
// nodes, inlining any functions as needed.
|
||||
TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
|
||||
graph, send_from_host_nodes, pruned_graph, node_images, library));
|
||||
|
||||
// Perform shape inference on the pruned graph.
|
||||
shape_refiner->set_require_shape_inference_fns(false);
|
||||
FixupSourceAndSinkEdges(pruned_graph->get());
|
||||
std::vector<Node*> post_order;
|
||||
GetReversePostOrder(*(*pruned_graph), &post_order);
|
||||
for (auto node : post_order) {
|
||||
// Ignore the status returned by the shape_refiner. At this point we want
|
||||
// the best effort shapes, even if no shape function is registered for a
|
||||
// node.
|
||||
Status status = shape_refiner->AddNode(node);
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Shape inference failed for node: " << status;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
|
||||
Graph* graph_out, FunctionLibraryDefinition* library) {
|
||||
std::unique_ptr<Graph> pruned_graph;
|
||||
ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
|
||||
std::unordered_map<const Node*, Node*> node_images;
|
||||
TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
|
||||
*graph_out, &pruned_graph, &shape_refiner, &node_images, library));
|
||||
|
||||
for (auto& subgraph_entry : subgraphs_) {
|
||||
Subgraph& subgraph = subgraph_entry.second;
|
||||
// Find all the recv_at_host nodes in this subgraph.
|
||||
std::vector<string> outside_compilation_names;
|
||||
subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
|
||||
std::unordered_set<string> recv_at_host_names;
|
||||
for (const auto& name : outside_compilation_names) {
|
||||
Node* recv_node = subgraph.GetRecvAtHostNode(name);
|
||||
if (recv_node != nullptr) {
|
||||
recv_at_host_names.insert(recv_node->name());
|
||||
}
|
||||
}
|
||||
// For each send_from_host node, do as much shape inference as possible
|
||||
// without knowing the shape of the recv_at_host nodes, and store the
|
||||
// result, along with enough information to complete the job at compile time
|
||||
// once the recv_at_host shapes are known.
|
||||
for (const auto& name : outside_compilation_names) {
|
||||
Node* send_node = subgraph.GetSendFromHostNode(name);
|
||||
std::vector<TensorShapeProto> static_shape;
|
||||
std::unique_ptr<GraphDef> graphdef;
|
||||
if (send_node != nullptr) {
|
||||
TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
|
||||
*pruned_graph, shape_refiner, recv_at_host_names,
|
||||
node_images[send_node], library, &static_shape, &graphdef));
|
||||
if (graphdef == nullptr) {
|
||||
VLOG(2) << "Send node " << send_node->name() << " shapes";
|
||||
for (int i = 0; i < static_shape.size(); ++i) {
|
||||
VLOG(2) << static_shape[i].DebugString();
|
||||
}
|
||||
} else {
|
||||
VLOG(2) << "Send node " << send_node->name() << " graph\n"
|
||||
<< graphdef->DebugString();
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
|
||||
}
|
||||
if (!outside_compilation_names.empty()) {
|
||||
TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
|
||||
FunctionLibraryDefinition* library) {
|
||||
// Map from nodes in the input graph to nodes in the output graph.
|
||||
std::unordered_map<const Node*, Node*> node_images;
|
||||
|
||||
@ -1522,6 +1968,9 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking,
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetShapeInfoForOutsideCompilationSends(graph_out, library));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1545,7 +1994,7 @@ Status EncapsulateSubgraphsInFunctions(
|
||||
std::unique_ptr<Graph> out(new Graph(library));
|
||||
out->set_versions(graph_in.versions());
|
||||
TF_RETURN_IF_ERROR(
|
||||
encapsulator.BuildOutputGraph(parallel_checking, out.get()));
|
||||
encapsulator.BuildOutputGraph(parallel_checking, out.get(), library));
|
||||
|
||||
*graph_out = std::move(out);
|
||||
return Status::OK();
|
||||
|
@ -29,17 +29,181 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
string* diff) {
|
||||
// TODO(phawkins) use a more sophisticated equality test.
|
||||
if (a.DebugString() != b.DebugString()) {
|
||||
template <class Tkey, class Tvalue>
|
||||
bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
|
||||
const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
|
||||
const std::function<string(const Tkey&)>& key_to_string,
|
||||
const std::function<string(const Tvalue&)>& value_to_string,
|
||||
const std::function<bool(const Tkey&, const Tvalue&,
|
||||
const Tvalue&)>& compare,
|
||||
const string& map_name, string* diff) {
|
||||
for (const auto& elt_a : a) {
|
||||
const auto iter = b.find(elt_a.first);
|
||||
if (iter == b.end()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(
|
||||
map_name, " expected: contains element with key '",
|
||||
key_to_string(elt_a.first), "' got: map has no such element");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!compare(elt_a.first, elt_a.second, iter->second)) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(map_name, " expected: element with key '",
|
||||
key_to_string(elt_a.first), " has value '",
|
||||
value_to_string(elt_a.second), "' got: '",
|
||||
value_to_string(iter->second), "'");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (const auto& elt_b : b) {
|
||||
const auto iter = a.find(elt_b.first);
|
||||
if (iter == a.end()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(map_name, " got: contains element with key '",
|
||||
key_to_string(elt_b.first),
|
||||
"' expected: map has no such element");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
|
||||
const string& diff_preamble, string* diff) {
|
||||
if (a.op() != b.op()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Definition mismatch for function ",
|
||||
a.signature().name(), ", expected:\n",
|
||||
a.DebugString(), "\ngot:\n", b.DebugString());
|
||||
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
", expected op '", a.op(), "' got '", b.op());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (a.device() != b.device()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
", expected device '", a.device(), "' got '",
|
||||
b.device());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (a.input_size() != b.input_size()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
", expected ", a.input_size(), " inputs got ",
|
||||
b.input_size(), " expected:\n", a.DebugString(),
|
||||
"\ngot:\n", b.DebugString());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < a.input_size(); ++i) {
|
||||
if (a.input(i) != b.input(i)) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
|
||||
" input ", i, ", expected ", a.input(i),
|
||||
" got ", b.input(i), " expected:\n",
|
||||
a.DebugString(), "\ngot:\n", b.DebugString());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return EqualProtoMap<string, AttrValue>(
|
||||
a.attr(), b.attr(), [](const string& s) { return s; },
|
||||
[](const AttrValue& v) { return v.DebugString(); },
|
||||
[](const string& key, const AttrValue& av, const AttrValue& bv) {
|
||||
if (key == "shape_inference_graph") {
|
||||
// Default serialization of GraphDef is unstable because maps don't
|
||||
// serialize deterministically. Rather than go through the hoops to
|
||||
// turn on deterministic serialization of this attr just for this
|
||||
// test, add logic here to compare determinstically.
|
||||
GraphDef ga;
|
||||
if (!ga.ParseFromString(av.s())) {
|
||||
return false;
|
||||
}
|
||||
GraphDef gb;
|
||||
if (!gb.ParseFromString(bv.s())) {
|
||||
return false;
|
||||
}
|
||||
return EqualGraphDef(ga, gb, nullptr);
|
||||
} else {
|
||||
return av.DebugString() == bv.DebugString();
|
||||
}
|
||||
},
|
||||
strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
|
||||
diff);
|
||||
}
|
||||
|
||||
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
|
||||
string* diff) {
|
||||
if (a.signature().DebugString() != b.signature().DebugString()) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Signature mismatch for function ",
|
||||
a.signature().name(), ", expected:\n",
|
||||
a.signature().DebugString(), "\ngot:\n",
|
||||
b.signature().DebugString());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
if (!EqualProtoMap<string, AttrValue>(
|
||||
a.attr(), b.attr(), [](const string& s) { return s; },
|
||||
[](const AttrValue& v) { return v.DebugString(); },
|
||||
[](const string& key, const AttrValue& av, const AttrValue& bv) {
|
||||
return av.DebugString() == bv.DebugString();
|
||||
},
|
||||
strings::StrCat("attr mismatch for function ", a.signature().name()),
|
||||
diff)) {
|
||||
return false;
|
||||
}
|
||||
if (!EqualProtoMap<string, string>(
|
||||
a.ret(), b.ret(), [](const string& s) { return s; },
|
||||
[](const string& s) { return s; },
|
||||
[](const string& key, const string& av, const string& bv) {
|
||||
return av == bv;
|
||||
},
|
||||
strings::StrCat("ret mismatch for function ", a.signature().name()),
|
||||
diff)) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < a.node_def_size(); ++i) {
|
||||
bool found = false;
|
||||
for (int j = 0; j < b.node_def_size(); ++j) {
|
||||
if (a.node_def(i).name() == b.node_def(j).name()) {
|
||||
if (!EqualFunctionNodeDef(
|
||||
a.node_def(i), b.node_def(j),
|
||||
strings::StrCat("Function ", a.signature().name()), diff)) {
|
||||
return false;
|
||||
}
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Function ", a.signature().name(),
|
||||
", expected: has node '", a.node_def(i).name(),
|
||||
"' got: no node of that name");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < b.node_def_size(); ++i) {
|
||||
bool found = false;
|
||||
for (int j = 0; j < a.node_def_size(); ++j) {
|
||||
if (b.node_def(i).name() == a.node_def(j).name()) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!found) {
|
||||
if (diff) {
|
||||
*diff = strings::StrCat("Function ", a.signature().name(),
|
||||
", got: has node '", b.node_def(i).name(),
|
||||
"' expected: no node of that name");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -84,29 +248,64 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
|
||||
|
||||
// TODO(misard): remove these fake registrations once there are real Ops to be
|
||||
// compiled.
|
||||
REGISTER_OP("_XlaSendToHost")
|
||||
.Input("input: dtypes")
|
||||
.Attr("dtypes: list(type) >= 0");
|
||||
|
||||
REGISTER_OP("_XlaRecvFromHost")
|
||||
.Output("output: dtypes")
|
||||
.Attr("dtypes: list(type) >= 0");
|
||||
REGISTER_OP("_XlaHostCompute")
|
||||
.Input("inputs: Tinputs")
|
||||
.Output("outputs: Toutputs")
|
||||
.Attr("Tinputs: list(type) >= 0")
|
||||
.Attr("Toutputs: list(type) >= 0")
|
||||
.Attr("key: string")
|
||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("_XlaSendFromHost")
|
||||
.Input("input: dtypes")
|
||||
.Attr("dtypes: list(type) >= 0");
|
||||
.Input("input: Tinputs")
|
||||
.Attr("Tinputs: list(type) >= 0")
|
||||
.Attr("key: string")
|
||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("_XlaRecvAtHost")
|
||||
.Output("output: dtypes")
|
||||
.Attr("dtypes: list(type) >= 0");
|
||||
.Output("output: Toutputs")
|
||||
.Attr("Toutputs: list(type) >= 0")
|
||||
.Attr("key: string")
|
||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("InputTest").Output("o: float");
|
||||
REGISTER_OP("InputTest")
|
||||
.Output("o: float")
|
||||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
|
||||
REGISTER_OP("InputTestShaped")
|
||||
.Output("o: float")
|
||||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(2));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("UnaryTest")
|
||||
.Input("a: float")
|
||||
.Output("o: float")
|
||||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||
::tensorflow::shape_inference::ShapeHandle o;
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
|
||||
c->set_output(0, o);
|
||||
return Status::OK();
|
||||
});
|
||||
REGISTER_OP("BinaryTest")
|
||||
.Input("a: float")
|
||||
.Input("b: float")
|
||||
.Output("o: float");
|
||||
.Output("o: float")
|
||||
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
|
||||
::tensorflow::shape_inference::ShapeHandle o;
|
||||
TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
|
||||
c->set_output(0, o);
|
||||
return Status::OK();
|
||||
});
|
||||
REGISTER_OP("BinaryTest2")
|
||||
.Input("a: float")
|
||||
.Input("b: float")
|
||||
.Output("o: float")
|
||||
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
|
||||
|
||||
REGISTER_OP("AddNLikeTest")
|
||||
.Input("inputs: N * T")
|
||||
@ -124,22 +323,48 @@ Node* Input(const GraphDefBuilder::Options& opts) {
|
||||
return ops::SourceOp("InputTest", opts);
|
||||
}
|
||||
|
||||
Node* RecvAtHost(const gtl::ArraySlice<DataType>& dtypes,
|
||||
Node* InputShaped(const GraphDefBuilder::Options& opts) {
|
||||
return ops::SourceOp("InputTestShaped", opts);
|
||||
}
|
||||
|
||||
Node* KnownShape(const gtl::ArraySlice<int>& shape,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
|
||||
opts.op_registry());
|
||||
TensorProto value;
|
||||
value.set_dtype(DT_FLOAT);
|
||||
for (int dim : shape) {
|
||||
value.mutable_tensor_shape()->add_dim()->set_size(dim);
|
||||
}
|
||||
return opts.WithAttr("value", value)
|
||||
.WithAttr("dtype", DT_FLOAT)
|
||||
.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
Node* RecvAtHost(const string& key, const gtl::ArraySlice<DataType>& dtypes,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
|
||||
"_XlaRecvAtHost", opts.op_registry());
|
||||
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
|
||||
return opts.WithAttr("Toutputs", dtypes)
|
||||
.WithAttr("key", key)
|
||||
.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
Node* SendFromHost(const std::vector<ops::NodeOut>& inputs,
|
||||
const gtl::ArraySlice<DataType>& dtypes,
|
||||
Node* SendFromHost(const string& key, const std::vector<ops::NodeOut>& inputs,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
|
||||
"_XlaSendFromHost", opts.op_registry());
|
||||
node_builder.Input(inputs);
|
||||
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
|
||||
std::vector<DataType> dtypes;
|
||||
for (const auto& node : inputs) {
|
||||
dtypes.push_back(node.dt);
|
||||
}
|
||||
return opts.WithAttr("key", key)
|
||||
.WithAttr("Tinputs", dtypes)
|
||||
.FinalizeBuilder(&node_builder);
|
||||
}
|
||||
|
||||
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
|
||||
@ -151,6 +376,11 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b,
|
||||
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
|
||||
}
|
||||
|
||||
Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts);
|
||||
}
|
||||
|
||||
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
|
||||
const GraphDefBuilder::Options& opts) {
|
||||
if (opts.HaveError()) return nullptr;
|
||||
@ -576,6 +806,21 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
string shape_string_expected;
|
||||
{
|
||||
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
|
||||
Node* recv =
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||
shape.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
||||
shape.opts().WithName("E"));
|
||||
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
shape.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
GraphDef shape_graph;
|
||||
TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
|
||||
EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
|
||||
}
|
||||
|
||||
*library_expected.add_function() = test::function::XTimesTwo();
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
@ -584,19 +829,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
||||
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"C:o:0", "outside_compilation_O1_recv:output:0"},
|
||||
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{},
|
||||
{"outside_compilation_O1_recv"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{"C:o:0", "c:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", shape_string_expected},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||
{"c"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
@ -612,11 +856,11 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
|
||||
Node* call = b2.opts().FinalizeBuilder(&node_builder);
|
||||
|
||||
Node* recv =
|
||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
||||
b2.opts().WithName("E").WithControlInputs({recv, b}));
|
||||
Node* send = SendFromHost({e}, {DT_FLOAT},
|
||||
Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
b2.opts()
|
||||
.WithName("outside_compilation_F1_O1_send")
|
||||
.WithControlInput(e));
|
||||
@ -674,37 +918,71 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
string shape_string_expected_1;
|
||||
{
|
||||
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
|
||||
Node* recv =
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||
shape1.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
||||
shape1.opts().WithName("E"));
|
||||
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
shape1.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
GraphDef shape1_graph;
|
||||
TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph));
|
||||
EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1));
|
||||
}
|
||||
|
||||
string shape_string_expected_2;
|
||||
{
|
||||
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
|
||||
Node* recv1 =
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||
shape2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
||||
shape2.opts().WithName("E"));
|
||||
Node* recv2 =
|
||||
RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
|
||||
shape2.opts().WithName("outside_compilation_F1_O2_recv"));
|
||||
Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
|
||||
SendFromHost("host_compute_channel_F1_O2", {h},
|
||||
shape2.opts().WithName("outside_compilation_F1_O2_send"));
|
||||
GraphDef shape2_graph;
|
||||
TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph));
|
||||
EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2));
|
||||
}
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
|
||||
{{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}},
|
||||
{{"I"},
|
||||
"UnaryTest",
|
||||
{"outside_compilation_O2_host_compute:outputs:0"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"C:o:0", "outside_compilation_O1_recv:output:0"},
|
||||
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{},
|
||||
{"outside_compilation_O1_recv"}},
|
||||
{{"outside_compilation_O2_send"},
|
||||
"_XlaSendToHost",
|
||||
{"outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O2_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{"D:o:0", "F:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"key", "host_compute_channel_F1_O2"},
|
||||
{"shape_inference_graph", shape_string_expected_2},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||
{"F"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{"C:o:0", "D:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", shape_string_expected_1},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||
{"D"}},
|
||||
{{"outside_compilation_O2_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O2_send"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"i_0_retval", "I:o:0"}});
|
||||
|
||||
@ -720,23 +998,24 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
|
||||
Node* call = b2.opts().FinalizeBuilder(&node_builder);
|
||||
|
||||
Node* recv1 =
|
||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
||||
b2.opts().WithName("E").WithControlInputs({recv1, b}));
|
||||
Node* send1 = SendFromHost({e}, {DT_FLOAT},
|
||||
Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
b2.opts()
|
||||
.WithName("outside_compilation_F1_O1_send")
|
||||
.WithControlInput(e));
|
||||
|
||||
Node* recv2 =
|
||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
||||
RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O2_recv"));
|
||||
Node* g = Binary(e, ops::NodeOut(recv2, 1),
|
||||
b2.opts().WithName("G").WithControlInputs({recv2, e}));
|
||||
Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
|
||||
Node* send2 = SendFromHost(
|
||||
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send"));
|
||||
Node* send2 =
|
||||
SendFromHost("host_compute_channel_F1_O2", {h},
|
||||
b2.opts().WithName("outside_compilation_F1_O2_send"));
|
||||
|
||||
Node* s = NoOp(b2.opts()
|
||||
.WithName("F1_sequencer")
|
||||
@ -758,8 +1037,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* a = InputShaped(b1.opts().WithName("A"));
|
||||
Node* b = InputShaped(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
|
||||
@ -791,6 +1070,24 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
string shape_string_expected;
|
||||
{
|
||||
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
|
||||
Node* recv =
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||
shape.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
|
||||
shape.opts().WithName("E"));
|
||||
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
shape.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
GraphDef shape_graph;
|
||||
TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
|
||||
EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
|
||||
}
|
||||
|
||||
TensorShapeProto shape_proto_expected;
|
||||
shape_proto_expected.add_dim()->set_size(2);
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"},
|
||||
{"f_0_retval:float", "d_0_retval:float"}, {},
|
||||
@ -799,19 +1096,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"C:o:0", "outside_compilation_O1_recv:output:0"},
|
||||
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{},
|
||||
{"outside_compilation_O1_recv"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{"C:o:0", "D:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", shape_string_expected},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||
{"D"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
|
||||
|
||||
@ -822,16 +1118,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
{{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
|
||||
{{"I"},
|
||||
"BinaryTest",
|
||||
{"f_0_arg", "outside_compilation_O1_recv:output:0"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{"G:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"key", "host_compute_channel_F2_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes",
|
||||
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
|
||||
},
|
||||
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
|
||||
|
||||
@ -839,15 +1135,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
Node* a = InputShaped(b2.opts().WithName("A"));
|
||||
Node* b = InputShaped(b2.opts().WithName("B"));
|
||||
|
||||
Node* recv1 =
|
||||
RecvAtHost({DT_FLOAT, DT_FLOAT},
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
|
||||
b2.opts().WithName("E").WithControlInputs({recv1, b}));
|
||||
Node* send1 = SendFromHost({e}, {DT_FLOAT},
|
||||
Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
b2.opts()
|
||||
.WithName("outside_compilation_F1_O1_send")
|
||||
.WithControlInput(e));
|
||||
@ -857,12 +1153,14 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
|
||||
Node* s1 = NoOp(
|
||||
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
|
||||
|
||||
Node* recv2 = RecvAtHost(
|
||||
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
|
||||
Node* recv2 =
|
||||
RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F2_O1_recv"));
|
||||
Node* h = Binary(ops::NodeOut(call1, 1), recv2,
|
||||
b2.opts().WithName("H").WithControlInput(s1));
|
||||
Node* send2 = SendFromHost(
|
||||
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send"));
|
||||
Node* send2 =
|
||||
SendFromHost("host_compute_channel_F2_O1", {h},
|
||||
b2.opts().WithName("outside_compilation_F2_O1_send"));
|
||||
|
||||
NodeBuilder node_builder2("F2", "F2", lib_def.get());
|
||||
node_builder2.Input(e).Input(call1);
|
||||
@ -888,7 +1186,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* a = InputShaped(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
@ -908,6 +1206,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
TensorShapeProto shape_proto_expected;
|
||||
shape_proto_expected.add_dim()->set_size(2);
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
{
|
||||
@ -915,11 +1216,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes",
|
||||
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
@ -927,12 +1233,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* a = InputShaped(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* e = Unary(a, b2.opts().WithName("E"));
|
||||
Node* send1 = SendFromHost(
|
||||
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
Node* send1 =
|
||||
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||
node_builder1.Input(a).Input(b);
|
||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||
@ -954,7 +1261,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
||||
|
||||
{
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = Input(b1.opts().WithName("A"));
|
||||
Node* a = InputShaped(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
|
||||
Node* d =
|
||||
@ -975,6 +1282,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
TensorShapeProto shape_proto_expected;
|
||||
shape_proto_expected.add_dim()->set_size(2);
|
||||
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
{
|
||||
@ -982,17 +1292,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({})}},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes",
|
||||
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}},
|
||||
{"D"}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
@ -1000,14 +1310,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* a = InputShaped(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* recv1 =
|
||||
RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
RecvAtHost("host_compute_channel_F1_O1", {},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
|
||||
Node* send1 = SendFromHost(
|
||||
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
Node* send1 =
|
||||
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||
node_builder1.Input(a).Input(b);
|
||||
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
|
||||
@ -1055,10 +1367,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"}, "UnaryTest", {"D:o:0"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
@ -1069,8 +1385,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* recv1 = RecvAtHost(
|
||||
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* recv1 =
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Unary(recv1, b2.opts().WithName("E"));
|
||||
NodeBuilder node_builder1("F1", "F1", lib_def.get());
|
||||
node_builder1.Input(a).Input(b);
|
||||
@ -1118,16 +1435,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
|
||||
{
|
||||
{{"C"}, "UnaryTest", {"a_0_arg"}},
|
||||
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
|
||||
{{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}},
|
||||
{{"outside_compilation_O1_send"},
|
||||
"_XlaSendToHost",
|
||||
{{"F"},
|
||||
"UnaryTest",
|
||||
{"D:o:0"},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
|
||||
{{"outside_compilation_O1_recv"},
|
||||
"_XlaRecvFromHost",
|
||||
{},
|
||||
{{"dtypes", gtl::ArraySlice<DataType>({})}},
|
||||
{"outside_compilation_O1_send"}},
|
||||
{"outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{"D:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", ""},
|
||||
{"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
@ -1138,10 +1458,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
|
||||
Node* a = Input(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
|
||||
Node* recv1 = RecvAtHost(
|
||||
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* recv1 =
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = Unary(recv1, b2.opts().WithName("E"));
|
||||
Node* send1 = SendFromHost({}, {},
|
||||
Node* send1 = SendFromHost("host_compute_channel_F1_O1", {},
|
||||
b2.opts()
|
||||
.WithName("outside_compilation_F1_O1_send")
|
||||
.WithControlInput(e));
|
||||
@ -1215,5 +1536,110 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
// Test for shape inference of outside compilation.
|
||||
TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
|
||||
FunctionDefLibrary library;
|
||||
GraphDef graphdef;
|
||||
|
||||
{
|
||||
*library.add_function() = test::function::XTimesTwo();
|
||||
|
||||
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
|
||||
Node* a = InputShaped(b1.opts().WithName("A"));
|
||||
Node* b = Input(b1.opts().WithName("B"));
|
||||
// Give nodes 'c' and 'd' names that collide after lowercasing.
|
||||
Node* c = Unary(a, b1.opts().WithName("C"));
|
||||
Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr(
|
||||
"_encapsulate", "F1"));
|
||||
Node* e = BinaryUnknownShape(c, d,
|
||||
b1.opts()
|
||||
.WithName("E")
|
||||
.WithControlInputs({b, d})
|
||||
.WithAttr("_encapsulate", "F1")
|
||||
.WithAttr("_outside", "O1"));
|
||||
Node* f = Binary(c, e,
|
||||
b1.opts().WithName("F").WithControlInput(e).WithAttr(
|
||||
"_encapsulate", "F1"));
|
||||
Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
|
||||
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
|
||||
}
|
||||
|
||||
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
|
||||
|
||||
FunctionDefLibrary library_expected;
|
||||
GraphDef graphdef_expected;
|
||||
|
||||
string shape_string_expected;
|
||||
{
|
||||
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
|
||||
Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0"));
|
||||
Node* recv =
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
|
||||
shape.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
|
||||
SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
shape.opts().WithName("outside_compilation_F1_O1_send"));
|
||||
GraphDef shape_graph;
|
||||
TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
|
||||
EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
|
||||
}
|
||||
|
||||
*library_expected.add_function() = test::function::XTimesTwo();
|
||||
*library_expected.add_function() = FunctionDefHelper::Create(
|
||||
"F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {},
|
||||
{
|
||||
{{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
|
||||
{{"F"},
|
||||
"BinaryTest",
|
||||
{"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"},
|
||||
{},
|
||||
{"outside_compilation_O1_host_compute"}},
|
||||
{{"outside_compilation_O1_host_compute"},
|
||||
"_XlaHostCompute",
|
||||
{"c:o:0"},
|
||||
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
|
||||
{"key", "host_compute_channel_F1_O1"},
|
||||
{"shape_inference_graph", shape_string_expected},
|
||||
{"shapes", gtl::ArraySlice<DataType>({})}},
|
||||
{"c"}},
|
||||
},
|
||||
{{"f_0_retval", "F:o:0"}});
|
||||
|
||||
{
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def(
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
|
||||
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
|
||||
Node* a = InputShaped(b2.opts().WithName("A"));
|
||||
Node* b = Input(b2.opts().WithName("B"));
|
||||
Node* c = Unary(a, b2.opts().WithName("C"));
|
||||
|
||||
NodeBuilder node_builder("F1", "F1", lib_def.get());
|
||||
node_builder.Input(b).Input(c);
|
||||
Node* call =
|
||||
b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder);
|
||||
|
||||
Node* recv =
|
||||
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
|
||||
b2.opts().WithName("outside_compilation_F1_O1_recv"));
|
||||
Node* e = BinaryUnknownShape(
|
||||
c, ops::NodeOut(recv, 0),
|
||||
b2.opts().WithName("E").WithControlInputs({recv, b}));
|
||||
Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
|
||||
b2.opts()
|
||||
.WithName("outside_compilation_F1_O1_send")
|
||||
.WithControlInput(e));
|
||||
|
||||
Node* s = NoOp(
|
||||
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
|
||||
|
||||
Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
|
||||
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
|
||||
}
|
||||
|
||||
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
|
||||
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -45,7 +45,7 @@ namespace tensorflow {
|
||||
// see comment on `AllowsAsynchronousDeallocation()`.
|
||||
class XlaAllocator : public xla::DeviceMemoryAllocator {
|
||||
public:
|
||||
XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context);
|
||||
XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context);
|
||||
~XlaAllocator() override;
|
||||
xla::StatusOr<gpu::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
|
||||
bool retry_on_failure) override;
|
||||
@ -79,7 +79,8 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
|
||||
std::unordered_map<void*, Tensor> tensors_;
|
||||
};
|
||||
|
||||
XlaAllocator::XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context)
|
||||
XlaAllocator::XlaAllocator(const gpu::Platform* platform,
|
||||
OpKernelContext* op_context)
|
||||
: xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
|
||||
|
||||
XlaAllocator::~XlaAllocator() = default;
|
||||
@ -248,12 +249,16 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
|
||||
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
|
||||
|
||||
// Builds an XLA allocator for the device.
|
||||
XlaAllocator xla_allocator(client->platform(), ctx);
|
||||
|
||||
XlaCompiler::Options options;
|
||||
options.client = client;
|
||||
options.device_type = &cache->device_type();
|
||||
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
|
||||
options.device_allocator = &xla_allocator;
|
||||
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
xla::LocalExecutable* executable;
|
||||
@ -264,9 +269,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
// Builds an XLA allocator for the device.
|
||||
XlaAllocator xla_allocator(client->platform(), ctx);
|
||||
|
||||
std::unique_ptr<xla::ShapedBuffer> output;
|
||||
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
|
||||
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
|
||||
@ -374,8 +376,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES(ctx,
|
||||
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
|
||||
errors::Internal("Invalid input index for variable write."));
|
||||
TensorShape write_shape;
|
||||
OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape));
|
||||
|
||||
gpu::DeviceMemoryBase buffer = output->buffer({output_num});
|
||||
|
||||
@ -397,7 +397,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
|
||||
// Looks up the owning Tensor by buffer address.
|
||||
OP_REQUIRES_OK(
|
||||
ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape,
|
||||
ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape,
|
||||
variable->tensor()));
|
||||
++output_num;
|
||||
}
|
||||
|
@ -148,8 +148,7 @@ Status BuildArguments(int num_constant_args,
|
||||
XlaCompiler::Argument& arg = (*args)[input_num];
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input.dtype();
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
|
||||
arg.shape = input.shape();
|
||||
arg.constant_value = input;
|
||||
++input_num;
|
||||
}
|
||||
@ -170,8 +169,7 @@ Status BuildArguments(int num_constant_args,
|
||||
arg.constant_value = input;
|
||||
}
|
||||
arg.type = input.dtype();
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
|
||||
arg.shape = input.shape();
|
||||
++input_num;
|
||||
}
|
||||
|
||||
@ -189,8 +187,7 @@ Status BuildArguments(int num_constant_args,
|
||||
if (variable_args[variable_id].present) {
|
||||
const Tensor& value = variable_args[variable_id].value;
|
||||
arg.type = value.dtype();
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape));
|
||||
arg.shape = value.shape();
|
||||
arg.initialized = true;
|
||||
} else {
|
||||
// The values of uninitialized variables are not passed as inputs, since
|
||||
@ -199,7 +196,7 @@ Status BuildArguments(int num_constant_args,
|
||||
// uninitialized variables.
|
||||
arg.initialized = false;
|
||||
arg.type = DT_INVALID;
|
||||
arg.shape = xla::Shape();
|
||||
arg.shape = TensorShape();
|
||||
}
|
||||
++input_num;
|
||||
}
|
||||
@ -223,6 +220,7 @@ Status XlaCompilationCache::BuildExecutable(
|
||||
xla::ExecutableBuildOptions build_options;
|
||||
build_options.set_device_ordinal(client_->default_device_ordinal());
|
||||
build_options.set_result_layout(result.xla_output_shape);
|
||||
build_options.set_device_allocator(options.device_allocator);
|
||||
|
||||
auto compile_result =
|
||||
client_->Compile(*result.computation, argument_layouts, build_options);
|
||||
|
@ -144,6 +144,21 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "matrix_triangular_solve_op_test",
|
||||
size = "small",
|
||||
srcs = ["matrix_triangular_solve_op_test.py"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "clustering_test",
|
||||
size = "small",
|
||||
@ -240,6 +255,18 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "extract_image_patches_op_test",
|
||||
size = "small",
|
||||
srcs = ["extract_image_patches_op_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "fft_test",
|
||||
size = "medium",
|
||||
@ -326,6 +353,19 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "matrix_band_part_test",
|
||||
size = "medium",
|
||||
srcs = ["matrix_band_part_test.py"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "momentum_test",
|
||||
size = "small",
|
||||
@ -437,6 +477,18 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "reverse_sequence_op_test",
|
||||
size = "small",
|
||||
srcs = ["reverse_sequence_op_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "rmsprop_test",
|
||||
size = "small",
|
||||
|
@ -1181,6 +1181,50 @@ class BinaryOpsTest(XLATestCase):
|
||||
np.array([4, 5, 6], dtype=np.int32),
|
||||
expected=None)
|
||||
|
||||
def testMatrixSetDiag(self):
|
||||
for dtype in self.numeric_types:
|
||||
# Square
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
dtype=dtype),
|
||||
np.array([1.0, 2.0, 3.0], dtype=dtype),
|
||||
expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]],
|
||||
dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
|
||||
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
|
||||
dtype=dtype),
|
||||
np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype),
|
||||
expected=np.array(
|
||||
[[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]],
|
||||
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]],
|
||||
dtype=dtype))
|
||||
|
||||
# Rectangular
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
|
||||
np.array([3.0, 4.0], dtype=dtype),
|
||||
expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
|
||||
np.array([3.0, 4.0], dtype=dtype),
|
||||
expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
|
||||
|
||||
self._testBinary(
|
||||
array_ops.matrix_set_diag,
|
||||
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
|
||||
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
|
||||
np.array([[-1.0, -2.0], [-4.0, -5.0]],
|
||||
dtype=dtype),
|
||||
expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
|
||||
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
|
||||
dtype=dtype))
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
134
tensorflow/compiler/tests/extract_image_patches_op_test.py
Normal file
134
tensorflow/compiler/tests/extract_image_patches_op_test.py
Normal file
@ -0,0 +1,134 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functional tests for ExtractImagePatches op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ExtractImagePatches(XLATestCase):
|
||||
"""Functional tests for ExtractImagePatches op."""
|
||||
|
||||
def _VerifyValues(self, image, ksizes, strides, rates, padding, patches):
|
||||
"""Tests input-output pairs for the ExtractImagePatches op.
|
||||
|
||||
Args:
|
||||
image: Input tensor with shape: [batch, in_rows, in_cols, depth].
|
||||
ksizes: Patch size specified as: [ksize_rows, ksize_cols].
|
||||
strides: Output strides, specified as [stride_rows, stride_cols].
|
||||
rates: Atrous rates, specified as [rate_rows, rate_cols].
|
||||
padding: Padding type.
|
||||
patches: Expected output.
|
||||
"""
|
||||
ksizes = [1] + ksizes + [1]
|
||||
strides = [1] + strides + [1]
|
||||
rates = [1] + rates + [1]
|
||||
|
||||
with self.test_session():
|
||||
image_placeholder = array_ops.placeholder(dtypes.float32)
|
||||
with self.test_scope():
|
||||
out_tensor = array_ops.extract_image_patches(
|
||||
image_placeholder,
|
||||
ksizes=ksizes,
|
||||
strides=strides,
|
||||
rates=rates,
|
||||
padding=padding,
|
||||
name="im2col")
|
||||
feed_dict = {image_placeholder: image}
|
||||
self.assertAllClose(patches, out_tensor.eval(feed_dict=feed_dict))
|
||||
|
||||
def testKsize1x1Stride1x1Rate1x1(self):
|
||||
"""Verifies that for 1x1 kernel the output equals the input."""
|
||||
# [2, 3, 4, 5]
|
||||
image = np.reshape(range(120), [2, 3, 4, 5])
|
||||
# [2, 3, 4, 5]
|
||||
patches = np.reshape(range(120), [2, 3, 4, 5])
|
||||
for padding in ["VALID", "SAME"]:
|
||||
self._VerifyValues(
|
||||
image,
|
||||
ksizes=[1, 1],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding=padding,
|
||||
patches=patches)
|
||||
|
||||
def testKsize1x1Stride2x3Rate1x1(self):
|
||||
"""Test for 1x1 kernel and strides."""
|
||||
# [2, 4, 5, 3]
|
||||
image = np.reshape(range(120), [2, 4, 5, 3])
|
||||
# [2, 2, 2, 3]
|
||||
patches = image[:, ::2, ::3, :]
|
||||
for padding in ["VALID", "SAME"]:
|
||||
self._VerifyValues(
|
||||
image,
|
||||
ksizes=[1, 1],
|
||||
strides=[2, 3],
|
||||
rates=[1, 1],
|
||||
padding=padding,
|
||||
patches=patches)
|
||||
|
||||
def testKsize2x2Stride1x1Rate1x1Valid(self):
|
||||
"""Test for 2x2 kernel with VALID padding."""
|
||||
# [1, 2, 2, 1]
|
||||
image = [[[[1], [2]], [[3], [4]]]]
|
||||
# [1, 1, 1, 4]
|
||||
patches = [[[[1, 2, 3, 4]]]]
|
||||
self._VerifyValues(
|
||||
image,
|
||||
ksizes=[2, 2],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="VALID",
|
||||
patches=patches)
|
||||
|
||||
def testKsize2x2Stride1x1Rate1x1Same(self):
|
||||
"""Test for 2x2 kernel with SAME padding."""
|
||||
# [1, 2, 2, 1]
|
||||
image = [[[[1], [2]], [[3], [4]]]]
|
||||
# [1, 2, 2, 4]
|
||||
patches = [[[[1, 2, 3, 4], [2, 0, 4, 0]], [[3, 4, 0, 0], [4, 0, 0, 0]]]]
|
||||
self._VerifyValues(
|
||||
image,
|
||||
ksizes=[2, 2],
|
||||
strides=[1, 1],
|
||||
rates=[1, 1],
|
||||
padding="SAME",
|
||||
patches=patches)
|
||||
|
||||
def testKsize2x2Stride1x1Rate2x2Valid(self):
|
||||
"""Test for 2x2 kernel with 2x2 dilation."""
|
||||
# [1, 2, 2, 1]
|
||||
image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32)
|
||||
# [1, 2, 2, 4]
|
||||
patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]],
|
||||
[[4, 6, 12, 14], [5, 7, 13, 15]]]]
|
||||
self._VerifyValues(
|
||||
image,
|
||||
ksizes=[2, 2],
|
||||
strides=[1, 1],
|
||||
rates=[2, 2],
|
||||
padding="VALID",
|
||||
patches=patches)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
64
tensorflow/compiler/tests/matrix_band_part_test.py
Normal file
64
tensorflow/compiler/tests/matrix_band_part_test.py
Normal file
@ -0,0 +1,64 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class MatrixBandPartTest(XLATestCase):
|
||||
|
||||
def _testMatrixBandPart(self, dtype, shape):
|
||||
with self.test_session():
|
||||
batch_shape = shape[:-2]
|
||||
mat = np.ones(shape).astype(dtype)
|
||||
batch_mat = np.tile(mat, batch_shape + [1, 1])
|
||||
for lower in -1, 0, 1, shape[-2] - 1:
|
||||
for upper in -1, 0, 1, shape[-1] - 1:
|
||||
band_np = mat
|
||||
if lower >= 0:
|
||||
band_np = np.triu(band_np, -lower)
|
||||
if upper >= 0:
|
||||
band_np = np.tril(band_np, upper)
|
||||
if batch_shape:
|
||||
band_np = np.tile(band_np, batch_shape + [1, 1])
|
||||
|
||||
placeholder = array_ops.placeholder(dtype)
|
||||
with self.test_scope():
|
||||
band = array_ops.matrix_band_part(
|
||||
placeholder,
|
||||
constant_op.constant(lower, dtype=dtypes.int32),
|
||||
constant_op.constant(upper, dtype=dtypes.int32))
|
||||
feed_dict = {placeholder: batch_mat}
|
||||
self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
|
||||
|
||||
def testMatrixBandPart(self):
|
||||
for dtype in self.float_types:
|
||||
for batch_shape in [[], [2,], [1, 3, 2]]:
|
||||
for rows in 1, 2, 7:
|
||||
for cols in 1, 2, 7:
|
||||
self._testMatrixBandPart(dtype, batch_shape + [rows, cols])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
130
tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
Normal file
130
tensorflow/compiler/tests/matrix_triangular_solve_op_test.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.tf.MatrixTriangularSolve."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def MakePlaceholder(x):
|
||||
return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape)
|
||||
|
||||
|
||||
class MatrixTriangularSolveOpTest(XLATestCase):
|
||||
|
||||
def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca,
|
||||
placeholder_b, a, clean_a, b, verification,
|
||||
atol):
|
||||
feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b}
|
||||
verification_np = sess.run(verification, feed_dict)
|
||||
self.assertAllClose(b, verification_np, atol=atol)
|
||||
|
||||
def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol):
|
||||
clean_a = np.tril(a) if lower else np.triu(a)
|
||||
with self.test_session() as sess:
|
||||
placeholder_a = MakePlaceholder(a)
|
||||
placeholder_ca = MakePlaceholder(clean_a)
|
||||
placeholder_b = MakePlaceholder(b)
|
||||
with self.test_scope():
|
||||
x = linalg_ops.matrix_triangular_solve(
|
||||
placeholder_a, placeholder_b, lower=lower, adjoint=adjoint)
|
||||
verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint)
|
||||
self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca,
|
||||
placeholder_b, a, clean_a, b,
|
||||
verification, atol)
|
||||
|
||||
def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4):
|
||||
transp = lambda x: np.swapaxes(x, -1, -2)
|
||||
for lower, adjoint in itertools.product([True, False], repeat=2):
|
||||
self._VerifyTriangularSolve(
|
||||
a if lower else transp(a), b, lower, adjoint, atol)
|
||||
|
||||
def testBasic(self):
|
||||
rng = np.random.RandomState(0)
|
||||
a = np.tril(rng.randn(5, 5))
|
||||
b = rng.randn(5, 7)
|
||||
for dtype in self.float_types:
|
||||
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
|
||||
|
||||
def testBasicNotActuallyTriangular(self):
|
||||
rng = np.random.RandomState(0)
|
||||
a = rng.randn(5, 5) # the `a` matrix is not lower-triangular
|
||||
b = rng.randn(5, 7)
|
||||
for dtype in self.float_types:
|
||||
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
|
||||
|
||||
def testBasicComplexDtypes(self):
|
||||
rng = np.random.RandomState(0)
|
||||
a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j)
|
||||
b = rng.randn(5, 7) + rng.randn(5, 7) * 1j
|
||||
for dtype in self.complex_types:
|
||||
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
|
||||
|
||||
def testBatch(self):
|
||||
rng = np.random.RandomState(0)
|
||||
shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)),
|
||||
((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))]
|
||||
tuples = itertools.product(self.float_types, shapes)
|
||||
for dtype, (a_shape, b_shape) in tuples:
|
||||
n = a_shape[-1]
|
||||
a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n)
|
||||
b = rng.randn(*b_shape)
|
||||
self._VerifyTriangularSolveCombo(
|
||||
a.astype(dtype), b.astype(dtype), atol=1e-3)
|
||||
|
||||
def testLarge(self):
|
||||
n = 1024
|
||||
rng = np.random.RandomState(0)
|
||||
a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n)
|
||||
b = rng.randn(n, n)
|
||||
self._VerifyTriangularSolve(
|
||||
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
|
||||
|
||||
def testNonSquareCoefficientMatrix(self):
|
||||
rng = np.random.RandomState(0)
|
||||
for dtype in self.float_types:
|
||||
a = rng.randn(3, 4).astype(dtype)
|
||||
b = rng.randn(4, 4).astype(dtype)
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.matrix_triangular_solve(a, b)
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.matrix_triangular_solve(a, b)
|
||||
|
||||
def testWrongDimensions(self):
|
||||
randn = np.random.RandomState(0).randn
|
||||
for dtype in self.float_types:
|
||||
lhs = constant_op.constant(randn(3, 3), dtype=dtype)
|
||||
rhs = constant_op.constant(randn(4, 3), dtype=dtype)
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.matrix_triangular_solve(lhs, rhs)
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.matrix_triangular_solve(lhs, rhs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
93
tensorflow/compiler/tests/reverse_sequence_op_test.py
Normal file
93
tensorflow/compiler/tests/reverse_sequence_op_test.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.reverse_sequence_op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class ReverseSequenceTest(XLATestCase):
|
||||
|
||||
def _testReverseSequence(self,
|
||||
x,
|
||||
batch_axis,
|
||||
seq_axis,
|
||||
seq_lengths,
|
||||
truth,
|
||||
expected_err_re=None):
|
||||
with self.test_session():
|
||||
p = array_ops.placeholder(dtypes.as_dtype(x.dtype))
|
||||
lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype))
|
||||
with self.test_scope():
|
||||
ans = array_ops.reverse_sequence(
|
||||
p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths)
|
||||
if expected_err_re is None:
|
||||
tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths})
|
||||
self.assertAllClose(tf_ans, truth, atol=1e-10)
|
||||
else:
|
||||
with self.assertRaisesOpError(expected_err_re):
|
||||
ans.eval(feed_dict={p: x, lengths: seq_lengths})
|
||||
|
||||
def testSimple(self):
|
||||
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
|
||||
expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32)
|
||||
self._testReverseSequence(
|
||||
x,
|
||||
batch_axis=0,
|
||||
seq_axis=1,
|
||||
seq_lengths=np.array([1, 3, 2], np.int32),
|
||||
truth=expected)
|
||||
|
||||
def _testBasic(self, dtype, len_dtype):
|
||||
x = np.asarray(
|
||||
[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]],
|
||||
[[17, 18, 19, 20], [21, 22, 23, 24]]],
|
||||
dtype=dtype)
|
||||
x = x.reshape(3, 2, 4, 1, 1)
|
||||
x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
|
||||
|
||||
# reverse dim 2 up to (0:3, none, 0:4) along dim=0
|
||||
seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype)
|
||||
|
||||
truth_orig = np.asarray(
|
||||
[
|
||||
[[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3
|
||||
[[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none
|
||||
[[20, 19, 18, 17], [24, 23, 22, 21]]
|
||||
], # reverse 0:4 (all)
|
||||
dtype=dtype)
|
||||
truth_orig = truth_orig.reshape(3, 2, 4, 1, 1)
|
||||
truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
|
||||
|
||||
seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.)
|
||||
batch_axis = 2
|
||||
self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth)
|
||||
|
||||
def testSeqLength(self):
|
||||
for dtype in self.all_types:
|
||||
for seq_dtype in self.int_types:
|
||||
self._testBasic(dtype, seq_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -154,6 +154,21 @@ class UnaryOpsTest(XLATestCase):
|
||||
|
||||
def testFloatOps(self):
|
||||
for dtype in self.float_types:
|
||||
x = np.arange(-0.90, 0.90, 0.25)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.acos,
|
||||
x.astype(dtype),
|
||||
expected=np.arccos(x).astype(dtype))
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.asin,
|
||||
x.astype(dtype),
|
||||
expected=np.arcsin(x).astype(dtype))
|
||||
x = np.arange(-3, 3).reshape(1, 3, 2)
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.atan,
|
||||
x.astype(dtype),
|
||||
expected=np.arctan(x).astype(dtype))
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.acosh,
|
||||
np.array([1, 2, 3, 4], dtype=dtype),
|
||||
|
@ -427,16 +427,36 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
|
||||
// identity nodes are values used by the loop body or condition.
|
||||
// The Identity node may have the wrong device so copy the device from
|
||||
// one of its outputs instead.
|
||||
std::deque<const Edge*> possible_exit;
|
||||
for (const Edge* edge : arg.switch_node->out_edges()) {
|
||||
if (edge->src_output() == 0 && IsExit(edge->dst())) {
|
||||
if (edge->src_output() == 0) {
|
||||
possible_exit.push_back(edge);
|
||||
}
|
||||
if (IsIdentity(edge->dst())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
|
||||
}
|
||||
}
|
||||
// TODO(b/67425339): Allow general graph between switch and exit.
|
||||
while (!possible_exit.empty()) {
|
||||
const Edge* edge = possible_exit.front();
|
||||
possible_exit.pop_front();
|
||||
if (IsExit(edge->dst())) {
|
||||
if (arg.exit != nullptr) {
|
||||
return errors::InvalidArgument("Duplicate Exit successors to ",
|
||||
arg.switch_node->name());
|
||||
}
|
||||
arg.exit = edge->dst();
|
||||
} else if (StringPiece(edge->dst()->type_string()) == "Identity") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
|
||||
} else {
|
||||
if (!IsIdentity(edge->dst())) {
|
||||
return errors::Unimplemented("General graph between switch (",
|
||||
arg.switch_node->name(),
|
||||
") and exit node of frame ",
|
||||
frame->name, " not supported yet.");
|
||||
}
|
||||
for (const Edge* out : edge->dst()->out_edges()) {
|
||||
possible_exit.push_back(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,9 @@ Operator | Type Constraint
|
||||
`Acosh` | `T={complex64,double,float}`
|
||||
`Add` | `T={complex64,double,float,int32,int64}`
|
||||
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`AdjustContrastv2` |
|
||||
`AdjustHue` |
|
||||
`AdjustSaturation` |
|
||||
`All` | `Tidx={int32,int64}`
|
||||
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Any` | `Tidx={int32,int64}`
|
||||
@ -34,7 +37,7 @@ Operator | Type Constraint
|
||||
`BroadcastGradientArgs` | `T={int32,int64}`
|
||||
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Ceil` | `T={double,float}`
|
||||
`Cholesky` | `T={complex64,double,float}`
|
||||
`Cholesky` | `T={double,float}`
|
||||
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
||||
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
@ -68,7 +71,11 @@ Operator | Type Constraint
|
||||
`Exp` | `T={complex64,double,float}`
|
||||
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Expm1` | `T={complex64,double,float}`
|
||||
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`FFT` |
|
||||
`FFT2D` |
|
||||
`FFT3D` |
|
||||
`Fill` | `index_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Floor` | `T={double,float}`
|
||||
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
||||
`FloorMod` | `T={double,float,int32,int64}`
|
||||
@ -80,6 +87,13 @@ Operator | Type Constraint
|
||||
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`HSVToRGB` | `T={double,float}`
|
||||
`IFFT` |
|
||||
`IFFT2D` |
|
||||
`IFFT3D` |
|
||||
`IRFFT` |
|
||||
`IRFFT2D` |
|
||||
`IRFFT3D` |
|
||||
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
||||
@ -105,11 +119,14 @@ Operator | Type Constraint
|
||||
`MatMul` | `T={complex64,double,float}`
|
||||
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MatrixTriangularSolve` | `T={complex64,double,float}`
|
||||
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MaxPool` | `T={double,float,int32,int64}`
|
||||
`MaxPool3D` | `T={float}`
|
||||
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
||||
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`MaxPoolV2` | `T={double,float,int32,int64}`
|
||||
`Maximum` | `T={double,float,int32,int64}`
|
||||
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
@ -131,6 +148,10 @@ Operator | Type Constraint
|
||||
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`QuantizeAndDequantizeV2` | `T={double,float}`
|
||||
`RFFT` |
|
||||
`RFFT2D` |
|
||||
`RFFT3D` |
|
||||
`RGBToHSV` | `T={double,float}`
|
||||
`RandomStandardNormal` | `dtype={float}`
|
||||
`RandomUniform` | `T={int32,int64}`<br>`dtype={double,float}`
|
||||
`RandomUniformInt` | `T={int32,int64}`<br>`Tout={int32,int64}`
|
||||
@ -146,6 +167,8 @@ Operator | Type Constraint
|
||||
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ResizeBilinear` | `T={double,float,int32,int64}`
|
||||
`ResizeBilinearGrad` | `T={double,float}`
|
||||
`ResourceApplyAdagrad` | `T={double,float}`
|
||||
`ResourceApplyAdam` | `T={double,float}`
|
||||
`ResourceApplyFtrl` | `T={double,float}`
|
||||
@ -156,6 +179,7 @@ Operator | Type Constraint
|
||||
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
||||
`RightShift` | `T={int32,int64,uint32,uint64}`
|
||||
`Rint` | `T={double,float}`
|
||||
|
@ -6,6 +6,9 @@ Operator | Type Constraint
|
||||
`Acosh` | `T={complex64,double,float}`
|
||||
`Add` | `T={complex64,double,float,int32,int64}`
|
||||
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`AdjustContrastv2` |
|
||||
`AdjustHue` |
|
||||
`AdjustSaturation` |
|
||||
`All` | `Tidx={int32,int64}`
|
||||
`Angle` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Any` | `Tidx={int32,int64}`
|
||||
@ -34,7 +37,7 @@ Operator | Type Constraint
|
||||
`BroadcastGradientArgs` | `T={int32,int64}`
|
||||
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Ceil` | `T={double,float}`
|
||||
`Cholesky` | `T={complex64,double,float}`
|
||||
`Cholesky` | `T={double,float}`
|
||||
`Complex` | `Tout={complex64}`<br>`T={double,float}`
|
||||
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
|
||||
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
@ -68,7 +71,11 @@ Operator | Type Constraint
|
||||
`Exp` | `T={complex64,double,float}`
|
||||
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Expm1` | `T={complex64,double,float}`
|
||||
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`FFT` |
|
||||
`FFT2D` |
|
||||
`FFT3D` |
|
||||
`Fill` | `index_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Floor` | `T={double,float}`
|
||||
`FloorDiv` | `T={complex64,double,float,int32,int64}`
|
||||
`FloorMod` | `T={double,float,int32,int64}`
|
||||
@ -80,6 +87,13 @@ Operator | Type Constraint
|
||||
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`HSVToRGB` | `T={double,float}`
|
||||
`IFFT` |
|
||||
`IFFT2D` |
|
||||
`IFFT3D` |
|
||||
`IRFFT` |
|
||||
`IRFFT2D` |
|
||||
`IRFFT3D` |
|
||||
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Imag` | `Tout={double,float}`<br>`T={complex64}`
|
||||
@ -105,11 +119,14 @@ Operator | Type Constraint
|
||||
`MatMul` | `T={complex64,double,float}`
|
||||
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MatrixTriangularSolve` | `T={complex64,double,float}`
|
||||
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`MaxPool` | `T={double,float,int32,int64}`
|
||||
`MaxPool3D` | `T={float}`
|
||||
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
|
||||
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`MaxPoolV2` | `T={double,float,int32,int64}`
|
||||
`Maximum` | `T={double,float,int32,int64}`
|
||||
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
@ -131,6 +148,10 @@ Operator | Type Constraint
|
||||
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`QuantizeAndDequantizeV2` | `T={double,float}`
|
||||
`RFFT` |
|
||||
`RFFT2D` |
|
||||
`RFFT3D` |
|
||||
`RGBToHSV` | `T={double,float}`
|
||||
`Range` | `Tidx={double,float,int32,int64}`
|
||||
`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
@ -143,6 +164,8 @@ Operator | Type Constraint
|
||||
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
|
||||
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ResizeBilinear` | `T={double,float,int32,int64}`
|
||||
`ResizeBilinearGrad` | `T={double,float}`
|
||||
`ResourceApplyAdagrad` | `T={double,float}`
|
||||
`ResourceApplyAdam` | `T={double,float}`
|
||||
`ResourceApplyFtrl` | `T={double,float}`
|
||||
@ -153,6 +176,7 @@ Operator | Type Constraint
|
||||
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
|
||||
`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
|
||||
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
|
||||
`RightShift` | `T={int32,int64,uint32,uint64}`
|
||||
`Rint` | `T={double,float}`
|
||||
|
@ -60,9 +60,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
for (int i = 0; i < args->size(); ++i) {
|
||||
XlaCompiler::Argument& arg = (*args)[i];
|
||||
arg.type = ctx->input_type(i);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
|
||||
arg.shape = ctx->InputShape(i);
|
||||
|
||||
if (arg.type == DT_RESOURCE) {
|
||||
return errors::InvalidArgument(
|
||||
|
@ -31,6 +31,7 @@ tf_kernel_library(
|
||||
"diag_op.cc",
|
||||
"dynamic_stitch_op.cc",
|
||||
"elu_op.cc",
|
||||
"extract_image_patches_op.cc",
|
||||
"fft_ops.cc",
|
||||
"fill_op.cc",
|
||||
"function_ops.cc",
|
||||
@ -43,6 +44,9 @@ tf_kernel_library(
|
||||
"l2loss_op.cc",
|
||||
"lrn_ops.cc",
|
||||
"matmul_op.cc",
|
||||
"matrix_band_part_op.cc",
|
||||
"matrix_set_diag_op.cc",
|
||||
"matrix_triangular_solve_op.cc",
|
||||
"mirror_pad_op.cc",
|
||||
"no_op.cc",
|
||||
"one_hot_op.cc",
|
||||
@ -58,6 +62,7 @@ tf_kernel_library(
|
||||
"reshape_op.cc",
|
||||
"retval_op.cc",
|
||||
"reverse_op.cc",
|
||||
"reverse_sequence_op.cc",
|
||||
"scan_ops.cc",
|
||||
"segment_reduction_ops.cc",
|
||||
"select_op.cc",
|
||||
@ -92,6 +97,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/lib:batch_dot",
|
||||
"//tensorflow/compiler/tf2xla/lib:cholesky",
|
||||
"//tensorflow/compiler/tf2xla/lib:triangular_solve",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/tf2xla/ops:sendrecv_ops",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
|
@ -28,8 +28,9 @@ class BatchMatMulOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto result =
|
||||
BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_);
|
||||
auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1),
|
||||
/*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_,
|
||||
/*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_);
|
||||
OP_REQUIRES_OK(ctx, result.status());
|
||||
ctx->SetOutput(0, result.ValueOrDie());
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ class CholeskyOp : public XlaOpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp);
|
||||
REGISTER_XLA_OP(Name("Cholesky").TypeConstraint("T", kFloatTypes), CholeskyOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
169
tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
Normal file
169
tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc
Normal file
@ -0,0 +1,169 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
class ExtractImagePatchesOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
const TensorFormat data_format = FORMAT_NHWC;
|
||||
const int num_dims = ksizes_.size();
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, num_dims >= 3,
|
||||
errors::InvalidArgument("Kernel size must have at least 3 dimensions"));
|
||||
const int num_spatial_dims = num_dims - 2;
|
||||
|
||||
OP_REQUIRES(ctx, strides_.size() == num_dims,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify ",
|
||||
num_dims, " dimensions"));
|
||||
OP_REQUIRES(ctx, dilations_.size() == num_dims,
|
||||
errors::InvalidArgument("Dilations field must "
|
||||
"specify ",
|
||||
num_dims, " dimensions"));
|
||||
|
||||
int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
|
||||
int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format);
|
||||
OP_REQUIRES(
|
||||
ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
|
||||
errors::Unimplemented("Current implementation does not yet support "
|
||||
"kernel sizes > 1 in the batch and depth "
|
||||
"dimensions."));
|
||||
OP_REQUIRES(
|
||||
ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
|
||||
errors::Unimplemented("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions."));
|
||||
OP_REQUIRES(
|
||||
ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
|
||||
errors::Unimplemented("Current implementation does not support "
|
||||
"dilations in the batch and depth dimensions."));
|
||||
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
|
||||
OP_REQUIRES(
|
||||
ctx, ksizes_[input_dim] >= 0,
|
||||
errors::Unimplemented("Kernel size values must be non-negative; ", i,
|
||||
"th spatial dimension had dilation ",
|
||||
dilations_[input_dim]));
|
||||
OP_REQUIRES(ctx, strides_[input_dim] >= 1,
|
||||
errors::Unimplemented("Stride values must be positive; ", i,
|
||||
"th spatial dimension had dilation ",
|
||||
dilations_[input_dim]));
|
||||
OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
|
||||
errors::Unimplemented("Dilation values must be positive; ", i,
|
||||
"th spatial dimension had dilation ",
|
||||
dilations_[input_dim]));
|
||||
}
|
||||
|
||||
xla::PrimitiveType type;
|
||||
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type));
|
||||
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
OP_REQUIRES(
|
||||
ctx, input_shape.dims() == num_dims,
|
||||
errors::InvalidArgument("input must be ", num_dims, "-dimensional",
|
||||
input_shape.DebugString()));
|
||||
const int64 depth = input_shape.dim_size(feature_dim);
|
||||
|
||||
xla::ComputationBuilder* builder = ctx->builder();
|
||||
|
||||
// The following code is equivalent to:
|
||||
// eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
|
||||
int64 kernel_size = 1;
|
||||
std::vector<int64> lhs_shape(num_dims, 1);
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
|
||||
lhs_shape[i] = ksizes_[input_dim];
|
||||
kernel_size *= ksizes_[input_dim];
|
||||
}
|
||||
lhs_shape[num_spatial_dims] = depth;
|
||||
lhs_shape[num_spatial_dims + 1] = 1;
|
||||
|
||||
// Builds an identity matrix as a broadcast equality of iotas.
|
||||
// iota = np.arange(np.prod(ksize), depth)
|
||||
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
|
||||
xla::ComputationDataHandle iota;
|
||||
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
|
||||
kernel_size * depth, &iota));
|
||||
|
||||
auto lhs = builder->Reshape(iota, lhs_shape);
|
||||
auto filter = builder->ConvertElementType(
|
||||
builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
|
||||
|
||||
xla::ConvolutionDimensionNumbers dims;
|
||||
std::vector<int64> window_strides(num_spatial_dims);
|
||||
std::vector<int64> lhs_dilation(num_spatial_dims, 1);
|
||||
std::vector<int64> rhs_dilation(num_spatial_dims);
|
||||
std::vector<std::pair<int64, int64>> padding(num_spatial_dims);
|
||||
|
||||
dims.set_input_batch_dimension(batch_dim);
|
||||
dims.set_output_batch_dimension(batch_dim);
|
||||
dims.set_input_feature_dimension(feature_dim);
|
||||
dims.set_output_feature_dimension(feature_dim);
|
||||
dims.set_kernel_input_feature_dimension(num_spatial_dims);
|
||||
dims.set_kernel_output_feature_dimension(num_spatial_dims + 1);
|
||||
|
||||
for (int i = 0; i < num_spatial_dims; ++i) {
|
||||
const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
|
||||
dims.add_input_spatial_dimensions(dim);
|
||||
dims.add_kernel_spatial_dimensions(i);
|
||||
dims.add_output_spatial_dimensions(dim);
|
||||
window_strides[i] = strides_.at(dim);
|
||||
rhs_dilation[i] = dilations_.at(dim);
|
||||
|
||||
int64 unused_output_size;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, GetWindowedOutputSizeVerboseV2(
|
||||
input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i],
|
||||
window_strides[i], padding_, &unused_output_size,
|
||||
&padding[i].first, &padding[i].second));
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle conv =
|
||||
builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
|
||||
padding, lhs_dilation, rhs_dilation, dims);
|
||||
ctx->SetOutput(0, conv);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<int32> ksizes_;
|
||||
std::vector<int32> dilations_;
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
98
tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
Normal file
98
tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc
Normal file
@ -0,0 +1,98 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class MatrixBandPartOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixBandPartOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
const TensorShape input_shape = context->InputShape(0);
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
|
||||
errors::InvalidArgument(
|
||||
"input must be at least 2-dim, received shape: ",
|
||||
input_shape.DebugString()));
|
||||
|
||||
const TensorShape num_lower_in_shape = context->InputShape(1);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape),
|
||||
errors::InvalidArgument("num_lower must be scalar, got shape ",
|
||||
num_lower_in_shape.DebugString()));
|
||||
|
||||
const TensorShape num_upper_in_shape = context->InputShape(2);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape),
|
||||
errors::InvalidArgument("num_upper must be scalar, got shape ",
|
||||
num_upper_in_shape.DebugString()));
|
||||
|
||||
xla::ComputationBuilder* builder = context->builder();
|
||||
xla::ComputationDataHandle input = context->Input(0);
|
||||
xla::ComputationDataHandle num_lower = context->Input(1);
|
||||
xla::ComputationDataHandle num_upper = context->Input(2);
|
||||
DataType input_type = context->input_type(0);
|
||||
DataType index_type = context->input_type(1);
|
||||
|
||||
TensorShape batch_shape = input_shape;
|
||||
batch_shape.RemoveLastDims(2);
|
||||
const int64 m = input_shape.dim_size(input_shape.dims() - 2);
|
||||
const int64 n = input_shape.dim_size(input_shape.dims() - 1);
|
||||
|
||||
// Compute 'offset', which is how many diagonals we are above/below the
|
||||
// diagonal.
|
||||
xla::ComputationDataHandle iota_m;
|
||||
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
|
||||
|
||||
xla::ComputationDataHandle iota_n;
|
||||
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
|
||||
|
||||
auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
|
||||
/*broadcast_dimensions=*/{0});
|
||||
|
||||
// If num_lower or num_upper are negative, include all lower/upper
|
||||
// diagonals.
|
||||
auto zero_index = XlaHelpers::Zero(builder, index_type);
|
||||
num_lower = builder->Select(
|
||||
builder->Lt(num_lower, zero_index),
|
||||
XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower);
|
||||
num_upper = builder->Select(
|
||||
builder->Lt(num_upper, zero_index),
|
||||
XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper);
|
||||
|
||||
auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset),
|
||||
builder->Le(offset, num_upper));
|
||||
indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
|
||||
|
||||
auto zero_input = XlaHelpers::Zero(builder, input_type);
|
||||
auto output = builder->Select(
|
||||
indicator, input,
|
||||
builder->Broadcast(zero_input, input_shape.dim_sizes()));
|
||||
|
||||
context->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp);
|
||||
};
|
||||
REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
93
tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
Normal file
93
tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc
Normal file
@ -0,0 +1,93 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class MatrixSetDiagOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixSetDiagOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
const TensorShape input_shape = context->InputShape(0);
|
||||
const TensorShape diag_shape = context->InputShape(1);
|
||||
|
||||
const int rank = input_shape.dims();
|
||||
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
|
||||
errors::InvalidArgument(
|
||||
"input must be at least 2-dim, received shape: ",
|
||||
input_shape.DebugString()));
|
||||
|
||||
// Check to make sure the last dimension of diag is equal to the smaller of
|
||||
// the last two dimensions of input.
|
||||
const int64 m = input_shape.dim_size(rank - 2);
|
||||
const int64 n = input_shape.dim_size(rank - 1);
|
||||
const int64 min_dim = std::min(m, n);
|
||||
|
||||
TensorShape batch_shape = input_shape;
|
||||
batch_shape.RemoveLastDims(2);
|
||||
|
||||
TensorShape expected_diag_shape = batch_shape;
|
||||
expected_diag_shape.AddDim(min_dim);
|
||||
OP_REQUIRES(context, expected_diag_shape == diag_shape,
|
||||
errors::InvalidArgument(
|
||||
"must have diagonal.shape == input.shape[:-2] + "
|
||||
"min(input.shape[-2:]), but received input shape: ",
|
||||
input_shape.DebugString(),
|
||||
" and diagonal shape: ", diag_shape.DebugString()));
|
||||
|
||||
xla::ComputationBuilder* builder = context->builder();
|
||||
xla::ComputationDataHandle input = context->Input(0);
|
||||
xla::ComputationDataHandle diag = context->Input(1);
|
||||
|
||||
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
|
||||
|
||||
// Create an indicator tensor that is true only on the diagonal.
|
||||
xla::ComputationDataHandle iota_m;
|
||||
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
|
||||
xla::ComputationDataHandle iota_n;
|
||||
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
|
||||
auto indicator = builder->Eq(iota_m,
|
||||
builder->Broadcast(iota_n, {m}),
|
||||
/*broadcast_dimensions=*/{0});
|
||||
indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
|
||||
|
||||
// Broadcast diag up to the input shape. Use an implicit broadcast (Add)
|
||||
// because we need to broadcast on the right.
|
||||
std::vector<int64> diag_broadcast_dims(rank - 1);
|
||||
std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0);
|
||||
if (min_dim != m) {
|
||||
diag_broadcast_dims.back() = rank - 1;
|
||||
}
|
||||
diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()),
|
||||
/*broadcast_dimensions=*/diag_broadcast_dims);
|
||||
|
||||
auto output = builder->Select(indicator, diag, input);
|
||||
context->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class MatrixTriangularSolveOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit MatrixTriangularSolveOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto result = TriangularSolve(
|
||||
ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true,
|
||||
/*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_);
|
||||
if (!result.ok()) {
|
||||
ctx->SetStatus(result.status());
|
||||
return;
|
||||
}
|
||||
ctx->SetOutput(0, result.ValueOrDie());
|
||||
}
|
||||
|
||||
private:
|
||||
bool lower_;
|
||||
bool adjoint_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -37,21 +37,23 @@ class PoolingOp : public XlaOpKernel {
|
||||
public:
|
||||
PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims)
|
||||
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
|
||||
std::vector<int32> ksize_int;
|
||||
std::vector<int32> stride_int;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
|
||||
OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
|
||||
OP_REQUIRES(ctx, stride_int.size() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
for (int i = 0; i < num_dims(); ++i) {
|
||||
ksize_.push_back(ksize_int[i]);
|
||||
stride_.push_back(stride_int[i]);
|
||||
if (ctx->num_inputs() == 1) {
|
||||
std::vector<int32> ksize_int;
|
||||
std::vector<int32> stride_int;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
|
||||
OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
|
||||
OP_REQUIRES(ctx, stride_int.size() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
for (int i = 0; i < num_dims(); ++i) {
|
||||
ksize_.push_back(ksize_int[i]);
|
||||
stride_.push_back(stride_int[i]);
|
||||
}
|
||||
}
|
||||
Padding padding;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
|
||||
@ -77,6 +79,33 @@ class PoolingOp : public XlaOpKernel {
|
||||
xla::ComputationDataHandle input = ctx->Input(0);
|
||||
const TensorShape input_shape = ctx->InputShape(0);
|
||||
|
||||
std::vector<int64> ksize = ksize_;
|
||||
std::vector<int64> stride = stride_;
|
||||
if (ctx->num_inputs() != 1) {
|
||||
const TensorShape ksize_shape = ctx->InputShape(1);
|
||||
// Validate input sizes.
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
|
||||
errors::InvalidArgument("ksize must be a vector, not shape ",
|
||||
ksize_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
ksize.clear();
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
|
||||
|
||||
const TensorShape stride_shape = ctx->InputShape(2);
|
||||
// Validate input sizes.
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
|
||||
errors::InvalidArgument("stride must be a vector, not shape ",
|
||||
stride_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window stride field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
stride.clear();
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
|
||||
}
|
||||
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
|
||||
errors::InvalidArgument("Input to ", type_string(),
|
||||
" operator must have ", num_dims(),
|
||||
@ -84,8 +113,8 @@ class PoolingOp : public XlaOpKernel {
|
||||
|
||||
const DataType type = input_type(0);
|
||||
xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow(
|
||||
input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize_,
|
||||
stride_, padding_);
|
||||
input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize,
|
||||
stride, padding_);
|
||||
ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape));
|
||||
}
|
||||
|
||||
@ -130,6 +159,10 @@ class MaxPool2DOp : public MaxPoolOp {
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
|
||||
REGISTER_XLA_OP(Name("MaxPoolV2")
|
||||
.CompileTimeConstInput("ksize")
|
||||
.CompileTimeConstInput("strides"),
|
||||
MaxPool2DOp);
|
||||
|
||||
class MaxPool3DOp : public MaxPoolOp {
|
||||
public:
|
||||
@ -243,22 +276,44 @@ class MaxPoolGradOp : public XlaOpKernel {
|
||||
public:
|
||||
MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
|
||||
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES(ctx, ksize_.size() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
|
||||
OP_REQUIRES(ctx, stride_.size() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
if (ctx->num_inputs() == 3) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
int num_dims() const { return num_spatial_dims_ + 2; }
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
if (ctx->num_inputs() != 3) {
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->num_inputs() == 5,
|
||||
errors::InvalidArgument("Must supply ksize and stride arguments."));
|
||||
const TensorShape ksize_shape = ctx->InputShape(3);
|
||||
// Validate input sizes.
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
|
||||
errors::InvalidArgument("ksize must be a vector, not shape ",
|
||||
ksize_shape.DebugString()));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
|
||||
|
||||
const TensorShape stride_shape = ctx->InputShape(4);
|
||||
// Validate input sizes.
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
|
||||
errors::InvalidArgument("stride must be a vector, not shape ",
|
||||
stride_shape.DebugString()));
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
|
||||
}
|
||||
|
||||
OP_REQUIRES(ctx, ksize_.size() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window ksize field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
OP_REQUIRES(ctx, stride_.size() == num_dims(),
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify ",
|
||||
num_dims(), " dimensions"));
|
||||
|
||||
const TensorShape tensor_in_shape = ctx->InputShape(0);
|
||||
const TensorShape tensor_out_shape = ctx->InputShape(1);
|
||||
const TensorShape out_backprop_shape = ctx->InputShape(2);
|
||||
@ -315,6 +370,10 @@ class MaxPool2DGradOp : public MaxPoolGradOp {
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
|
||||
REGISTER_XLA_OP(Name("MaxPoolGradV2")
|
||||
.CompileTimeConstInput("ksize")
|
||||
.CompileTimeConstInput("strides"),
|
||||
MaxPool2DGradOp);
|
||||
|
||||
class MaxPool3DGradOp : public MaxPoolGradOp {
|
||||
public:
|
||||
|
182
tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
Normal file
182
tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc
Normal file
@ -0,0 +1,182 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class ReverseSequenceOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ReverseSequenceOp(OpKernelConstruction* context)
|
||||
: XlaOpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
const TensorShape input_shape = context->InputShape(0);
|
||||
const TensorShape seq_lens_shape = context->InputShape(1);
|
||||
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape),
|
||||
errors::InvalidArgument("seq_lens input must be 1-dim, not ",
|
||||
seq_lens_shape.dims()));
|
||||
OP_REQUIRES(context, batch_dim_ != seq_dim_,
|
||||
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_));
|
||||
OP_REQUIRES(
|
||||
context, seq_dim_ < input_shape.dims(),
|
||||
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
|
||||
seq_dim_, " vs. ", input_shape.dims(), ")"));
|
||||
OP_REQUIRES(
|
||||
context, batch_dim_ < input_shape.dims(),
|
||||
errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
|
||||
batch_dim_, " vs. ", input_shape.dims(), ")"));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_),
|
||||
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_,
|
||||
"), ", "(", seq_lens_shape.num_elements(),
|
||||
" vs. ", input_shape.dim_size(batch_dim_)));
|
||||
|
||||
xla::ComputationBuilder* builder = context->builder();
|
||||
const auto input = context->Input(0);
|
||||
const auto seq_lens = context->Input(1);
|
||||
|
||||
const int64 batch_size = input_shape.dim_size(batch_dim_);
|
||||
|
||||
const DataType input_type = context->input_type(0);
|
||||
const DataType seq_lens_type = context->input_type(1);
|
||||
const int64 max_seq_len = input_shape.dim_size(seq_dim_);
|
||||
|
||||
xla::Shape input_xla_shape;
|
||||
OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape,
|
||||
&input_xla_shape));
|
||||
xla::Shape seq_lens_xla_shape;
|
||||
OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape,
|
||||
&seq_lens_xla_shape));
|
||||
|
||||
const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({
|
||||
xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}),
|
||||
seq_lens_xla_shape,
|
||||
input_xla_shape,
|
||||
});
|
||||
|
||||
// For each entry in the batch, reverse the sequence.
|
||||
// TODO(b/65689298): generalize the Map() operator to non-scalar cases and
|
||||
// use it here, instead of a While loop.
|
||||
|
||||
// Condition: lambda (i, _, _): i < batch_size
|
||||
auto condition_builder =
|
||||
builder->CreateSubBuilder("reverse_sequence_condition");
|
||||
{
|
||||
auto param = condition_builder->Parameter(0, tuple_shape, "param");
|
||||
auto i = condition_builder->GetTupleElement(param, 0);
|
||||
condition_builder->Lt(
|
||||
i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type,
|
||||
batch_size));
|
||||
}
|
||||
auto condition = condition_builder->Build();
|
||||
OP_REQUIRES_OK(context, condition.status());
|
||||
|
||||
auto body_builder = builder->CreateSubBuilder("reverse_sequence_body");
|
||||
{
|
||||
auto param = body_builder->Parameter(0, tuple_shape, "param");
|
||||
auto i = body_builder->GetTupleElement(param, 0);
|
||||
auto seq_lens = body_builder->GetTupleElement(param, 1);
|
||||
auto output = body_builder->GetTupleElement(param, 2);
|
||||
|
||||
// seq_len is the sequence length of the current batch element (rank 1)
|
||||
auto seq_len = body_builder->DynamicSlice(
|
||||
seq_lens, body_builder->Reshape(i, {1}), {1});
|
||||
|
||||
// Indices is the offset of the batch element in the input.
|
||||
auto indices = body_builder->Broadcast(
|
||||
XlaHelpers::Zero(body_builder.get(), seq_lens_type),
|
||||
{input_shape.dims()});
|
||||
indices = body_builder->DynamicUpdateSlice(
|
||||
indices, body_builder->Reshape(i, {1}),
|
||||
body_builder->Reshape(
|
||||
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
|
||||
batch_dim_),
|
||||
{1}));
|
||||
|
||||
// slice_indices is the offset of the start of the reversed sequence in
|
||||
// the input.
|
||||
auto slice_indices = body_builder->DynamicUpdateSlice(
|
||||
indices,
|
||||
body_builder->Sub(XlaHelpers::IntegerLiteral(
|
||||
body_builder.get(), seq_lens_type, max_seq_len),
|
||||
seq_len),
|
||||
body_builder->Reshape(
|
||||
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
|
||||
seq_dim_),
|
||||
{1}));
|
||||
|
||||
// Slice out the reversed sequence. The slice will overflow the end of the
|
||||
// sequence, and the contents of the overflow are implementation-defined.
|
||||
// However, we will mask off these elements and replace them with elements
|
||||
// from the original input so their values do not matter.
|
||||
TensorShape slice_shape = input_shape;
|
||||
slice_shape.set_dim(batch_dim_, 1);
|
||||
auto slice = body_builder->DynamicSlice(output, slice_indices,
|
||||
slice_shape.dim_sizes());
|
||||
|
||||
// Shift the reversed sequence to the left.
|
||||
output = body_builder->DynamicUpdateSlice(output, slice, indices);
|
||||
|
||||
body_builder->Tuple(
|
||||
{body_builder->Add(
|
||||
i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
|
||||
seq_lens, output});
|
||||
}
|
||||
auto body = body_builder->Build();
|
||||
OP_REQUIRES_OK(context, body.status());
|
||||
|
||||
auto loop_output = builder->While(
|
||||
condition.ValueOrDie(), body.ValueOrDie(),
|
||||
builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
|
||||
builder->Rev(input, {seq_dim_})}));
|
||||
auto output = builder->GetTupleElement(loop_output, 2);
|
||||
|
||||
// Mask out elements after the sequence length.
|
||||
xla::ComputationDataHandle iota;
|
||||
OP_REQUIRES_OK(
|
||||
context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
|
||||
std::vector<int64> dims(input_shape.dims(), 1);
|
||||
dims[batch_dim_] = batch_size;
|
||||
auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_});
|
||||
|
||||
// Broadcast the mask up to the input shape.
|
||||
mask =
|
||||
builder->Or(mask, builder->Broadcast(builder->ConstantR0<bool>(false),
|
||||
input_shape.dim_sizes()));
|
||||
|
||||
output = builder->Select(mask, output, input);
|
||||
context->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
int32 batch_dim_;
|
||||
int32 seq_dim_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -77,10 +77,8 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder,
|
||||
// Stack has not been initialized.
|
||||
xla::ComputationDataHandle zero =
|
||||
XlaHelpers::Zero(builder, resource->type());
|
||||
TF_RETURN_IF_ERROR(resource->SetValue(
|
||||
dtype,
|
||||
builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()),
|
||||
builder->ConstantR0<int32>(0)})));
|
||||
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
|
||||
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
|
||||
} else {
|
||||
// Checks the expected shape matches the actual shape.
|
||||
TensorShape actual_shape;
|
||||
@ -119,8 +117,8 @@ class StackOp : public XlaOpKernel {
|
||||
string name = strings::StrCat("Stack: ", stack_name_);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
|
||||
value, &resource));
|
||||
resource->set_tensor_array_size(size);
|
||||
TensorShape(), value, /*tensor_array_size=*/size,
|
||||
/*tensor_array_gradients=*/{}, &resource));
|
||||
ctx->SetResourceOutput(0, resource);
|
||||
}
|
||||
|
||||
@ -164,11 +162,9 @@ class StackPushOp : public XlaOpKernel {
|
||||
|
||||
// TODO(phawkins): We don't check the index is in bounds --- there is no
|
||||
// error mechanism in XLA.
|
||||
OP_REQUIRES_OK(
|
||||
ctx,
|
||||
resource->SetValue(
|
||||
dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices),
|
||||
b->Add(index, b->ConstantR0<int32>(1))})));
|
||||
OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple(
|
||||
{b->DynamicUpdateSlice(ta, update, start_indices),
|
||||
b->Add(index, b->ConstantR0<int32>(1))})));
|
||||
|
||||
ctx->SetOutput(0, value);
|
||||
}
|
||||
@ -208,7 +204,7 @@ class StackPopOp : public XlaOpKernel {
|
||||
xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
|
||||
|
||||
index = b->Sub(index, b->ConstantR0<int32>(1));
|
||||
OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index})));
|
||||
OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
|
||||
|
||||
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
|
||||
auto start_indices =
|
||||
|
@ -231,6 +231,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
@ -252,9 +253,9 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
|
||||
&strides_tensor));
|
||||
|
||||
DataType lhs_type;
|
||||
TensorShape lhs_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
|
||||
xla::ComputationDataHandle lhs;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
|
||||
|
||||
const TensorShape rhs_shape = ctx->InputShape(4);
|
||||
|
||||
@ -282,9 +283,6 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
||||
" does not match r-value shape ", rhs_shape.DebugString(),
|
||||
". Automatic broadcasting not yet implemented."));
|
||||
|
||||
xla::ComputationDataHandle lhs;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
|
||||
|
||||
xla::ComputationDataHandle rhs = ctx->Input(4);
|
||||
|
||||
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
|
||||
@ -320,13 +318,14 @@ class StridedSliceAssignOp : public XlaOpKernel {
|
||||
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
|
||||
}
|
||||
|
||||
private:
|
||||
int32 begin_mask_, end_mask_;
|
||||
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
|
||||
DataType index_type_;
|
||||
DataType dtype_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
|
||||
|
@ -62,15 +62,13 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
|
||||
|
||||
TF_RET_CHECK(resource->tensor_array_size() >= 0)
|
||||
<< resource->name() << " size " << resource->tensor_array_size();
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(resource->tensor_array_size());
|
||||
ta_shape.AppendShape(elem_shape);
|
||||
|
||||
if (!resource->initialized()) {
|
||||
xla::ComputationDataHandle zero =
|
||||
XlaHelpers::Zero(builder, resource->type());
|
||||
TF_RETURN_IF_ERROR(resource->SetValue(
|
||||
dtype, builder->Broadcast(zero, ta_shape.dim_sizes())));
|
||||
|
||||
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
|
||||
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
|
||||
} else {
|
||||
// Checks the elem_shape matches the TensorArray shape.
|
||||
auto shape_or_status = builder->GetShape(resource->value());
|
||||
@ -80,6 +78,10 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
|
||||
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(resource->tensor_array_size());
|
||||
ta_shape.AppendShape(elem_shape);
|
||||
if (ta_shape != shape) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
|
||||
@ -114,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
|
||||
Status GetTensorArrayShape(const XlaResource* resource,
|
||||
xla::ComputationBuilder* builder,
|
||||
TensorShape* shape) {
|
||||
TF_RETURN_IF_ERROR(resource->GetShape(builder, shape));
|
||||
if (shape->dims() < 1) {
|
||||
return errors::InvalidArgument("TensorArray rank must be >= 1");
|
||||
}
|
||||
*shape = resource->shape();
|
||||
shape->InsertDim(0, resource->tensor_array_size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -160,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel {
|
||||
// Initializes the TensorArray value if we know the element shape.
|
||||
// Otherwise, defer initialization to the first write.
|
||||
xla::ComputationDataHandle value;
|
||||
TensorShape shape;
|
||||
if (element_shape_.IsFullyDefined()) {
|
||||
TensorShape shape;
|
||||
CHECK(element_shape_.AsTensorShape(&shape));
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(size);
|
||||
@ -175,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel {
|
||||
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
|
||||
dtype_, value, &var));
|
||||
var->set_tensor_array_size(size);
|
||||
dtype_, shape, value, /*tensor_array_size=*/size,
|
||||
/*tensor_array_gradients=*/{}, &var));
|
||||
ctx->SetResourceOutput(0, var);
|
||||
|
||||
Tensor flow(DT_FLOAT, TensorShape({}));
|
||||
@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
|
||||
xla::ComputationDataHandle written =
|
||||
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
|
||||
|
||||
OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written));
|
||||
OP_REQUIRES_OK(ctx, resource->SetValue(written));
|
||||
ctx->SetOutput(0, flow);
|
||||
}
|
||||
|
||||
@ -421,7 +421,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta));
|
||||
OP_REQUIRES_OK(ctx, resource->SetValue(ta));
|
||||
ctx->SetOutput(0, flow);
|
||||
}
|
||||
|
||||
@ -525,9 +525,8 @@ class TensorArraySplitOp : public XlaOpKernel {
|
||||
value_shape.DebugString(), " vs. ",
|
||||
ta_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx, resource->SetValue(
|
||||
dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()))));
|
||||
OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
|
||||
ta, b->Reshape(value, ta_shape.dim_sizes()))));
|
||||
|
||||
ctx->SetOutput(0, flow);
|
||||
}
|
||||
|
@ -32,9 +32,24 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle handle;
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
|
||||
DataType type = ctx->input_type(1);
|
||||
TensorShape var_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
|
||||
|
||||
TensorShape alpha_shape = ctx->InputShape(1);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
|
||||
errors::InvalidArgument("alpha is not a scalar: ",
|
||||
alpha_shape.DebugString()));
|
||||
|
||||
TensorShape delta_shape = ctx->InputShape(2);
|
||||
OP_REQUIRES(
|
||||
ctx, var_shape.IsSameSize(delta_shape),
|
||||
errors::InvalidArgument("var and delta do not have the same shape: ",
|
||||
var_shape.DebugString(), " vs ",
|
||||
delta_shape.DebugString()));
|
||||
|
||||
handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(
|
||||
@ -52,18 +67,10 @@ class ResourceApplyMomentum : public XlaOpKernel {
|
||||
|
||||
DataType type = ctx->input_type(2);
|
||||
|
||||
DataType var_type, accum_type;
|
||||
TensorShape var_shape, accum_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, type == var_type && type == accum_type,
|
||||
errors::InvalidArgument(
|
||||
"Types of variable arguments to ResourceApplyMomentum must match: ",
|
||||
DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
|
||||
DataTypeString(accum_type)));
|
||||
xla::ComputationDataHandle var, accum;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
|
||||
|
||||
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
||||
errors::InvalidArgument(
|
||||
@ -86,10 +93,6 @@ class ResourceApplyMomentum : public XlaOpKernel {
|
||||
errors::InvalidArgument("momentum is not a scalar: ",
|
||||
momentum_shape.DebugString()));
|
||||
|
||||
xla::ComputationDataHandle var, accum;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
|
||||
|
||||
xla::ComputationDataHandle lr = ctx->Input(2);
|
||||
xla::ComputationDataHandle grad = ctx->Input(3);
|
||||
xla::ComputationDataHandle momentum = ctx->Input(4);
|
||||
@ -122,18 +125,10 @@ class ResourceApplyAdagrad : public XlaOpKernel {
|
||||
|
||||
DataType type = ctx->input_type(2);
|
||||
|
||||
DataType var_type, accum_type;
|
||||
TensorShape var_shape, accum_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, type == var_type && type == accum_type,
|
||||
errors::InvalidArgument(
|
||||
"Types of variable arguments to ResourceApplyAdagrad must match: ",
|
||||
DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
|
||||
DataTypeString(accum_type)));
|
||||
xla::ComputationDataHandle var, accum;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
|
||||
|
||||
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
||||
errors::InvalidArgument(
|
||||
@ -151,9 +146,6 @@ class ResourceApplyAdagrad : public XlaOpKernel {
|
||||
"var and grad do not have the same shape",
|
||||
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
||||
|
||||
xla::ComputationDataHandle var, accum;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
|
||||
xla::ComputationDataHandle lr = ctx->Input(2);
|
||||
xla::ComputationDataHandle grad = ctx->Input(3);
|
||||
|
||||
@ -175,18 +167,11 @@ class ResourceApplyAdam : public XlaOpKernel {
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType var_type, m_type, v_type;
|
||||
TensorShape var_shape, m_shape, v_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type,
|
||||
errors::InvalidArgument(
|
||||
"Types of variable arguments to ResourceApplyRMSProp must match: ",
|
||||
DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ",
|
||||
DataTypeString(m_type), " vs. ", DataTypeString(v_type)));
|
||||
xla::ComputationDataHandle var, m, v;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
|
||||
|
||||
TensorShape beta1_power_shape = ctx->InputShape(3);
|
||||
TensorShape beta2_power_shape = ctx->InputShape(4);
|
||||
@ -228,10 +213,6 @@ class ResourceApplyAdam : public XlaOpKernel {
|
||||
"var and grad do not have the same shape",
|
||||
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
||||
|
||||
xla::ComputationDataHandle var, m, v;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v));
|
||||
xla::ComputationDataHandle beta1_power = ctx->Input(3);
|
||||
xla::ComputationDataHandle beta2_power = ctx->Input(4);
|
||||
xla::ComputationDataHandle lr = ctx->Input(5);
|
||||
@ -278,18 +259,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
|
||||
|
||||
DataType type = ctx->input_type(3);
|
||||
|
||||
DataType var_type, ms_type, mom_type;
|
||||
TensorShape var_shape, ms_shape, mom_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, type == var_type && type == ms_type && type == mom_type,
|
||||
errors::InvalidArgument(
|
||||
"Types of variable arguments to ResourceApplyRMSProp must match: ",
|
||||
DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ",
|
||||
DataTypeString(ms_type), " vs. ", DataTypeString(mom_type)));
|
||||
xla::ComputationDataHandle var, ms, mom;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
|
||||
|
||||
TensorShape lr_shape = ctx->InputShape(3);
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
|
||||
@ -323,10 +297,6 @@ class ResourceApplyRMSProp : public XlaOpKernel {
|
||||
"var and grad do not have the same shape",
|
||||
var_shape.DebugString(), " ", grad_shape.DebugString()));
|
||||
|
||||
xla::ComputationDataHandle var, ms, mom;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom));
|
||||
xla::ComputationDataHandle lr = ctx->Input(3);
|
||||
xla::ComputationDataHandle rho = ctx->Input(4);
|
||||
xla::ComputationDataHandle momentum = ctx->Input(5);
|
||||
@ -373,20 +343,11 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
|
||||
bool has_l2_shrinkage) {
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
|
||||
DataType var_type, accum_type, linear_type;
|
||||
TensorShape var_shape, accum_shape, linear_shape;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape));
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, dtype == var_type && dtype == accum_type && dtype == linear_type,
|
||||
errors::InvalidArgument(
|
||||
"Types of variable arguments to ResourceApplyFtrlV2 must match: ",
|
||||
DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ",
|
||||
DataTypeString(accum_type), " and ", DataTypeString(linear_type)));
|
||||
xla::ComputationDataHandle var, accum, linear;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
|
||||
|
||||
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
|
||||
errors::InvalidArgument(
|
||||
@ -438,10 +399,6 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
|
||||
errors::InvalidArgument("lr_power is not a scalar: ",
|
||||
lr_power_shape.DebugString()));
|
||||
|
||||
xla::ComputationDataHandle var, accum, linear;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear));
|
||||
xla::ComputationDataHandle grad = ctx->Input(3);
|
||||
xla::ComputationDataHandle lr = ctx->Input(4);
|
||||
xla::ComputationDataHandle l1 = ctx->Input(5);
|
||||
|
@ -50,18 +50,41 @@ XLAJIT_MAKE_UNARY(Conj, b->Conj(x));
|
||||
// Return x if x>0, otherwise -x.
|
||||
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
|
||||
|
||||
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
|
||||
XLAJIT_MAKE_UNARY(
|
||||
Acos,
|
||||
b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
|
||||
b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
|
||||
b->Mul(x, x)),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
|
||||
b->Add(XlaHelpers::One(b, input_type(0)), x))));
|
||||
|
||||
// acosh(x) = log(x + sqrt(x^2 - 1))
|
||||
XLAJIT_MAKE_UNARY(
|
||||
Acosh,
|
||||
b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x),
|
||||
XlaHelpers::One(b, input_type(0))),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
|
||||
|
||||
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
||||
XLAJIT_MAKE_UNARY(
|
||||
Asin,
|
||||
b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
|
||||
b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)),
|
||||
b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
|
||||
b->Mul(x, x)),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0),
|
||||
0.5))))));
|
||||
|
||||
// asinh(x) = log(x + sqrt(x^2 + 1))
|
||||
XLAJIT_MAKE_UNARY(
|
||||
Asinh,
|
||||
b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x),
|
||||
XlaHelpers::One(b, input_type(0))),
|
||||
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
|
||||
|
||||
XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0))));
|
||||
|
||||
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
|
||||
XLAJIT_MAKE_UNARY(
|
||||
Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x),
|
||||
|
@ -33,21 +33,29 @@ class VarIsInitializedOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle handle;
|
||||
bool initialized = ctx->ReadVariableInput(0, &handle).ok();
|
||||
ctx->SetOutput(0, ctx->builder()->ConstantR0<bool>(initialized));
|
||||
XlaResource* variable;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
|
||||
ctx->SetOutput(0,
|
||||
ctx->builder()->ConstantR0<bool>(variable->initialized()));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
|
||||
|
||||
class ReadVariableOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
|
||||
ctx->SetOutput(0, handle);
|
||||
}
|
||||
|
||||
private:
|
||||
DataType dtype_;
|
||||
};
|
||||
REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp);
|
||||
|
||||
@ -65,10 +73,12 @@ class AssignAddVariableOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType type = ctx->input_type(1);
|
||||
xla::ComputationDataHandle handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
|
||||
handle = ctx->builder()->Add(handle, ctx->Input(1));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(
|
||||
@ -79,10 +89,12 @@ class AssignSubVariableOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
DataType type = ctx->input_type(1);
|
||||
xla::ComputationDataHandle handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
|
||||
handle = ctx->builder()->Sub(handle, ctx->Input(1));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
|
||||
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
|
||||
}
|
||||
};
|
||||
REGISTER_XLA_OP(
|
||||
@ -95,28 +107,19 @@ class ResourceGatherOp : public XlaOpKernel {
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationBuilder* builder = ctx->builder();
|
||||
|
||||
// Get the shape of the resource tensor.
|
||||
DataType type = ctx->expected_output_dtype(0);
|
||||
|
||||
TensorShape resource_shape;
|
||||
DataType resource_dtype;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape));
|
||||
|
||||
DataType expected_output_dtype = ctx->expected_output_dtype(0);
|
||||
OP_REQUIRES(ctx, resource_dtype == expected_output_dtype,
|
||||
errors::InvalidArgument(
|
||||
"Variable dtype is ", DataTypeString(resource_dtype),
|
||||
" but expected output dtype is ",
|
||||
DataTypeString(expected_output_dtype), "."));
|
||||
|
||||
xla::ComputationDataHandle resource_handle;
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle));
|
||||
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
|
||||
&resource_handle));
|
||||
|
||||
auto indices = ctx->Input(1);
|
||||
auto indices_shape = ctx->InputShape(1);
|
||||
DataType index_type = ctx->input_type(1);
|
||||
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
|
||||
ctx, resource_handle, resource_shape, indices, indices_shape, 0,
|
||||
resource_dtype, index_type, builder);
|
||||
ctx, resource_handle, resource_shape, indices, indices_shape, 0, type,
|
||||
index_type, builder);
|
||||
ctx->SetOutput(0, gather);
|
||||
}
|
||||
};
|
||||
|
@ -58,9 +58,8 @@ Status MakeXlaCompilerArgumentsFromInputs(
|
||||
}
|
||||
|
||||
arg.type = resource->type();
|
||||
if (arg.initialized) {
|
||||
TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape));
|
||||
} else {
|
||||
arg.shape = resource->shape();
|
||||
if (!arg.initialized) {
|
||||
*has_uninitialized_vars = true;
|
||||
}
|
||||
arg.tensor_array_size = resource->tensor_array_size();
|
||||
@ -70,14 +69,13 @@ Status MakeXlaCompilerArgumentsFromInputs(
|
||||
arg.name = resource->name();
|
||||
VLOG(2) << " resource " << resource->name()
|
||||
<< " type: " << DataTypeString(arg.type)
|
||||
<< " shape: " << xla::ShapeUtil::HumanString(arg.shape)
|
||||
<< " shape: " << arg.shape.DebugString()
|
||||
<< " initialized: " << arg.initialized;
|
||||
|
||||
} else {
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
arg.type = ctx->input_type(i);
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
|
||||
arg.shape = ctx->InputShape(i);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -154,17 +152,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
XlaCompiler::Argument& arg = arguments[update.input_index];
|
||||
if (!arg.initialized) {
|
||||
VLOG(2) << "Update shape for argument " << update.input_index << " "
|
||||
<< xla::ShapeUtil::HumanString(update.shape);
|
||||
<< update.shape.DebugString();
|
||||
arg.initialized = true;
|
||||
|
||||
xla::Shape shape = update.shape;
|
||||
if (!update.tensor_array_gradients_accessed.empty()) {
|
||||
shape = xla::ShapeUtil::GetTupleElementShape(shape, 0);
|
||||
}
|
||||
std::unique_ptr<xla::Literal> zero =
|
||||
xla::Literal::CreateFromShape(shape);
|
||||
OP_REQUIRES_OK(ctx, resource->SetValue(
|
||||
update.type, builder->ConstantLiteral(*zero)));
|
||||
arg.shape = update.shape;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
resource->SetTypeAndShape(update.type, update.shape));
|
||||
|
||||
OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
|
||||
}
|
||||
|
||||
// Add any TensorArray gradients touched by the body to the enclosing
|
||||
@ -182,9 +177,6 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
for (const auto& gradient : resource->tensor_array_gradients()) {
|
||||
arg.tensor_array_gradients.insert(gradient.first);
|
||||
}
|
||||
|
||||
// Recompute the argument shape.
|
||||
OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape));
|
||||
}
|
||||
// Recompile the body with the "correct" resource shapes.
|
||||
VLOG(1) << "Recompiling body with corrected resource shapes";
|
||||
@ -292,13 +284,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
|
||||
OP_REQUIRES_OK(ctx,
|
||||
resource->SetFromPack(
|
||||
arguments[update.input_index].tensor_array_gradients,
|
||||
builder->GetTupleElement(while_result, pos),
|
||||
/*reset_initial_values=*/false, builder));
|
||||
builder->GetTupleElement(while_result, pos), builder));
|
||||
}
|
||||
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
|
||||
<< " name: " << resource->name() << " modified: " << update.modified
|
||||
<< " type: " << DataTypeString(update.type)
|
||||
<< " shape: " << xla::ShapeUtil::HumanString(update.shape);
|
||||
<< " shape: " << update.shape.DebugString();
|
||||
// Copies the identity of the resource variable from input to output
|
||||
// unchanged, even if the variable was not modified.
|
||||
ctx->op_kernel_context()->set_output(
|
||||
|
@ -60,6 +60,8 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:computation",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -25,11 +25,10 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// The current implementation simply unrolls the computation along the batch
|
||||
// dimension.
|
||||
xla::StatusOr<xla::ComputationDataHandle> BatchDot(
|
||||
xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
|
||||
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) {
|
||||
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
|
||||
bool conjugate_x, bool conjugate_y) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape,
|
||||
builder->GetShape(x));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape,
|
||||
@ -89,10 +88,10 @@ xla::StatusOr<xla::ComputationDataHandle> BatchDot(
|
||||
dimensions);
|
||||
}
|
||||
|
||||
if (x_shape->element_type() == xla::C64 && transpose_x) {
|
||||
if (x_shape->element_type() == xla::C64 && conjugate_x) {
|
||||
x = builder->Conj(x);
|
||||
}
|
||||
if (y_shape->element_type() == xla::C64 && transpose_y) {
|
||||
if (y_shape->element_type() == xla::C64 && conjugate_y) {
|
||||
y = builder->Conj(y);
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,10 @@ namespace tensorflow {
|
||||
// viewed as an element of a batch), and arranges the individual results
|
||||
// in a single output tensor of the same batch size. Each of the
|
||||
// individual slices can optionally be transposed before multiplication by
|
||||
// setting the `transpose_x` or `transpose_y` flag to `true`.
|
||||
// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each
|
||||
// can be elementwise-complex-conjugated by setting the `conjugate_x` or
|
||||
// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both
|
||||
// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`.
|
||||
//
|
||||
// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
|
||||
// and `[..., r_y, c_y]`.
|
||||
@ -40,11 +43,10 @@ namespace tensorflow {
|
||||
// It is computed as:
|
||||
//
|
||||
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
|
||||
// TODO(phawkins): add an option to take the complex conjugate of the LHS or
|
||||
// RHS.
|
||||
xla::StatusOr<xla::ComputationDataHandle> BatchDot(
|
||||
xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
|
||||
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y);
|
||||
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
|
||||
bool conjugate_x = false, bool conjugate_y = false);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -71,11 +71,14 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
|
||||
SliceInMinorDims(builder, l, {j + 1, 0}, {n, j}));
|
||||
TF_ASSIGN_OR_RETURN(auto r_squared,
|
||||
BatchDot(builder, r, r, /*transpose_x=*/false,
|
||||
/*transpose_y=*/true));
|
||||
/*transpose_y=*/true, /*conjugate_x=*/false,
|
||||
/*conjugate_y=*/false));
|
||||
new_d_squared = builder->Sub(new_d_squared, r_squared);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false,
|
||||
/*transpose_y=*/true));
|
||||
/*transpose_y=*/true,
|
||||
/*conjugate_x=*/false,
|
||||
/*conjugate_y=*/false));
|
||||
}
|
||||
auto new_d_inv = builder->Pow(
|
||||
new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5));
|
||||
@ -134,7 +137,8 @@ xla::StatusOr<xla::ComputationDataHandle> Cholesky(
|
||||
SliceInMinorDims(builder, l, {i, 0}, {i + k, i}));
|
||||
TF_ASSIGN_OR_RETURN(auto delta,
|
||||
BatchDot(builder, lhs, rhs, /*transpose_x=*/false,
|
||||
/*transpose_y=*/true));
|
||||
/*transpose_y=*/true, /*conjugate_x=*/false,
|
||||
/*conjugate_y=*/false));
|
||||
TF_ASSIGN_OR_RETURN(auto before,
|
||||
SliceInMinorDims(builder, a, {i, i}, {n, i + k}));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
@ -155,6 +159,10 @@ xla::StatusOr<xla::ComputationDataHandle> Cholesky(
|
||||
SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
|
||||
TF_ASSIGN_OR_RETURN(auto update,
|
||||
TriangularSolve(builder, factorized, panel,
|
||||
/*left_side=*/false,
|
||||
/*lower=*/true,
|
||||
/*transpose_a=*/true,
|
||||
/*conjugate_a=*/false,
|
||||
/*block_size=*/8));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
l, UpdateSliceInMinorDims(builder, l, update, {i + k, i}));
|
||||
|
@ -29,6 +29,7 @@ namespace tensorflow {
|
||||
// the block size to use.
|
||||
// TODO(phawkins): check for negative values on the diagonal and return an
|
||||
// error, instead of silently yielding NaNs.
|
||||
// TODO(mattjj): handle the complex Hermitian case
|
||||
xla::StatusOr<xla::ComputationDataHandle> Cholesky(
|
||||
xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
|
||||
int64 block_size = 256);
|
||||
|
@ -24,13 +24,15 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
||||
xla::ComputationDataHandle b, int64 block_size) {
|
||||
xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
|
||||
bool conjugate_a, int64 block_size) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
|
||||
builder->GetShape(a));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
|
||||
@ -60,14 +62,15 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
||||
batch_dimensions.push_back(a_size);
|
||||
}
|
||||
|
||||
const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
|
||||
const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
|
||||
if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) {
|
||||
if (xla::ShapeUtil::GetDimension(*a_shape, -1) !=
|
||||
xla::ShapeUtil::GetDimension(*a_shape, -2)) {
|
||||
return errors::InvalidArgument(
|
||||
"The 'a' arguments to TriangularSolve must be square matrices: ",
|
||||
xla::ShapeUtil::HumanString(*a_shape));
|
||||
}
|
||||
if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) {
|
||||
const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
|
||||
const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
|
||||
if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) {
|
||||
return errors::InvalidArgument(
|
||||
"Arguments to TriangularSolve have incompatible matrix shapes: ",
|
||||
xla::ShapeUtil::HumanString(*a_shape), " vs ",
|
||||
@ -89,6 +92,14 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
||||
return output;
|
||||
};
|
||||
|
||||
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
|
||||
// is true, otherwise returns its argument.
|
||||
auto maybe_conj = [&](xla::ComputationBuilder* builder,
|
||||
xla::ComputationDataHandle x) {
|
||||
auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
|
||||
return perform_conj ? builder->Conj(x) : x;
|
||||
};
|
||||
|
||||
std::map<int, xla::Computation> base_computations;
|
||||
auto get_base_triangular_solve =
|
||||
[&](int k) -> xla::StatusOr<xla::Computation*> {
|
||||
@ -103,19 +114,35 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
||||
prepend_batch_dims({k, k})),
|
||||
"a");
|
||||
|
||||
std::array<int64, 2> b_lastd;
|
||||
if (left_side) {
|
||||
b_lastd = {k, n};
|
||||
} else {
|
||||
b_lastd = {m, k};
|
||||
}
|
||||
auto b_param =
|
||||
sub->Parameter(1,
|
||||
xla::ShapeUtil::MakeShape(b_shape->element_type(),
|
||||
prepend_batch_dims({m, k})),
|
||||
prepend_batch_dims(b_lastd)),
|
||||
"b");
|
||||
|
||||
// TODO(phawkins): it might make sense to use a while loop here, rather
|
||||
// than unrolling.
|
||||
// TODO(phawkins): the left-looking variant of the algorithm might be more
|
||||
// efficient at block size 1.
|
||||
TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
|
||||
/*block_size=*/1)
|
||||
.status());
|
||||
// We use a left-looking subroutine on the block diagonal in some common
|
||||
// cases, while falling back to a recursive call in unsupported cases. The
|
||||
// left-looking subroutine is written with a While loop and so yields much
|
||||
// faster compile times. Moreover, the left-looking variant can give
|
||||
// higher performance on smaller (sub)problems.
|
||||
if (left_side && lower) {
|
||||
TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param,
|
||||
b_param, transpose_a,
|
||||
conjugate_a)
|
||||
.status());
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
|
||||
left_side, lower, transpose_a,
|
||||
conjugate_a,
|
||||
/*block_size=*/1)
|
||||
.status());
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(computation, sub->Build());
|
||||
}
|
||||
@ -129,47 +156,396 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
||||
// Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation
|
||||
// of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1
|
||||
// (2008): 4.
|
||||
for (int64 i = 0; i < n; i += block_size) {
|
||||
int64 k = std::min(block_size, n - i);
|
||||
|
||||
// if k > 1:
|
||||
// output[..., :, i:i+k] = triangular_solve(
|
||||
// a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right',
|
||||
// kind='Lower', transpose=True, block_size=1)
|
||||
// else:
|
||||
// output[..., :, i] = b[..., :, i] / a[..., i, i]
|
||||
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
|
||||
xla::ComputationDataHandle update;
|
||||
if (k > 1) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
|
||||
get_base_triangular_solve(k));
|
||||
update = builder->Call(*solve, {a_slice, b_slice});
|
||||
} else {
|
||||
update = builder->Div(b_slice, a_slice);
|
||||
// In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if
|
||||
// conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if
|
||||
// conjugate_a is True.
|
||||
|
||||
if (!left_side && lower == transpose_a) {
|
||||
// for i in range(0, a.shape[-1], block_size):
|
||||
for (int64 i = 0; i < n; i += block_size) {
|
||||
int64 k = std::min(block_size, n - i);
|
||||
|
||||
// output[..., :, i:i+k] = triangular_solve(
|
||||
// a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
|
||||
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
|
||||
xla::ComputationDataHandle update;
|
||||
if (k > 1) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
|
||||
get_base_triangular_solve(k));
|
||||
update = builder->Call(*solve, {a_slice, b_slice});
|
||||
} else {
|
||||
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
|
||||
|
||||
// if i + k < a.shape[-1]:
|
||||
// a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
|
||||
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
|
||||
// b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
|
||||
if (i + k < n) {
|
||||
xla::ComputationDataHandle a_slice_2;
|
||||
if (lower) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n}));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto b_update,
|
||||
BatchDot(builder, update, a_slice_2,
|
||||
/*transpose_x=*/false,
|
||||
/*transpose_y=*/transpose_a,
|
||||
/*conjugate_x=*/false,
|
||||
/*conjugate_y=*/conjugate_a));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
||||
SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
|
||||
b_update = builder->Sub(b_slice_2, b_update);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
|
||||
// b[..., :, i+k:] -= np.dot(output[..., :, i:i+k],
|
||||
// np.transpose(..., a[i+k:, i:i+k]))
|
||||
if (i + k < n) {
|
||||
TF_ASSIGN_OR_RETURN(auto a_slice_2,
|
||||
SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
|
||||
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2,
|
||||
/*transpose_x=*/false,
|
||||
/*transpose_y=*/true));
|
||||
} else if (left_side && lower != transpose_a) {
|
||||
// for i in range(0, a.shape[-1], block_size):
|
||||
for (int64 i = 0; i < m; i += block_size) {
|
||||
int64 k = std::min(block_size, m - i);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
||||
SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
|
||||
b_update = builder->Sub(b_slice_2, b_update);
|
||||
// output[..., i:i+k, :] = triangular_solve(
|
||||
// a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
|
||||
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
|
||||
xla::ComputationDataHandle update;
|
||||
if (k > 1) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
|
||||
get_base_triangular_solve(k));
|
||||
update = builder->Call(*solve, {a_slice, b_slice});
|
||||
} else {
|
||||
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
|
||||
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
|
||||
|
||||
// if i + k < a.shape[-1]:
|
||||
// a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
|
||||
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
|
||||
// b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
|
||||
if (i + k < m) {
|
||||
xla::ComputationDataHandle a_slice_2;
|
||||
if (lower) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k}));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m}));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
|
||||
/*transpose_x=*/transpose_a,
|
||||
/*transpose_y=*/false,
|
||||
/*conjugate_x=*/conjugate_a,
|
||||
/*conjugate_y=*/false));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
||||
SliceInMinorDims(builder, b, {i + k, 0}, {m, n}));
|
||||
b_update = builder->Sub(b_slice_2, b_update);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0}));
|
||||
}
|
||||
}
|
||||
} else if (!left_side && lower != transpose_a) {
|
||||
// for i in reversed(range(0, a.shape[-1], block_size)):
|
||||
const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size;
|
||||
for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
|
||||
int64 k = std::min(block_size, n - i);
|
||||
|
||||
// output[..., :, i:i+k] triangular_solve(
|
||||
// a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
|
||||
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
|
||||
xla::ComputationDataHandle update;
|
||||
if (k > 1) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
|
||||
get_base_triangular_solve(k));
|
||||
update = builder->Call(*solve, {a_slice, b_slice});
|
||||
} else {
|
||||
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
|
||||
|
||||
// if i - k >= 0:
|
||||
// a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
|
||||
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
|
||||
// b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
|
||||
if (i - k >= 0) {
|
||||
xla::ComputationDataHandle a_slice_2;
|
||||
if (lower) {
|
||||
TF_ASSIGN_OR_RETURN(a_slice_2,
|
||||
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(a_slice_2,
|
||||
SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto b_update,
|
||||
BatchDot(builder, update, a_slice_2,
|
||||
/*transpose_x=*/false,
|
||||
/*transpose_y=*/transpose_a,
|
||||
/*conjugate_x=*/false,
|
||||
/*conjugate_y=*/conjugate_a));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
||||
SliceInMinorDims(builder, b, {0, 0}, {m, i}));
|
||||
b_update = builder->Sub(b_slice_2, b_update);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
|
||||
}
|
||||
}
|
||||
} else { // left_side && lower == transpose_a
|
||||
// for i in reversed(range(0, a.shape[-1], block_size)):
|
||||
const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size;
|
||||
for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
|
||||
int64 k = std::min(block_size, m - i);
|
||||
|
||||
// output[..., i:i+k, :] triangular_solve(
|
||||
// a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
|
||||
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
|
||||
xla::ComputationDataHandle update;
|
||||
if (k > 1) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
|
||||
get_base_triangular_solve(k));
|
||||
update = builder->Call(*solve, {a_slice, b_slice});
|
||||
} else {
|
||||
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
|
||||
|
||||
// if i - k >= 0:
|
||||
// a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
|
||||
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
|
||||
// b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
|
||||
if (i - k >= 0) {
|
||||
xla::ComputationDataHandle a_slice_2;
|
||||
if (lower) {
|
||||
TF_ASSIGN_OR_RETURN(a_slice_2,
|
||||
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(a_slice_2,
|
||||
SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
|
||||
/*transpose_x=*/transpose_a,
|
||||
/*transpose_y=*/false,
|
||||
/*conjugate_x=*/conjugate_a,
|
||||
/*conjugate_y=*/false));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice_2,
|
||||
SliceInMinorDims(builder, b, {0, 0}, {i, n}));
|
||||
b_update = builder->Sub(b_slice_2, b_update);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
||||
const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
|
||||
builder->GetShape(a));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
|
||||
builder->GetShape(b));
|
||||
const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
|
||||
const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
|
||||
const int64 ndims = xla::ShapeUtil::Rank(*a_shape);
|
||||
|
||||
std::vector<int64> batch_dimensions;
|
||||
for (int i = 0; i < ndims - 2; ++i) {
|
||||
int64 a_size = a_shape->dimensions(i);
|
||||
batch_dimensions.push_back(a_size);
|
||||
}
|
||||
|
||||
auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
|
||||
std::vector<int64> output(ndims);
|
||||
std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
|
||||
std::copy(indices.begin(), indices.end(),
|
||||
output.begin() + batch_dimensions.size());
|
||||
return output;
|
||||
};
|
||||
|
||||
auto maybe_conj = [&](xla::ComputationBuilder* builder,
|
||||
xla::ComputationDataHandle x) {
|
||||
auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
|
||||
return perform_conj ? builder->Conj(x) : x;
|
||||
};
|
||||
|
||||
// The main computation is performed in a While loop.
|
||||
|
||||
// Allocate the output and set its first or last row,
|
||||
// output = np.zeros_like(b)
|
||||
// if transpose_a:
|
||||
// output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
|
||||
// else:
|
||||
// output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
|
||||
xla::ComputationDataHandle output = Zeros(builder, *b_shape);
|
||||
{
|
||||
auto i = transpose_a ? m - 1 : 0;
|
||||
TF_ASSIGN_OR_RETURN(auto a_slice,
|
||||
SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1}));
|
||||
TF_ASSIGN_OR_RETURN(auto b_slice,
|
||||
SliceInMinorDims(builder, b, {i, 0}, {i + 1, n}));
|
||||
auto update = builder->Div(b_slice, maybe_conj(builder, a_slice));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
|
||||
}
|
||||
|
||||
// Construct the initial loop carry tuple,
|
||||
// if transpose_a:
|
||||
// init = (m-2, output, a, b)
|
||||
// else:
|
||||
// init = (1, output, a, b)
|
||||
std::vector<xla::Shape> tuple_shapes = {
|
||||
// The loop iteration counter is a scalar, incremented each iteration.
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {}),
|
||||
// The output has the shape of b, with one row updated each iteration.
|
||||
*b_shape,
|
||||
// The coefficient matrix a is a loop invariant.
|
||||
*a_shape,
|
||||
// The right-hand-side matrix b is a loop invariant.
|
||||
*b_shape};
|
||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
|
||||
auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1);
|
||||
auto init = builder->Tuple({init_i, output, a, b});
|
||||
|
||||
// Construct the loop condition function,
|
||||
// def cond_fun(loop_carry):
|
||||
// i, output, a, b = loop_carry
|
||||
// return i >= 0 if transpose_a else i < m
|
||||
std::unique_ptr<xla::ComputationBuilder> condb =
|
||||
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
|
||||
{
|
||||
auto i = condb->GetTupleElement(
|
||||
condb->Parameter(0, tuple_shape,
|
||||
"TriangularSolveLeftLookingWhileTuple"),
|
||||
0);
|
||||
if (transpose_a) {
|
||||
condb->Ge(i, condb->ConstantR0<int32>(0));
|
||||
} else {
|
||||
condb->Lt(i, condb->ConstantR0<int32>(m));
|
||||
}
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
|
||||
|
||||
// Construct the loop body function,
|
||||
// def body_fun(loop_carry):
|
||||
// i, output, a, b = loop_carry
|
||||
// if transpose_a:
|
||||
// a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2)
|
||||
// else:
|
||||
// a_row = a[..., i:i+1, :i]
|
||||
// result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :])
|
||||
// output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
|
||||
// if transpose_a:
|
||||
// return (i - 1, output, a, b)
|
||||
// else:
|
||||
// return (i + 1, output, a, b)
|
||||
// We have to do some extra FLOPs propagating zeros in the matrix multiply
|
||||
// because we can't have the size of its arguments depend on the loop counter.
|
||||
std::unique_ptr<xla::ComputationBuilder> bodyb =
|
||||
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
|
||||
{
|
||||
auto input_tuple = bodyb->Parameter(0, tuple_shape,
|
||||
"TriangularSolveLeftLookingWhileTuple");
|
||||
|
||||
// i, output, a, b = loop_carry
|
||||
auto i = bodyb->GetTupleElement(input_tuple, 0);
|
||||
auto body_out = bodyb->GetTupleElement(input_tuple, 1);
|
||||
auto body_a = bodyb->GetTupleElement(input_tuple, 2);
|
||||
auto body_b = bodyb->GetTupleElement(input_tuple, 3);
|
||||
auto zero = bodyb->ConstantR0<int32>(0);
|
||||
|
||||
// Set up some helper functions.
|
||||
auto prepend_zeros = [&](std::array<xla::ComputationDataHandle, 2> starts) {
|
||||
auto zero = bodyb->Reshape(bodyb->ConstantR0<int32>(0), {1});
|
||||
std::vector<xla::ComputationDataHandle> padded_starts(ndims, zero);
|
||||
padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1});
|
||||
padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1});
|
||||
return bodyb->ConcatInDim(padded_starts, 0);
|
||||
};
|
||||
|
||||
auto dynamic_slice = [&](xla::ComputationDataHandle x,
|
||||
std::array<xla::ComputationDataHandle, 2> starts,
|
||||
std::array<int64, 2> sizes) {
|
||||
auto padded_starts = prepend_zeros(starts);
|
||||
auto padded_sizes = prepend_batch_dims(sizes);
|
||||
return bodyb->DynamicSlice(x, padded_starts, padded_sizes);
|
||||
};
|
||||
|
||||
auto update = [&](xla::ComputationDataHandle x,
|
||||
xla::ComputationDataHandle update,
|
||||
std::array<xla::ComputationDataHandle, 2> starts) {
|
||||
auto padded_starts = prepend_zeros(starts);
|
||||
return bodyb->DynamicUpdateSlice(x, update, padded_starts);
|
||||
};
|
||||
|
||||
// We'd like to implement this:
|
||||
// if transpose_a:
|
||||
// a_row = T(a[..., i+1:, i:i+1])
|
||||
// result_row = (b[..., i:i+1, :]
|
||||
// - np.matmul(a_row, body_out[..., i+1:, :]))
|
||||
// else:
|
||||
// result_row = (b[..., i:i+1, :]
|
||||
// - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :]))
|
||||
// But since we can't have intermediate array sizes depend on the loop
|
||||
// counter, we instead exploit the fact that we initialized the output to
|
||||
// all zeros and use that as zero-padding (doing unnecessary FLOPs).
|
||||
xla::ComputationDataHandle a_row;
|
||||
if (transpose_a) {
|
||||
a_row = dynamic_slice(body_a, {zero, i}, {m, 1});
|
||||
} else {
|
||||
a_row = dynamic_slice(body_a, {i, zero}, {1, m});
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out,
|
||||
/*transpose_x=*/transpose_a,
|
||||
/*transpose_y=*/false,
|
||||
/*conjugate_x=*/conjugate_a,
|
||||
/*conjugate_y=*/false));
|
||||
auto result_row =
|
||||
bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update);
|
||||
|
||||
// body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
|
||||
auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1});
|
||||
auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt));
|
||||
body_out = update(body_out, div_result, {i, zero});
|
||||
|
||||
// if transpose_a:
|
||||
// return (i - 1, body_out, a, b)
|
||||
// else:
|
||||
// return (i + 1, body_out, a, b)
|
||||
auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? -1 : 1));
|
||||
bodyb->Tuple({next_i, body_out, body_a, body_b});
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
|
||||
|
||||
// Construct the While loop and return the result,
|
||||
// return while_loop(cond_fun, body_fun, init)[1]
|
||||
auto triangular_solve_left_looking_while = builder->While(cond, body, init);
|
||||
return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -21,25 +21,50 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Solves systems of linear equations with upper or lower triangular matrices by
|
||||
// backsubstitution.
|
||||
// Solves systems of linear equations with lower or upper triangular coefficient
|
||||
// matrices by forward- or back-substitution. Broadcasting along leading
|
||||
// dimensions, this routine solves one of the matrix systems
|
||||
// `op(a) * x = b`, or `x * op(a) = b`,
|
||||
// for the variable `x` given `a` and `b`, where `op(a)` is either
|
||||
// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`.
|
||||
// That is, the innermost matrices in the output satisfy a scalar system
|
||||
// depending on the value of the value of (left_side, transpose_a, conjugate_a)
|
||||
// according to:
|
||||
// (F, F, F) => `output[..., i, k] a[..., k, j] = b[..., i, j]`,
|
||||
// (F, F, T) => `output[..., i, k] a*[..., k, j] = b[..., i, j]`,
|
||||
// (F, T, F) => `output[..., i, k] a[..., j, k] = b[..., i, j]`,
|
||||
// (F, T, T) => `output[..., i, k] a*[..., j, k] = b[..., i, j]`,
|
||||
// (T, F, F) => ` a[..., i, k] output[..., k, j] = b[..., i, j]`,
|
||||
// (T, F, T) => `a*[..., i, k] output[..., k, j] = b[..., i, j]`,
|
||||
// (T, T, F) => ` a[..., i, k] output[..., j, k] = b[..., i, j]`,
|
||||
// (T, T, T) => `a*[..., i, k] output[..., j, k] = b[..., i, j]`,
|
||||
// where * denotes complex conjugation and where the index `k` is summed over.
|
||||
//
|
||||
// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
|
||||
// square matrices. The strictly upper triangular part of each inner-most matrix
|
||||
// is assumed to be zero and not accessed.
|
||||
// `b` is a tensor of shape `[..., M, K]`.
|
||||
//
|
||||
// The innermost matrices in the output satisfy matrix equations
|
||||
// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`.
|
||||
// `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
|
||||
// square matrices. If lower is true (false), then the strictly upper (lower)
|
||||
// triangular part of each innermost matrix in `a` is assumed to be zero and is
|
||||
// not accessed.
|
||||
// `b` is a tensor of shape `[..., M, K]` if left_side is true, otherwise a
|
||||
// tensor of shape `[..., K, M]`.
|
||||
// `left_side` is a boolean, indicating whether to solve a system of the form
|
||||
// op(a) * x = b (true) or x * op(a) = b (false).
|
||||
// `lower` is a boolean, indicating whether the argument `a` is lower-triangular
|
||||
// (true) or upper-triangular (false).
|
||||
// `transpose_a` is a boolean indicating whether the matrix `a` is transposed.
|
||||
// `conjugate_a` is a boolean indicating whether the entries of `a` are complex
|
||||
// conjugated (independently of whether they are transposed), so that when both
|
||||
// transpose_a and conjugate_a are true the effect is a Hermitian adjoint.
|
||||
//
|
||||
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
|
||||
// blocking is used.
|
||||
// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right,
|
||||
// kind=lower, and transposed_a=true. Implement the other possible combinations
|
||||
// of side, kind and transposed_a.
|
||||
xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
||||
xla::ComputationDataHandle b, int64 block_size = 256);
|
||||
xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
|
||||
bool conjugate_a, int64 block_size = 256);
|
||||
|
||||
xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
|
||||
const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -27,32 +27,68 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
using TriangularSolveTest = xla::ClientLibraryTestBase;
|
||||
using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase;
|
||||
using complex64 = xla::complex64;
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, Simple) {
|
||||
xla::Array2D<float> AValsLower() {
|
||||
return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}};
|
||||
}
|
||||
|
||||
xla::Array2D<float> AValsUpper() {
|
||||
return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}};
|
||||
}
|
||||
|
||||
xla::Array2D<float> BValsRight() {
|
||||
return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
|
||||
}
|
||||
|
||||
xla::Array2D<float> BValsLeft() {
|
||||
return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
|
||||
}
|
||||
|
||||
xla::Array2D<complex64> AValsLowerComplex() {
|
||||
return {{2, 0, 0, 0},
|
||||
{complex64(3, 1), 6, 0, 0},
|
||||
{4, complex64(7, 2), 9, 0},
|
||||
{5, 8, complex64(10, 3), 11}};
|
||||
}
|
||||
|
||||
xla::Array2D<complex64> AValsUpperComplex() {
|
||||
return {{2, 3, complex64(4, 3), 5},
|
||||
{0, 6, complex64(7, 2), 8},
|
||||
{0, 0, complex64(9, 1), 10},
|
||||
{0, 0, 0, 11}};
|
||||
}
|
||||
|
||||
xla::Array2D<complex64> BValsRightComplex() {
|
||||
return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
|
||||
}
|
||||
|
||||
xla::Array2D<complex64> BValsLeftComplex() {
|
||||
return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
|
||||
}
|
||||
|
||||
xla::Array2D<float> AValsFull() {
|
||||
return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
{2, 0, 0, 0},
|
||||
{3, 6, 0, 0},
|
||||
{4, 7, 9, 0},
|
||||
{5, 8, 10, 11},
|
||||
});
|
||||
xla::Array2D<float> b_vals({
|
||||
{1, 2, 3, 4},
|
||||
{5, 6, 7, 8},
|
||||
{9, 10, 11, 12},
|
||||
});
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(b_vals, 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b, /*block_size=*/2);
|
||||
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/false, /*lower=*/true,
|
||||
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
@ -62,7 +98,267 @@ XLA_TEST_F(TriangularSolveTest, Simple) {
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(2e-3, 2e-3));
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/false, /*lower=*/true,
|
||||
/*transpose_a=*/false, /*conjugate_a=*/false,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
|
||||
{0.64393939, 0.06565657, -0.03030303, 0.72727273},
|
||||
{1.4520202, 0.2003367, 0.01010101, 1.09090909},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/false, /*lower=*/false,
|
||||
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
|
||||
{0.64393939, 0.06565657, -0.03030303, 0.72727273},
|
||||
{1.4520202, 0.2003367, 0.01010101, 1.09090909},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/false, /*lower=*/false,
|
||||
/*transpose_a=*/false, /*conjugate_a=*/false,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{0.5, 0.08333334, 0.04629629, 0.03367003},
|
||||
{2.5, -0.25, -0.1388889, -0.1010101},
|
||||
{4.5, -0.58333331, -0.32407406, -0.23569024},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/true, /*lower=*/true,
|
||||
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{-0.89646465, -0.69444444, -0.49242424},
|
||||
{-0.27441077, -0.24074074, -0.20707071},
|
||||
{-0.23232323, -0.22222222, -0.21212121},
|
||||
{0.90909091, 1., 1.09090909},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/true, /*lower=*/true,
|
||||
/*transpose_a=*/false, /*conjugate_a=*/false,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{0.5, 1.0, 1.5},
|
||||
{0.41666667, 0.33333333, 0.25},
|
||||
{0.23148148, 0.18518519, 0.13888889},
|
||||
{0.16835017, 0.13468013, 0.1010101},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/true, /*lower=*/false,
|
||||
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{0.5, 1.0, 1.5},
|
||||
{0.41666667, 0.33333333, 0.25},
|
||||
{0.23148148, 0.18518519, 0.13888889},
|
||||
{0.16835017, 0.13468013, 0.1010101},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/true, /*lower=*/false,
|
||||
/*transpose_a=*/false, /*conjugate_a=*/false,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{-0.89646465, -0.69444444, -0.49242424},
|
||||
{-0.27441077, -0.24074074, -0.20707071},
|
||||
{-0.23232323, -0.22222222, -0.21212121},
|
||||
{0.90909091, 1., 1.09090909},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data =
|
||||
CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
|
||||
auto b_data =
|
||||
CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/false, /*lower=*/true,
|
||||
/*transpose_a=*/true, /*conjugate_a=*/true,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<complex64> expected({
|
||||
{0.5, complex64(0.08333333, 0.08333333),
|
||||
complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)},
|
||||
{2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963),
|
||||
complex64(0.08670034, -0.02104377)},
|
||||
{4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296),
|
||||
complex64(0.11026936, -0.03114478)},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<complex64>(&builder, expected,
|
||||
{a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data =
|
||||
CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
|
||||
auto b_data =
|
||||
CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolve(&builder, a, b,
|
||||
/*left_side=*/true, /*lower=*/false,
|
||||
/*transpose_a=*/true, /*conjugate_a=*/false,
|
||||
/*block_size=*/2);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<complex64> expected({
|
||||
{0.5, 1., 1.5},
|
||||
{0.41666667, 0.33333333, 0.25},
|
||||
{complex64(0.20020325, -2.81504065e-01),
|
||||
complex64(0.13821138, -4.22764228e-01),
|
||||
complex64(0.07621951, -5.64024390e-01)},
|
||||
{complex64(0.19678492, 2.55912786e-01),
|
||||
complex64(0.17738359, 3.84331116e-01),
|
||||
complex64(0.15798226, 5.12749446e-01)},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<complex64>(&builder, expected,
|
||||
{a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolveLeftLooking(&builder, a, b,
|
||||
/*transpose_a=*/false,
|
||||
/*conjugate_a=*/false);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{0.5, 1.0, 1.5},
|
||||
{0.41666667, 0.33333333, 0.25},
|
||||
{0.23148148, 0.18518519, 0.13888889},
|
||||
{0.16835017, 0.13468013, 0.1010101},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
|
||||
xla::ComputationBuilder builder(client_, TestName());
|
||||
|
||||
xla::ComputationDataHandle a, b;
|
||||
auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
|
||||
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
|
||||
auto result = TriangularSolveLeftLooking(&builder, a, b,
|
||||
/*transpose_a=*/false,
|
||||
/*conjugate_a=*/false);
|
||||
TF_ASSERT_OK(result.status());
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
{0.5, 1.0, 1.5},
|
||||
{0.41666667, 0.33333333, 0.25},
|
||||
{0.23148148, 0.18518519, 0.13888889},
|
||||
{0.16835017, 0.13468013, 0.1010101},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
|
||||
xla::ErrorSpec(1e-2, 1e-2));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -107,4 +107,15 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
|
||||
return UpdateSlice(builder, x, update, padded_start);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
|
||||
const int64 n_dims = xla::ShapeUtil::Rank(*shape);
|
||||
TF_RET_CHECK(n_dims >= 2);
|
||||
std::vector<int64> permutation(n_dims);
|
||||
std::iota(permutation.begin(), permutation.end(), 0);
|
||||
std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
|
||||
return builder->Transpose(x, permutation);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -49,6 +49,10 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
|
||||
const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
|
||||
|
||||
// Transposes a stack of matrices `x` by swapping the last two dimensions.
|
||||
xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_
|
||||
|
@ -241,9 +241,7 @@ Status CreateXlaArgs(const Graph& graph,
|
||||
XlaCompiler::Argument arg;
|
||||
arg.kind = XlaCompiler::Argument::kParameter;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
|
||||
xla_args->push_back(arg);
|
||||
}
|
||||
|
@ -66,13 +66,14 @@ Status CheckSignature(const DataTypeVector& types,
|
||||
|
||||
bool XlaCompiler::Argument::operator==(
|
||||
const XlaCompiler::Argument& other) const {
|
||||
if (std::tie(kind, resource_kind, type, name, tensor_array_size,
|
||||
if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size,
|
||||
tensor_array_gradients) !=
|
||||
std::tie(other.kind, other.resource_kind, other.type, other.name,
|
||||
other.tensor_array_size, other.tensor_array_gradients)) {
|
||||
other.initialized, other.tensor_array_size,
|
||||
other.tensor_array_gradients)) {
|
||||
return false;
|
||||
}
|
||||
if (!xla::ShapeUtil::Equal(shape, other.shape)) {
|
||||
if (shape != other.shape) {
|
||||
return false;
|
||||
}
|
||||
if (constant_value.shape() != other.constant_value.shape()) {
|
||||
@ -230,6 +231,64 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes the XLA shape for argument 'arg'.
|
||||
/*static*/ Status XlaCompiler::XLAShapeForArgument(
|
||||
const XlaCompiler::Argument& arg, xla::Shape* xla_shape) {
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
|
||||
xla_shape);
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
|
||||
case XlaCompiler::Argument::kResource: {
|
||||
TF_RET_CHECK(arg.initialized);
|
||||
|
||||
switch (arg.resource_kind) {
|
||||
case XlaResource::kVariable:
|
||||
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
|
||||
case XlaResource::kTensorArray: {
|
||||
if (arg.tensor_array_size < 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Negative tensor_array_size in XLAShapeForArgument");
|
||||
}
|
||||
TensorShape shape;
|
||||
shape.AddDim(arg.tensor_array_size);
|
||||
shape.AppendShape(arg.shape);
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
|
||||
|
||||
if (!arg.tensor_array_gradients.empty()) {
|
||||
std::vector<xla::Shape> tuple_shape(
|
||||
arg.tensor_array_gradients.size() + 1, *xla_shape);
|
||||
*xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
case XlaResource::kStack: {
|
||||
if (arg.tensor_array_size < 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Negative tensor_array_size in XLAShapeForArgument");
|
||||
}
|
||||
TensorShape shape;
|
||||
shape.AddDim(arg.tensor_array_size);
|
||||
shape.AppendShape(arg.shape);
|
||||
xla::Shape buffer_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
|
||||
*xla_shape = xla::ShapeUtil::MakeTupleShape(
|
||||
{buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
case XlaResource::kInvalid:
|
||||
return errors::Internal(
|
||||
"Invalid resource type in XLAShapeForArgument()");
|
||||
}
|
||||
}
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
return errors::Internal("Invalid argument type in XLAShapeForArgument()");
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
@ -275,8 +334,9 @@ Status BuildArguments(const Graph& graph,
|
||||
|
||||
// Argument numbers of arguments and resources that are to be passed to the
|
||||
// XLA computation as runtime parameters.
|
||||
std::vector<int> parameters, resources;
|
||||
parameters.reserve(args.size());
|
||||
input_mapping->clear();
|
||||
input_mapping->reserve(args.size());
|
||||
std::vector<int> resources;
|
||||
resources.reserve(args.size());
|
||||
|
||||
// Fills in constant arguments, and computes non-constant argument order.
|
||||
@ -290,18 +350,20 @@ Status BuildArguments(const Graph& graph,
|
||||
// TODO(phawkins): this code assumes that resource arguments do not
|
||||
// alias.
|
||||
XlaResource* resource;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->CreateResource(arg.resource_kind, i, arg.name, arg.type,
|
||||
xla::ComputationDataHandle(), &resource));
|
||||
resource->set_tensor_array_size(arg.tensor_array_size);
|
||||
TF_RETURN_IF_ERROR(context->CreateResource(
|
||||
arg.resource_kind, i, arg.name, arg.type, arg.shape,
|
||||
xla::ComputationDataHandle(),
|
||||
/*tensor_array_size=*/arg.tensor_array_size,
|
||||
/*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
|
||||
arg_expression.set_resource(resource);
|
||||
if (arg.initialized) {
|
||||
resources.push_back(i);
|
||||
}
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
parameters.push_back(i);
|
||||
case XlaCompiler::Argument::kParameter: {
|
||||
input_mapping->push_back(i);
|
||||
break;
|
||||
}
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
arg_expression.set_constant_value(arg.constant_value);
|
||||
break;
|
||||
@ -312,19 +374,17 @@ Status BuildArguments(const Graph& graph,
|
||||
|
||||
// Append parameters containing variable values after the other runtime
|
||||
// parameters.
|
||||
parameters.insert(parameters.end(), resources.begin(), resources.end());
|
||||
if (parameters.empty()) {
|
||||
input_mapping->insert(input_mapping->end(), resources.begin(),
|
||||
resources.end());
|
||||
if (input_mapping->empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<xla::Shape> arg_shapes;
|
||||
arg_shapes.reserve(parameters.size());
|
||||
input_mapping->resize(parameters.size());
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args[parameters[i]];
|
||||
std::vector<xla::Shape> arg_shapes(input_mapping->size());
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
// Computes the shapes of non-constant arguments.
|
||||
arg_shapes.push_back(arg.shape);
|
||||
(*input_mapping)[i] = parameters[i];
|
||||
TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument(
|
||||
args[(*input_mapping)[i]], &arg_shapes[i]));
|
||||
}
|
||||
|
||||
if (use_tuple_arg) {
|
||||
@ -354,13 +414,13 @@ Status BuildArguments(const Graph& graph,
|
||||
}
|
||||
|
||||
// Build parameter handles for non-constant arguments.
|
||||
std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
|
||||
std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
|
||||
if (use_tuple_arg) {
|
||||
xla::ComputationDataHandle tuple;
|
||||
if (is_entry_computation) {
|
||||
xla::OpSharding tuple_sharding;
|
||||
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
|
||||
for (int64 parameter : parameters) {
|
||||
for (int64 parameter : *input_mapping) {
|
||||
const int core = (*arg_cores)[parameter];
|
||||
const int root_device = 0;
|
||||
*tuple_sharding.add_tuple_shardings() =
|
||||
@ -373,16 +433,16 @@ Status BuildArguments(const Graph& graph,
|
||||
} else {
|
||||
tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
|
||||
}
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
const int core = (*arg_cores)[parameters[i]];
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
const int core = (*arg_cores)[input_mapping->at(i)];
|
||||
xla::ScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
arg_handles[i] = builder->GetTupleElement(tuple, i);
|
||||
}
|
||||
} else {
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
const int core = (*arg_cores)[parameters[i]];
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
const int core = (*arg_cores)[input_mapping->at(i)];
|
||||
xla::ScopedShardingAssignment assign_sharding(
|
||||
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
|
||||
: xla::sharding_builder::AssignDevice(core));
|
||||
@ -393,19 +453,18 @@ Status BuildArguments(const Graph& graph,
|
||||
|
||||
// Fill in the handles in non-constant arguments.
|
||||
VLOG(2) << "XLA computation inputs:";
|
||||
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args[parameters[i]];
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
|
||||
VLOG(2) << " XLA arg " << i
|
||||
<< " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
|
||||
<< " name: " << arg.name << " TF arg " << parameters[i];
|
||||
XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
|
||||
<< " name: " << arg.name << " TF arg " << input_mapping->at(i);
|
||||
XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)];
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kResource: {
|
||||
TF_RET_CHECK(arg.initialized);
|
||||
XlaResource* resource = arg_expression.resource();
|
||||
TF_RETURN_IF_ERROR(
|
||||
resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i],
|
||||
/*reset_initial_values=*/true, builder));
|
||||
TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
|
||||
arg_handles[i], builder));
|
||||
VLOG(2) << " resource: num_gradients: "
|
||||
<< arg.tensor_array_gradients.size();
|
||||
break;
|
||||
@ -486,6 +545,7 @@ Status BuildComputation(
|
||||
XlaCompiler::ResourceUpdate& update = resource_updates->back();
|
||||
update.input_index = resource->arg_num();
|
||||
update.type = resource->type();
|
||||
update.shape = resource->shape();
|
||||
update.modified = modified;
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
@ -616,13 +676,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
++computation_output;
|
||||
}
|
||||
}
|
||||
|
||||
for (std::vector<ResourceUpdate>::size_type i = 0;
|
||||
i < result->resource_updates.size(); ++i) {
|
||||
result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
|
||||
result->xla_output_shape, computation_output);
|
||||
++computation_output;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -104,9 +104,17 @@ class XlaCompiler {
|
||||
// is the type of the variable's value, not DT_RESOURCE.
|
||||
DataType type;
|
||||
|
||||
// The shape of the argument. If the argument is a resource, this is the
|
||||
// shape of the resource's value.
|
||||
xla::Shape shape;
|
||||
// The shape of the argument. For:
|
||||
// * a parameter: the shape of the parameter.
|
||||
// * a constant: ignored; the shape given by constant_value is used
|
||||
// instead.
|
||||
// * an uninitialized resource: ignored. We don't yet know the shape of an
|
||||
// uninitialized resource (otherwise we would have initialized it!)
|
||||
// * an initialized variable: the shape of the variable's value.
|
||||
// * an initialized TensorArray or Stack resource: the shape of an entry in
|
||||
// the TensorArray/Stack. Note this is the size of a single entry, not the
|
||||
// XLA data structure that represents the complete stack/array.
|
||||
TensorShape shape;
|
||||
|
||||
// The value of the argument, if it is a compile-time constant. Must be a
|
||||
// host-memory tensor.
|
||||
@ -175,8 +183,9 @@ class XlaCompiler {
|
||||
int input_index;
|
||||
|
||||
// Type and shape of the tensor to be written back.
|
||||
// The `shape` field has the same meaning as the Argument::shape field.
|
||||
DataType type;
|
||||
xla::Shape shape;
|
||||
TensorShape shape;
|
||||
|
||||
// Was the value of the variable modified by the computation?
|
||||
// (Always true, unless `return_updated_values_for_all_resources` is true.)
|
||||
@ -235,6 +244,19 @@ class XlaCompiler {
|
||||
// device is created, and can be used to create metadata objects
|
||||
// that can be accessed by XLA op kernels.
|
||||
std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
|
||||
|
||||
// If not nullptr, this memory allocator can be used by the compiler for
|
||||
// temporary allocations it might want to make during compilation.
|
||||
//
|
||||
// For example, the compiler may want to try out different algorithms and
|
||||
// choose the fastest one, and it might run those algorithms over buffers
|
||||
// created using this allocator.
|
||||
//
|
||||
// The compiler can function correctly without an explicit allocator given
|
||||
// here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
|
||||
// allocate most or all available memory on the device, leaving none for the
|
||||
// compiler to access, unless it can use TensorFlow's allocator.
|
||||
xla::DeviceMemoryAllocator* device_allocator = nullptr;
|
||||
};
|
||||
|
||||
explicit XlaCompiler(Options options);
|
||||
@ -253,11 +275,10 @@ class XlaCompiler {
|
||||
const std::vector<Argument>& args,
|
||||
CompilationResult* result);
|
||||
|
||||
Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func,
|
||||
const std::vector<DataType>& types,
|
||||
const std::vector<TensorShape>& shapes,
|
||||
const std::vector<const XlaExpression*>& expressions,
|
||||
std::vector<Argument>* args);
|
||||
// Returns the shape of the XLA parameter for an argument 'arg'.
|
||||
// See the class comment for more details about the argument passing
|
||||
// convention.
|
||||
static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
|
||||
|
||||
// Retrieves the channel handle associated with `key`. Allocates
|
||||
// a new channel handle if none exists.
|
||||
|
@ -191,10 +191,10 @@ TEST_F(XlaCompilerTest, Simple) {
|
||||
std::vector<XlaCompiler::Argument> args(2);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
||||
args[0].shape = TensorShape({2});
|
||||
args[1].kind = XlaCompiler::Argument::kParameter;
|
||||
args[1].type = DT_INT32;
|
||||
args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
||||
args[1].shape = TensorShape({2});
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
@ -242,10 +242,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
|
||||
std::vector<XlaCompiler::Argument> args(2);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
||||
args[0].shape = TensorShape({2});
|
||||
args[1].kind = XlaCompiler::Argument::kParameter;
|
||||
args[1].type = DT_INT32;
|
||||
args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
||||
args[1].shape = TensorShape({2});
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
@ -281,7 +281,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
|
||||
std::vector<XlaCompiler::Argument> args(1);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
||||
args[0].shape = TensorShape({2});
|
||||
|
||||
XlaCompiler::Options options = DefaultOptions();
|
||||
XlaCompiler compiler(options);
|
||||
@ -373,7 +373,7 @@ TEST_F(XlaCompilerTest, ResourceManager) {
|
||||
std::vector<XlaCompiler::Argument> args(1);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
||||
args[0].shape = TensorShape({2});
|
||||
|
||||
DummyResourceForTest* resource = new DummyResourceForTest();
|
||||
|
||||
@ -420,7 +420,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
|
||||
std::vector<XlaCompiler::Argument> args(1);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
|
||||
args[0].shape = TensorShape({2});
|
||||
|
||||
// Compiles the graph.
|
||||
auto options = DefaultOptions();
|
||||
@ -472,9 +472,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
|
||||
args[0].resource_kind = XlaResource::kTensorArray;
|
||||
args[0].initialized = true;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::S32, {2}),
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {2})});
|
||||
args[0].shape = TensorShape({});
|
||||
args[0].tensor_array_size = 2;
|
||||
args[0].tensor_array_gradients = {"grad2"};
|
||||
|
||||
@ -540,9 +538,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
|
||||
args[0].resource_kind = XlaResource::kTensorArray;
|
||||
args[0].initialized = true;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::S32, {2}),
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {2})});
|
||||
args[0].shape = TensorShape({});
|
||||
args[0].tensor_array_size = 2;
|
||||
args[0].tensor_array_gradients = {"grad1"};
|
||||
|
||||
@ -574,9 +570,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
|
||||
args[0].resource_kind = XlaResource::kTensorArray;
|
||||
args[0].initialized = true;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::S32, {2}),
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {2})});
|
||||
args[0].shape = TensorShape({});
|
||||
args[0].tensor_array_size = 2;
|
||||
args[0].tensor_array_gradients = {"grad1"};
|
||||
|
||||
|
@ -103,12 +103,14 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
|
||||
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
|
||||
|
||||
Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
|
||||
string name, DataType type,
|
||||
const xla::ComputationDataHandle& handle,
|
||||
XlaResource** resource) {
|
||||
Status XlaContext::CreateResource(
|
||||
XlaResource::Kind kind, int arg_num, string name, DataType type,
|
||||
TensorShape shape, const xla::ComputationDataHandle& handle,
|
||||
int64 tensor_array_size, const std::set<string>& tensor_array_gradients,
|
||||
XlaResource** resource) {
|
||||
resources_.emplace_back(
|
||||
new XlaResource(kind, arg_num, std::move(name), type, handle));
|
||||
new XlaResource(kind, arg_num, std::move(name), type, std::move(shape),
|
||||
handle, tensor_array_size, tensor_array_gradients));
|
||||
*resource = resources_.back().get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -71,11 +71,15 @@ class XlaContext : public ResourceBase {
|
||||
Status AddConstRetval(int retval_index, DataType dtype,
|
||||
const xla::Literal& literal);
|
||||
|
||||
// Creates a resource with resource `kind` and initial type `type` and
|
||||
// value `handle`. `name` is a descriptive name for use in error messages.
|
||||
// Creates a resource with resource `kind` and initial value `handle`. `name`
|
||||
// is a descriptive name for use in error messages. See the `XlaResource`
|
||||
// constructor for a description of the remaining arguments.
|
||||
// Fails if the resource already exists.
|
||||
Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
|
||||
DataType type, const xla::ComputationDataHandle& handle,
|
||||
DataType type, TensorShape shape,
|
||||
const xla::ComputationDataHandle& handle,
|
||||
int64 tensor_array_size,
|
||||
const std::set<string>& tensor_array_gradients,
|
||||
XlaResource** resource);
|
||||
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources() {
|
||||
|
@ -286,7 +286,8 @@ Status XlaOpKernelContext::ConstantInputList(
|
||||
}
|
||||
|
||||
Status XlaOpKernelContext::ReadVariableInput(
|
||||
int index, xla::ComputationDataHandle* value) {
|
||||
int index, DataType type, TensorShape* shape,
|
||||
xla::ComputationDataHandle* value) {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
XlaResource* variable = expression->resource();
|
||||
@ -296,7 +297,15 @@ Status XlaOpKernelContext::ReadVariableInput(
|
||||
return errors::InvalidArgument("Read of uninitialized variable ",
|
||||
variable->name());
|
||||
}
|
||||
if (variable->type() != type) {
|
||||
return errors::InvalidArgument(
|
||||
"Type mismatch for read of variable ", variable->name(), ". Expected ",
|
||||
DataTypeString(type), "; got ", DataTypeString(variable->type()));
|
||||
}
|
||||
*value = variable->value();
|
||||
if (shape) {
|
||||
*shape = variable->shape();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -312,12 +321,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
||||
variable->name());
|
||||
}
|
||||
*type = variable->type();
|
||||
auto shape_or_status = builder()->GetShape(variable->value());
|
||||
if (!shape_or_status.ok()) {
|
||||
return shape_or_status.status();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape));
|
||||
*shape = variable->shape();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -405,7 +409,17 @@ Status XlaOpKernelContext::AssignVariable(
|
||||
XlaResource* variable = expression->resource();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||
return variable->SetValue(type, handle);
|
||||
|
||||
auto shape_or_status = builder()->GetShape(handle);
|
||||
if (!shape_or_status.ok()) {
|
||||
return shape_or_status.status();
|
||||
}
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
|
||||
|
||||
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
|
||||
return variable->SetValue(handle);
|
||||
}
|
||||
|
||||
XlaCompiler* XlaOpKernelContext::compiler() const {
|
||||
|
@ -164,11 +164,16 @@ class XlaOpKernelContext {
|
||||
TensorShape* shape) const;
|
||||
|
||||
// Reads the current value of the resouce variable referred to by input
|
||||
// 'index'.
|
||||
Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
|
||||
// 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the
|
||||
// variable. Returns an error if the variable has not been initialized, or if
|
||||
// its type does not match `type`.
|
||||
Status ReadVariableInput(int index, DataType type, TensorShape* shape,
|
||||
xla::ComputationDataHandle* value);
|
||||
|
||||
// Assigns the value `handle` to the variable referenced by input
|
||||
// `input_index`. Marks the operator as having side effects.
|
||||
// `input_index`. The variable must be of `type`. Returns an error if the
|
||||
// variable has been initialized with a different type or with a
|
||||
// different shape.
|
||||
Status AssignVariable(int input_index, DataType type,
|
||||
const xla::ComputationDataHandle& handle);
|
||||
|
||||
|
@ -25,51 +25,99 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
XlaResource::XlaResource(Kind kind, int arg_num, string name,
|
||||
DataType initial_type,
|
||||
const xla::ComputationDataHandle& initial_value)
|
||||
XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
|
||||
TensorShape shape,
|
||||
const xla::ComputationDataHandle& initial_value,
|
||||
int64 tensor_array_size,
|
||||
const std::set<string>& tensor_array_gradients)
|
||||
: kind_(kind),
|
||||
arg_num_(arg_num),
|
||||
name_(std::move(name)),
|
||||
type_(initial_type),
|
||||
type_(type),
|
||||
shape_(std::move(shape)),
|
||||
value_(initial_value),
|
||||
initial_value_(initial_value) {
|
||||
initial_value_(initial_value),
|
||||
tensor_array_size_(tensor_array_size) {
|
||||
CHECK(kind_ != kInvalid);
|
||||
|
||||
for (const string& gradient : tensor_array_gradients) {
|
||||
tensor_array_gradients_[gradient].reset(
|
||||
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
|
||||
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
|
||||
type_, shape_, xla::ComputationDataHandle(),
|
||||
tensor_array_size_, /*tensor_array_gradients=*/{}));
|
||||
}
|
||||
}
|
||||
|
||||
Status XlaResource::SetValue(DataType type,
|
||||
const xla::ComputationDataHandle& value) {
|
||||
if (type_ == DT_INVALID && type == DT_INVALID) {
|
||||
return errors::InvalidArgument("Attempted to initialized resource ", name_,
|
||||
" to an invalid type");
|
||||
Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
|
||||
if (type == DT_INVALID) {
|
||||
return errors::InvalidArgument("Attempted to set type of resource '", name_,
|
||||
"'' to an invalid type");
|
||||
}
|
||||
if (type_ != DT_INVALID && type_ != type) {
|
||||
if (initialized() && type_ != type) {
|
||||
return errors::InvalidArgument("Type of resource ", name_,
|
||||
" cannot be changed after initialization: "
|
||||
"old type was ",
|
||||
DataTypeString(type_), ", new type is ",
|
||||
DataTypeString(type));
|
||||
}
|
||||
if (initialized() && shape_ != shape) {
|
||||
return errors::InvalidArgument("Shape of resource ", name_,
|
||||
" cannot be changed after initialization: "
|
||||
"old shape was ",
|
||||
shape_.DebugString(), ", new shape is ",
|
||||
shape.DebugString());
|
||||
}
|
||||
type_ = type;
|
||||
shape_ = shape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaResource::SetValue(const xla::ComputationDataHandle& value) {
|
||||
if (type_ == DT_INVALID) {
|
||||
return errors::InvalidArgument(
|
||||
"Resource '", name_,
|
||||
"' must be initialized with a valid type before use.");
|
||||
}
|
||||
value_ = value;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder,
|
||||
xla::Shape* shape) const {
|
||||
auto shape_or_status = builder->GetShape(value_);
|
||||
if (!shape_or_status.ok()) {
|
||||
return shape_or_status.status();
|
||||
Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) {
|
||||
if (type_ == DT_INVALID) {
|
||||
return errors::InvalidArgument(
|
||||
"Resource '", name_,
|
||||
"' must be initialized with a valid type before use.");
|
||||
}
|
||||
*shape = *shape_or_status.ValueOrDie();
|
||||
return Status::OK();
|
||||
}
|
||||
switch (kind_) {
|
||||
case kVariable: {
|
||||
value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
|
||||
shape_.dim_sizes());
|
||||
break;
|
||||
}
|
||||
case kTensorArray: {
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(tensor_array_size_);
|
||||
ta_shape.AppendShape(shape_);
|
||||
value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
|
||||
ta_shape.dim_sizes());
|
||||
break;
|
||||
}
|
||||
case kStack: {
|
||||
TensorShape ta_shape;
|
||||
ta_shape.AddDim(tensor_array_size_);
|
||||
ta_shape.AppendShape(shape_);
|
||||
value_ =
|
||||
builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_),
|
||||
ta_shape.dim_sizes()),
|
||||
builder->ConstantR0<int32>(0)});
|
||||
break;
|
||||
}
|
||||
|
||||
Status XlaResource::GetShape(xla::ComputationBuilder* builder,
|
||||
TensorShape* shape) const {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape));
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape));
|
||||
case kInvalid:
|
||||
default:
|
||||
LOG(FATAL) << "Invalid resource type";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -82,36 +130,20 @@ Status XlaResource::GetOrCreateTensorArrayGradient(
|
||||
std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
|
||||
if (!gradient) {
|
||||
TensorShape ta_shape;
|
||||
TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape));
|
||||
ta_shape.AddDim(tensor_array_size_);
|
||||
ta_shape.AppendShape(shape_);
|
||||
xla::ComputationDataHandle gradient_value = builder->Broadcast(
|
||||
XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
|
||||
gradient.reset(
|
||||
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
|
||||
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
|
||||
type_, gradient_value));
|
||||
gradient->tensor_array_size_ = tensor_array_size_;
|
||||
type_, shape_, gradient_value, tensor_array_size_,
|
||||
/*tensor_array_gradients=*/{}));
|
||||
}
|
||||
*gradient_out = gradient.get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaResource::PackedShape(xla::ComputationBuilder* builder,
|
||||
xla::Shape* packed_shape) const {
|
||||
if (tensor_array_gradients_.empty()) {
|
||||
return GetXlaShape(builder, packed_shape);
|
||||
}
|
||||
TF_RET_CHECK(kind_ == kTensorArray);
|
||||
std::vector<xla::Shape> elem_shapes(1 + tensor_array_gradients_.size());
|
||||
int pos = 0;
|
||||
TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++]));
|
||||
for (const auto& gradient : tensor_array_gradients_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
gradient.second->GetXlaShape(builder, &elem_shapes[pos++]));
|
||||
}
|
||||
*packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaResource::Pack(xla::ComputationDataHandle* pack,
|
||||
xla::ComputationBuilder* builder) const {
|
||||
if (tensor_array_gradients_.empty()) {
|
||||
@ -130,27 +162,32 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack,
|
||||
|
||||
Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
|
||||
const xla::ComputationDataHandle& pack,
|
||||
bool reset_initial_values,
|
||||
xla::ComputationBuilder* builder) {
|
||||
if (gradient_sources.empty()) {
|
||||
if (!initialized()) {
|
||||
initial_value_ = pack;
|
||||
}
|
||||
value_ = pack;
|
||||
} else {
|
||||
TF_RET_CHECK(kind_ == kTensorArray);
|
||||
int pos = 0;
|
||||
value_ = builder->GetTupleElement(pack, pos++);
|
||||
auto v = builder->GetTupleElement(pack, pos++);
|
||||
if (!initialized()) {
|
||||
initial_value_ = v;
|
||||
}
|
||||
value_ = v;
|
||||
|
||||
for (const auto& source : gradient_sources) {
|
||||
XlaResource* gradient;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetOrCreateTensorArrayGradient(source, builder, &gradient));
|
||||
gradient->value_ = builder->GetTupleElement(pack, pos++);
|
||||
if (reset_initial_values) {
|
||||
gradient->initial_value_ = gradient->value_;
|
||||
auto v = builder->GetTupleElement(pack, pos++);
|
||||
if (!gradient->initialized()) {
|
||||
gradient->initial_value_ = v;
|
||||
}
|
||||
gradient->value_ = v;
|
||||
}
|
||||
}
|
||||
if (reset_initial_values) {
|
||||
initial_value_ = value_;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -36,8 +36,11 @@ class XlaResource {
|
||||
kStack,
|
||||
};
|
||||
|
||||
XlaResource(Kind kind, int arg_num, string name, DataType initial_type,
|
||||
const xla::ComputationDataHandle& initial_value);
|
||||
XlaResource(Kind kind, int arg_num, string name, DataType type,
|
||||
TensorShape shape,
|
||||
const xla::ComputationDataHandle& initial_value,
|
||||
int64 tensor_array_size,
|
||||
const std::set<string>& tensor_array_gradients);
|
||||
|
||||
XlaResource(const XlaResource&) = delete;
|
||||
XlaResource(XlaResource&&) = delete;
|
||||
@ -60,6 +63,12 @@ class XlaResource {
|
||||
// a resource is first initialized we do not yet know its type, so we keep
|
||||
// track of its type dynamically.
|
||||
DataType type() const { return type_; }
|
||||
|
||||
// Shape of the resource. For an uninitialized resource, this is ignored.
|
||||
// For a Variable, this is the shape of the value. For a TensorArray or Stack
|
||||
// this is the shape of each entry in the TensorArray/Stack.
|
||||
const TensorShape& shape() const { return shape_; }
|
||||
|
||||
const xla::ComputationDataHandle& value() const { return value_; }
|
||||
|
||||
// Value of the resource at computation entry. Used to detect which
|
||||
@ -68,17 +77,19 @@ class XlaResource {
|
||||
return initial_value_;
|
||||
}
|
||||
|
||||
// A variable is initialized if it has a value.
|
||||
bool initialized() const { return value_.handle() > 0; }
|
||||
|
||||
// Sets the current type/value of the resource.
|
||||
Status SetValue(DataType type, const xla::ComputationDataHandle& value);
|
||||
// Sets the type and shape of the resource. The type and shape of a resource
|
||||
// must not change once the variable has been initialized.
|
||||
Status SetTypeAndShape(DataType type, const TensorShape& shape);
|
||||
|
||||
// Returns the shape of the resource as an xla::Shape.
|
||||
Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const;
|
||||
// Sets the current value of the resource. Returns an error if the type is not
|
||||
// set to a valid value.
|
||||
Status SetValue(const xla::ComputationDataHandle& value);
|
||||
|
||||
// Returns the shape of the resource as an TensorShape. Fails if the shape is
|
||||
// not representable as a TensorShape.
|
||||
Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const;
|
||||
// Sets the current value of the resource to an all-zero value.
|
||||
Status SetZeroValue(xla::ComputationBuilder* builder);
|
||||
|
||||
// Looks up the gradient for `source`, or creates it if it does not already
|
||||
// exist. The call target must be an initialized TensorArray resource. A
|
||||
@ -96,10 +107,6 @@ class XlaResource {
|
||||
Status Pack(xla::ComputationDataHandle* pack,
|
||||
xla::ComputationBuilder* builder) const;
|
||||
|
||||
// Returns the shape of the `pack` value computed by `Pack()`.
|
||||
Status PackedShape(xla::ComputationBuilder* builder,
|
||||
xla::Shape* packed_shape) const;
|
||||
|
||||
// Updates the resource with values from `pack`. If `gradient_sources` is
|
||||
// non-empty, treats `pack` as a tuple that represents a TensorArray and
|
||||
// its gradients, and unpacks and updates the gradient resources.
|
||||
@ -108,14 +115,14 @@ class XlaResource {
|
||||
// Opposite of Pack().
|
||||
Status SetFromPack(const std::set<string>& gradient_sources,
|
||||
const xla::ComputationDataHandle& pack,
|
||||
bool reset_initial_values,
|
||||
xla::ComputationBuilder* builder);
|
||||
|
||||
// TensorArray-specific fields
|
||||
// TensorArray and Stack specific fields
|
||||
|
||||
// 'tensor_array_size' stores the expected size of the TensorArray or Stack.
|
||||
// We need to store this since sometimes TensorArrays must be initialized
|
||||
// lazily since we do not know the element shape at construction time.
|
||||
// Used by both TensorArrays and Stacks.
|
||||
int64 tensor_array_size() const { return tensor_array_size_; }
|
||||
void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
|
||||
|
||||
@ -136,6 +143,7 @@ class XlaResource {
|
||||
const string name_;
|
||||
|
||||
DataType type_;
|
||||
TensorShape shape_;
|
||||
xla::ComputationDataHandle value_;
|
||||
xla::ComputationDataHandle initial_value_;
|
||||
|
||||
|
@ -88,7 +88,6 @@ cc_library(
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
@ -80,6 +80,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "executable_build_options",
|
||||
srcs = ["executable_build_options.cc"],
|
||||
hdrs = ["executable_build_options.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "local_client",
|
||||
srcs = ["local_client.cc"],
|
||||
@ -87,6 +99,7 @@ cc_library(
|
||||
deps = [
|
||||
":client",
|
||||
":computation",
|
||||
":executable_build_options",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
79
tensorflow/compiler/xla/client/executable_build_options.cc
Normal file
79
tensorflow/compiler/xla/client/executable_build_options.cc
Normal file
@ -0,0 +1,79 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator(
|
||||
DeviceMemoryAllocator* allocator) {
|
||||
device_allocator_ = allocator;
|
||||
return *this;
|
||||
}
|
||||
|
||||
DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
|
||||
return device_allocator_;
|
||||
}
|
||||
|
||||
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
|
||||
int device_ordinal) {
|
||||
CHECK_GE(device_ordinal, 0);
|
||||
device_ordinal_ = device_ordinal;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; }
|
||||
|
||||
ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout(
|
||||
const Shape& shape_with_layout) {
|
||||
result_layout_set_ = true;
|
||||
result_layout_ = shape_with_layout;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Shape* ExecutableBuildOptions::result_layout() const {
|
||||
return result_layout_set_ ? &result_layout_ : nullptr;
|
||||
}
|
||||
|
||||
string ExecutableBuildOptions::ToString() const {
|
||||
string result_layout = "nullopt";
|
||||
if (result_layout_set_) {
|
||||
result_layout = ShapeUtil::HumanStringWithLayout(result_layout_);
|
||||
}
|
||||
string generate_hlo_graph = "nullopt";
|
||||
if (generate_hlo_graph_.has_value()) {
|
||||
generate_hlo_graph = generate_hlo_graph_.value();
|
||||
}
|
||||
return tensorflow::strings::Printf(
|
||||
"ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, "
|
||||
"generate_hlo_graph=%s}",
|
||||
device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str());
|
||||
}
|
||||
|
||||
ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph(
|
||||
string regex) {
|
||||
generate_hlo_graph_ = std::move(regex);
|
||||
return *this;
|
||||
}
|
||||
|
||||
const tensorflow::gtl::optional<string>&
|
||||
ExecutableBuildOptions::generate_hlo_graph() const {
|
||||
return generate_hlo_graph_;
|
||||
}
|
||||
|
||||
} // namespace xla
|
74
tensorflow/compiler/xla/client/executable_build_options.h
Normal file
74
tensorflow/compiler/xla/client/executable_build_options.h
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/optional.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Class containing options for building an LocalExecutable with
|
||||
// LocalClient::Compile.
|
||||
class ExecutableBuildOptions {
|
||||
public:
|
||||
// If set, this is the device to build the computation for. Valid
|
||||
// device_ordinal values are: 0 to # of devices - 1. These values are
|
||||
// identical to the device ordinal values used by StreamExecutor. The built
|
||||
// executable will be executable on any device equivalent to the specified
|
||||
// device as determined by Backend::devices_equivalent(). A value of -1
|
||||
// indicates this option has not been set.
|
||||
ExecutableBuildOptions& set_device_ordinal(int device_ordinal);
|
||||
int device_ordinal() const;
|
||||
|
||||
// If set, this specifies the layout of the result of the computation. If not
|
||||
// set, the service will chose the layout of the result. A Shape is used to
|
||||
// store the layout to accommodate tuple result shapes. A value of nullptr
|
||||
// indicates the option has not been set.
|
||||
ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
|
||||
const Shape* result_layout() const;
|
||||
|
||||
// If set, this specifies an allocator that can be used to allocate temporary
|
||||
// space on the device during compilation. For example, the compiler might
|
||||
// want to run various algorithms on the device and pick the fastest one -- it
|
||||
// might allocate buffers for use by these algorithms using this allocator.
|
||||
//
|
||||
// This does not need to be the same as the DeviceMemoryAllocator passed when
|
||||
// running the executable.
|
||||
ExecutableBuildOptions& set_device_allocator(
|
||||
DeviceMemoryAllocator* allocator);
|
||||
DeviceMemoryAllocator* device_allocator() const;
|
||||
|
||||
// If set, specifies a regexp of HLO graphs to dump (as in DebugOptions).
|
||||
ExecutableBuildOptions& set_generate_hlo_graph(string regex);
|
||||
const tensorflow::gtl::optional<string>& generate_hlo_graph() const;
|
||||
|
||||
// Returns a string representation of the build options, suitable for
|
||||
// debugging.
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
int device_ordinal_ = -1;
|
||||
Shape result_layout_;
|
||||
bool result_layout_set_ = false;
|
||||
tensorflow::gtl::optional<string> generate_hlo_graph_;
|
||||
DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
|
@ -30,25 +30,6 @@ using xla::source_map_util::InvalidParameterArgument;
|
||||
|
||||
namespace xla {
|
||||
|
||||
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
|
||||
int device_ordinal) {
|
||||
device_ordinal_ = device_ordinal;
|
||||
return *this;
|
||||
}
|
||||
|
||||
int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; }
|
||||
|
||||
ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout(
|
||||
const Shape& shape_with_layout) {
|
||||
result_layout_set_ = true;
|
||||
result_layout_ = shape_with_layout;
|
||||
return *this;
|
||||
}
|
||||
|
||||
const Shape* ExecutableBuildOptions::result_layout() const {
|
||||
return result_layout_set_ ? &result_layout_ : nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
|
||||
Backend* backend) {
|
||||
@ -60,16 +41,18 @@ StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
|
||||
} // namespace
|
||||
|
||||
LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
|
||||
Backend* backend, int device_ordinal,
|
||||
const ExecutableBuildOptions& build_options)
|
||||
Backend* backend,
|
||||
ExecutableBuildOptions build_options)
|
||||
: executable_(std::move(executable)),
|
||||
backend_(backend),
|
||||
build_device_ordinal_(device_ordinal),
|
||||
build_options_(build_options) {}
|
||||
build_options_(std::move(build_options)) {
|
||||
CHECK_GE(build_options_.device_ordinal(), 0)
|
||||
<< "Must have a valid device ordinal that the executable was built for.";
|
||||
}
|
||||
|
||||
tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const ExecutableRunOptions& options, const Backend& backend) {
|
||||
const ExecutableRunOptions& run_options, const Backend& backend) {
|
||||
const ComputationLayout& computation_layout =
|
||||
executable_->module_config().entry_computation_layout();
|
||||
|
||||
@ -93,14 +76,14 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
||||
}
|
||||
}
|
||||
|
||||
if (options.stream() != nullptr) {
|
||||
if (!options.stream()->ok()) {
|
||||
if (run_options.stream() != nullptr) {
|
||||
if (!run_options.stream()->ok()) {
|
||||
return InvalidArgument("stream is uninitialized or in an error state");
|
||||
}
|
||||
|
||||
// Check stream matches service platform.
|
||||
const se::Platform* stream_platform =
|
||||
options.stream()->parent()->platform();
|
||||
run_options.stream()->parent()->platform();
|
||||
if (stream_platform != backend_->platform()) {
|
||||
return InvalidArgument(
|
||||
"stream is for platform %s, but service targets platform %s",
|
||||
@ -110,7 +93,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
||||
|
||||
// Cannot specify device_ordinal with a stream. The stream determines these
|
||||
// values.
|
||||
if (options.device_ordinal() != -1) {
|
||||
if (run_options.device_ordinal() != -1) {
|
||||
return InvalidArgument(
|
||||
"cannot set both device ordinal and stream options in "
|
||||
"ExecutableRunOptions; the stream determines the device ordinal");
|
||||
@ -119,34 +102,34 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
||||
|
||||
// Verify that the device the executable was built for is equivalent to the
|
||||
// device it will run on.
|
||||
int run_device_ordinal = options.device_ordinal() == -1
|
||||
int run_device_ordinal = run_options.device_ordinal() == -1
|
||||
? backend_->default_device_ordinal()
|
||||
: options.device_ordinal();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool devices_equivalent,
|
||||
backend_->devices_equivalent(run_device_ordinal, build_device_ordinal_));
|
||||
: run_options.device_ordinal();
|
||||
TF_ASSIGN_OR_RETURN(bool devices_equivalent,
|
||||
backend_->devices_equivalent(
|
||||
run_device_ordinal, build_options_.device_ordinal()));
|
||||
if (!devices_equivalent) {
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
|
||||
backend_->stream_executor(run_device_ordinal));
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
|
||||
backend_->stream_executor(build_device_ordinal_));
|
||||
backend_->stream_executor(build_device_ordinal()));
|
||||
return InvalidArgument(
|
||||
"executable is built for device %s of type \"%s\"; cannot run it on "
|
||||
"device %s of type \"%s\"",
|
||||
backend_->device_name(build_device_ordinal_).c_str(),
|
||||
backend_->device_name(build_device_ordinal()).c_str(),
|
||||
build_executor->GetDeviceDescription().name().c_str(),
|
||||
backend_->device_name(run_device_ordinal).c_str(),
|
||||
run_executor->GetDeviceDescription().name().c_str());
|
||||
}
|
||||
|
||||
if (!options.allocator()) {
|
||||
if (!run_options.allocator()) {
|
||||
return InvalidArgument("an allocator must be provided to ExecuteLocally");
|
||||
}
|
||||
|
||||
if (options.allocator()->platform() != backend.platform()) {
|
||||
if (run_options.allocator()->platform() != backend.platform()) {
|
||||
return InvalidArgument(
|
||||
"allocator platform (%s) does not match service platform (%s)",
|
||||
options.allocator()->platform()->Name().c_str(),
|
||||
run_options.allocator()->platform()->Name().c_str(),
|
||||
backend.platform()->Name().c_str());
|
||||
}
|
||||
|
||||
@ -155,23 +138,22 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
|
||||
|
||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const ExecutableRunOptions& options) {
|
||||
TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_));
|
||||
|
||||
ExecutableRunOptions actual_options = options;
|
||||
ExecutableRunOptions run_options) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateExecutionOptions(arguments, run_options, *backend_));
|
||||
|
||||
Backend::StreamPtr stream;
|
||||
if (options.stream() == nullptr) {
|
||||
if (run_options.stream() == nullptr) {
|
||||
// NB! The lifetime of `stream` needs to match the lifetime of
|
||||
// `actual_options` (otherwise we will end up using a returned stream in
|
||||
// ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
|
||||
// scope.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
stream, BorrowStreamForDevice(options.device_ordinal(), backend_));
|
||||
actual_options.set_stream(stream.get());
|
||||
stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
|
||||
run_options.set_stream(stream.get());
|
||||
}
|
||||
if (options.allocator() == nullptr) {
|
||||
actual_options.set_allocator(backend_->memory_allocator());
|
||||
if (run_options.allocator() == nullptr) {
|
||||
run_options.set_allocator(backend_->memory_allocator());
|
||||
}
|
||||
|
||||
// For local client execution on CPU backends:
|
||||
@ -180,7 +162,7 @@ StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
|
||||
// *) The thread pool used for XLA CPU ops is from
|
||||
// backend_->eigen_intra_op_thread_pool().
|
||||
ServiceExecutableRunOptions service_options(
|
||||
actual_options, backend_->StreamBorrower(),
|
||||
run_options, backend_->StreamBorrower(),
|
||||
backend_->eigen_intra_op_thread_pool());
|
||||
|
||||
if (executable_->dumping()) {
|
||||
@ -189,9 +171,8 @@ StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<ShapedBuffer> result,
|
||||
executable_->ExecuteOnStreamWrapper(
|
||||
&service_options, options.execution_profile(), arguments));
|
||||
return ScopedShapedBuffer::MakeScoped(result.get(),
|
||||
actual_options.allocator());
|
||||
&service_options, run_options.execution_profile(), arguments));
|
||||
return ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::ExecuteAndDump(
|
||||
@ -267,16 +248,19 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
|
||||
const Computation& computation,
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
const ExecutableBuildOptions& options) {
|
||||
int device_ordinal = options.device_ordinal() == -1
|
||||
? default_device_ordinal()
|
||||
: options.device_ordinal();
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
|
||||
local_service_->CompileExecutable(
|
||||
computation.handle(), argument_layouts,
|
||||
options.result_layout(), device_ordinal));
|
||||
ExecutableBuildOptions updated_options = options;
|
||||
if (options.device_ordinal() == -1) {
|
||||
updated_options.set_device_ordinal(default_device_ordinal());
|
||||
VLOG(3) << "Set device ordinal to default value of: "
|
||||
<< updated_options.device_ordinal();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
local_service_->CompileExecutable(computation.handle(), argument_layouts,
|
||||
updated_options));
|
||||
return WrapUnique(new LocalExecutable(std::move(executable),
|
||||
local_service_->mutable_backend(),
|
||||
device_ordinal, options));
|
||||
updated_options));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>>
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
@ -33,39 +34,13 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Class containing options for building an LocalExecutable with
|
||||
// LocalClient::Compile.
|
||||
class ExecutableBuildOptions {
|
||||
public:
|
||||
// If set, this is the device to build the computation for. Valid
|
||||
// device_ordinal values are: 0 to # of devices - 1. These values are
|
||||
// identical to the device ordinal values used by StreamExecutor. The built
|
||||
// executable will be executable on any device equivalent to the specified
|
||||
// device as determined by Backend::devices_equivalent(). A value of -1
|
||||
// indicates this option has not been set.
|
||||
ExecutableBuildOptions& set_device_ordinal(int device_ordinal);
|
||||
int device_ordinal() const;
|
||||
|
||||
// If set, this specifies the layout of the result of the computation. If not
|
||||
// set, the service will chose the layout of the result. A Shape is used to
|
||||
// store the layout to accommodate tuple result shapes. A value of nullptr
|
||||
// indicates the option has not been set.
|
||||
ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
|
||||
const Shape* result_layout() const;
|
||||
|
||||
private:
|
||||
int device_ordinal_ = -1;
|
||||
Shape result_layout_;
|
||||
bool result_layout_set_ = false;
|
||||
};
|
||||
|
||||
class LocalExecutable {
|
||||
public:
|
||||
// Run the compiled computation with the given arguments and options and
|
||||
// return the result.
|
||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>> Run(
|
||||
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
const ExecutableRunOptions& options);
|
||||
ExecutableRunOptions run_options);
|
||||
|
||||
// Return the layout (contained in a shape) of the result produced by the
|
||||
// computation.
|
||||
@ -88,8 +63,7 @@ class LocalExecutable {
|
||||
|
||||
// Constructor invoked by LocalClient.
|
||||
LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
|
||||
int device_ordinal,
|
||||
const ExecutableBuildOptions& build_options);
|
||||
ExecutableBuildOptions build_options);
|
||||
|
||||
// Validates that the given arguments and options satisfy various constraints
|
||||
// of the computation.
|
||||
@ -117,19 +91,19 @@ class LocalExecutable {
|
||||
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
|
||||
const ShapedBuffer& shaped_buffer);
|
||||
|
||||
// The ordinal of the device which this executable was compiled for. The
|
||||
// executable can run on all equivalent devices (as determined by
|
||||
// Backend::devices_equivalent).
|
||||
int build_device_ordinal() const { return build_options_.device_ordinal(); }
|
||||
|
||||
// Compiled computation.
|
||||
std::unique_ptr<Executable> executable_;
|
||||
|
||||
// Execution backend.
|
||||
Backend* backend_;
|
||||
|
||||
// The ordinal of the device which this executable was compiled for. The
|
||||
// executable can run on all equivalent devices (as determined by
|
||||
// Backend::devices_equivalent).
|
||||
int build_device_ordinal_;
|
||||
Backend* backend_ = nullptr;
|
||||
|
||||
// Options used to build the executable.
|
||||
const ExecutableBuildOptions& build_options_;
|
||||
const ExecutableBuildOptions build_options_;
|
||||
};
|
||||
|
||||
// An XLA Client specialization for use when the client and service run in
|
||||
|
@ -221,13 +221,19 @@ void AllocateFlags() {
|
||||
flag_values->xla_gpu_disable_multi_streaming(),
|
||||
"If true, multi-streaming in the GPU backend is disabled."),
|
||||
tensorflow::Flag(
|
||||
"xla_dump_hlo_proto_to", flag_values->mutable_xla_dump_hlo_proto_to(),
|
||||
"Dump compilation artifacts as proto binary into this directory."),
|
||||
"xla_dump_optimized_hlo_proto_to",
|
||||
flag_values->mutable_xla_dump_optimized_hlo_proto_to(),
|
||||
"Dump Hlo after all hlo passes are executed as proto binary into "
|
||||
"this directory."),
|
||||
tensorflow::Flag(
|
||||
"xla_dump_prepass_hlo_proto_to",
|
||||
flag_values->mutable_xla_dump_prepass_hlo_proto_to(),
|
||||
"Dump compilation artifacts, before hlo passes are executed, as "
|
||||
"proto binary into this directory."),
|
||||
"xla_dump_unoptimized_hlo_proto_to",
|
||||
flag_values->mutable_xla_dump_unoptimized_hlo_proto_to(),
|
||||
"Dump HLO before any hlo passes are executed as proto binary into "
|
||||
"this directory."),
|
||||
tensorflow::Flag("xla_dump_per_pass_hlo_proto_to",
|
||||
flag_values->mutable_xla_dump_per_pass_hlo_proto_to(),
|
||||
"Dump HLO after each pass as an HloProto in binary file "
|
||||
"format into this directory."),
|
||||
tensorflow::Flag(
|
||||
"xla_test_all_output_layouts",
|
||||
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
|
||||
|
@ -486,6 +486,7 @@ class Literal {
|
||||
std::vector<std::unique_ptr<Literal>> elements);
|
||||
|
||||
// Returns a string representation of the literal value.
|
||||
// Warning: this function can take minutes for multi-million element Literals.
|
||||
string ToString(bool print_layout = false) const;
|
||||
|
||||
// Invokes the "per cell" callback for each element in the provided
|
||||
|
@ -49,6 +49,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:shaped_buffer",
|
||||
"//tensorflow/core:framework_lite",
|
||||
|
@ -98,15 +98,25 @@ const std::unique_ptr<ScopedShapedBuffer>& LocalShapedBuffer::shaped_buffer()
|
||||
return shaped_buffer_;
|
||||
}
|
||||
|
||||
static StatusOr<std::unique_ptr<ScopedShapedBuffer>> ToBuffer(
|
||||
LocalClient* client, int device_ordinal, const Literal& arg) {
|
||||
return client->LiteralToShapedBuffer(arg, device_ordinal,
|
||||
client->backend().memory_allocator());
|
||||
}
|
||||
|
||||
/* static */
|
||||
LocalShapedBuffer* LocalShapedBuffer::FromLiteral(const Literal& argument) {
|
||||
LocalShapedBuffer* LocalShapedBuffer::FromLiteral(
|
||||
const Literal& argument,
|
||||
const tensorflow::gtl::optional<Shape>& shape_with_layout) {
|
||||
LocalClient* client = GetOrCreateLocalClient();
|
||||
std::unique_ptr<ScopedShapedBuffer> buf =
|
||||
client
|
||||
->LiteralToShapedBuffer(argument,
|
||||
/*device_ordinal=*/0,
|
||||
client->backend().memory_allocator())
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<ScopedShapedBuffer> buf;
|
||||
if (shape_with_layout) {
|
||||
std::unique_ptr<Literal> relaid =
|
||||
argument.Relayout(shape_with_layout.value());
|
||||
buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie();
|
||||
} else {
|
||||
buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie();
|
||||
}
|
||||
return new LocalShapedBuffer(std::move(buf));
|
||||
}
|
||||
|
||||
@ -120,7 +130,8 @@ CompiledLocalComputation::CompiledLocalComputation(
|
||||
: executable_(std::move(executable)) {}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
||||
const std::vector<Literal>& arguments) {
|
||||
const std::vector<Literal>& arguments,
|
||||
const std::vector<tensorflow::gtl::optional<Shape>>& shapes_with_layout) {
|
||||
LocalClient* client = GetOrCreateLocalClient();
|
||||
|
||||
VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas.";
|
||||
@ -133,7 +144,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
||||
GetReplicaCount());
|
||||
|
||||
for (int replica = 0; replica < GetReplicaCount(); ++replica) {
|
||||
pool.Schedule([this, client, replica, &arguments, &results] {
|
||||
pool.Schedule([this, client, replica, &arguments, &shapes_with_layout,
|
||||
&results] {
|
||||
StatusOr<int> device_ordinal_status =
|
||||
client->ReplicaNumberToDeviceOrdinal(replica);
|
||||
if (!device_ordinal_status.ok()) {
|
||||
@ -144,18 +156,28 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
|
||||
VLOG(3) << "Replica " << replica
|
||||
<< " mapped to device ordinal for execution: "
|
||||
<< device_ordinal;
|
||||
|
||||
// Transfer arguments in
|
||||
std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
|
||||
scoped_buffers.reserve(arguments.size());
|
||||
for (const Literal& argument : arguments) {
|
||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed =
|
||||
client->LiteralToShapedBuffer(
|
||||
argument, device_ordinal,
|
||||
client->backend().memory_allocator());
|
||||
for (int i = 0; i < arguments.size(); ++i) {
|
||||
const Literal& argument = arguments[i];
|
||||
const tensorflow::gtl::optional<Shape>& shape_with_layout =
|
||||
shapes_with_layout[i];
|
||||
|
||||
StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed;
|
||||
if (shape_with_layout) {
|
||||
std::unique_ptr<Literal> relaid =
|
||||
argument.Relayout(shape_with_layout.value());
|
||||
pushed = ToBuffer(client, device_ordinal, *relaid);
|
||||
} else {
|
||||
pushed = ToBuffer(client, device_ordinal, argument);
|
||||
}
|
||||
if (!pushed.ok()) {
|
||||
results[replica] = pushed.status();
|
||||
return;
|
||||
}
|
||||
|
||||
scoped_buffers.push_back(std::move(pushed).ValueOrDie());
|
||||
}
|
||||
|
||||
@ -233,7 +255,8 @@ LocalComputation::LocalComputation(Computation computation)
|
||||
: computation_(std::move(computation)) {}
|
||||
|
||||
StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
|
||||
const std::vector<Shape>& argument_shapes) {
|
||||
const std::vector<Shape>& argument_shapes,
|
||||
const ExecutableBuildOptions* build_options) {
|
||||
std::vector<const Shape*> argument_shape_pointers;
|
||||
argument_shape_pointers.reserve(argument_shapes.size());
|
||||
for (auto& argument_shape : argument_shapes) {
|
||||
@ -242,6 +265,9 @@ StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
|
||||
|
||||
LocalClient* client = GetOrCreateLocalClient();
|
||||
ExecutableBuildOptions options;
|
||||
if (build_options != nullptr) {
|
||||
options = *build_options;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto local_executable,
|
||||
client->Compile(computation_, argument_shape_pointers, options));
|
||||
@ -363,12 +389,6 @@ LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
|
||||
source, init_value, scatter.computation());
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Select(
|
||||
const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
|
||||
const ComputationDataHandle& on_false) {
|
||||
return builder_.Select(pred, on_true, on_false);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::Tuple(
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
|
||||
return builder_.Tuple(elements);
|
||||
@ -384,6 +404,12 @@ ComputationDataHandle LocalComputationBuilder::Dot(
|
||||
return builder_.Dot(lhs, rhs);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::DotGeneral(
|
||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||
const DotDimensionNumbers& dimension_numbers) {
|
||||
return builder_.DotGeneral(lhs, rhs, dimension_numbers);
|
||||
}
|
||||
|
||||
ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated(
|
||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
@ -483,6 +509,15 @@ ComputationDataHandle LocalComputationBuilder::While(
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
|
||||
(lhs, rhs, broadcast_dimensions))
|
||||
|
||||
#define _FORWARD_TRIOP(method_name) \
|
||||
_FORWARD( \
|
||||
method_name, ComputationDataHandle, \
|
||||
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
|
||||
const ComputationDataHandle& ehs), \
|
||||
(lhs, rhs, ehs))
|
||||
|
||||
_FORWARD_TRIOP(Select)
|
||||
_FORWARD_TRIOP(Clamp)
|
||||
_FORWARD_BINOP(Eq)
|
||||
_FORWARD_BINOP(Ne)
|
||||
_FORWARD_BINOP(Ge)
|
||||
@ -503,6 +538,7 @@ _FORWARD_UNOP(Abs)
|
||||
_FORWARD_UNOP(Exp)
|
||||
_FORWARD_UNOP(Floor)
|
||||
_FORWARD_UNOP(Ceil)
|
||||
_FORWARD_UNOP(Round)
|
||||
_FORWARD_UNOP(Log)
|
||||
_FORWARD_UNOP(Sign)
|
||||
_FORWARD_UNOP(Cos)
|
||||
@ -519,6 +555,7 @@ _FORWARD_UNOP(Sort)
|
||||
#undef _FORWARD
|
||||
#undef _FORWARD_UNOP
|
||||
#undef _FORWARD_BINOP
|
||||
#undef _FORWARD_TRIOP
|
||||
|
||||
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
|
||||
delete local_shaped_buffer;
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -58,7 +59,9 @@ StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
|
||||
// client.
|
||||
class LocalShapedBuffer {
|
||||
public:
|
||||
static LocalShapedBuffer* FromLiteral(const Literal& argument);
|
||||
static LocalShapedBuffer* FromLiteral(
|
||||
const Literal& argument,
|
||||
const tensorflow::gtl::optional<Shape>& shape_with_layout);
|
||||
LocalShapedBuffer(std::unique_ptr<ScopedShapedBuffer> shaped_buffer);
|
||||
const std::unique_ptr<ScopedShapedBuffer>& shaped_buffer() const;
|
||||
std::unique_ptr<Literal> ToLiteral() const;
|
||||
@ -76,8 +79,15 @@ class LocalShapedBuffer {
|
||||
class CompiledLocalComputation {
|
||||
public:
|
||||
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
|
||||
|
||||
// Execute the computation with the given argument literals, and
|
||||
// with optionally-specified argument layouts. The literals will be
|
||||
// re-laid out according to the corresponding elements of
|
||||
// shapes_with_layout.
|
||||
StatusOr<std::unique_ptr<Literal> > Execute(
|
||||
const std::vector<Literal>& arguments);
|
||||
const std::vector<Literal>& arguments,
|
||||
const std::vector<tensorflow::gtl::optional<Shape> >& shapes_with_layout);
|
||||
|
||||
LocalShapedBuffer* ExecuteWithShapedBuffers(
|
||||
tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
|
||||
|
||||
@ -93,7 +103,8 @@ class LocalComputation {
|
||||
public:
|
||||
LocalComputation(Computation computation);
|
||||
StatusOr<CompiledLocalComputation*> Compile(
|
||||
const std::vector<Shape>& argument_shapes);
|
||||
const std::vector<Shape>& argument_shapes,
|
||||
const ExecutableBuildOptions* build_options);
|
||||
const Computation& computation() const;
|
||||
|
||||
private:
|
||||
@ -172,10 +183,6 @@ class LocalComputationBuilder {
|
||||
const ComputationDataHandle& source,
|
||||
const ComputationDataHandle& init_value, const LocalComputation& scatter);
|
||||
|
||||
ComputationDataHandle Select(const ComputationDataHandle& pred,
|
||||
const ComputationDataHandle& on_true,
|
||||
const ComputationDataHandle& on_false);
|
||||
|
||||
ComputationDataHandle Tuple(
|
||||
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
|
||||
|
||||
@ -185,6 +192,10 @@ class LocalComputationBuilder {
|
||||
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
|
||||
const ComputationDataHandle& rhs);
|
||||
|
||||
ComputationDataHandle DotGeneral(
|
||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||
const DotDimensionNumbers& dimension_numbers);
|
||||
|
||||
ComputationDataHandle ConvGeneralDilated(
|
||||
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
|
||||
tensorflow::gtl::ArraySlice<int64> window_strides,
|
||||
@ -252,6 +263,14 @@ class LocalComputationBuilder {
|
||||
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
|
||||
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
|
||||
|
||||
#define _FORWARD_TRIOP(method_name) \
|
||||
_FORWARD( \
|
||||
method_name, ComputationDataHandle, \
|
||||
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
|
||||
const ComputationDataHandle& ehs))
|
||||
|
||||
_FORWARD_TRIOP(Select)
|
||||
_FORWARD_TRIOP(Clamp)
|
||||
_FORWARD_BINOP(Eq)
|
||||
_FORWARD_BINOP(Ne)
|
||||
_FORWARD_BINOP(Ge)
|
||||
@ -272,6 +291,7 @@ class LocalComputationBuilder {
|
||||
_FORWARD_UNOP(Exp)
|
||||
_FORWARD_UNOP(Floor)
|
||||
_FORWARD_UNOP(Ceil)
|
||||
_FORWARD_UNOP(Round)
|
||||
_FORWARD_UNOP(Log)
|
||||
_FORWARD_UNOP(Sign)
|
||||
_FORWARD_UNOP(Cos)
|
||||
@ -288,6 +308,7 @@ class LocalComputationBuilder {
|
||||
#undef _FORWARD
|
||||
#undef _FORWARD_UNOP
|
||||
#undef _FORWARD_BINOP
|
||||
#undef _FORWARD_TRIOP
|
||||
|
||||
private:
|
||||
ComputationBuilder builder_;
|
||||
|
@ -27,12 +27,14 @@ limitations under the License.
|
||||
// ArraySlice<ComputationDataHandle> <- sequence of int
|
||||
// Literal <-> (nested tuple of) numpy ndarray
|
||||
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
|
||||
// Shape <-> pair holding (dtype, dimensions)
|
||||
// std::vector<Shape> <- sequence of shape information pairs
|
||||
// Shape -> pair holding (dtype, dimensions)
|
||||
// <- object duck-typed as xla_client.Shape
|
||||
// std::vector<Shape> <- sequence of xla_client.Shape objects
|
||||
// PrimitiveType <- int
|
||||
// ArraySlice<pair<int64, in64>> <- sequence of int pairs
|
||||
// PaddingConfig proto <- corresponding Python proto
|
||||
// ConvolutionDimensionNumbers proto <- corresponding Python proto
|
||||
// DotDimensionNumbers proto <- corresponding Python proto
|
||||
//
|
||||
// Arrows indicate whether a conversion only ever occurs in one
|
||||
// direction, or whether it is maintained bidirectionally.
|
||||
@ -55,7 +57,7 @@ limitations under the License.
|
||||
// translates to a tuple-shaped XLA Literal, whose component subshapes
|
||||
// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
|
||||
//
|
||||
// The Python objects corresponding to C++ Shapes have the type:
|
||||
// Shapes output by C++ become Python objects with the type:
|
||||
//
|
||||
// T = (dtype, S)
|
||||
// S = DIMENSIONS | TUPLE_SHAPES
|
||||
@ -176,6 +178,16 @@ tensorflow::ImportNumpy();
|
||||
}
|
||||
}
|
||||
|
||||
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
|
||||
if ($1.ok()) {
|
||||
std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
|
||||
$result = numpy::PyObjectFromXlaLiteral(*value);
|
||||
} else {
|
||||
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
%typemap(out) StatusOr<xla::swig::LocalComputation*> {
|
||||
if ($1.ok()) {
|
||||
auto* value = $1.ValueOrDie();
|
||||
@ -343,15 +355,31 @@ tensorflow::ImportNumpy();
|
||||
// Shape
|
||||
|
||||
%typemap(in) const Shape& (Shape temp) {
|
||||
Status shape_status = numpy::CheckPyShapeInfo($input);
|
||||
if (!shape_status.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str());
|
||||
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
|
||||
if (!statusor.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
|
||||
return NULL;
|
||||
}
|
||||
temp = numpy::XlaShapeFromPyShapeInfo($input);
|
||||
temp = std::move(statusor).ValueOrDie();
|
||||
$1 = &temp;
|
||||
}
|
||||
|
||||
%typemap(in) const tensorflow::gtl::optional<Shape>& (
|
||||
tensorflow::gtl::optional<Shape> temp) {
|
||||
if ($input == Py_None) {
|
||||
temp = tensorflow::gtl::nullopt;
|
||||
$1 = &temp;
|
||||
} else {
|
||||
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
|
||||
if (!statusor.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
|
||||
return NULL;
|
||||
}
|
||||
temp = std::move(statusor).ValueOrDie();
|
||||
$1 = &temp;
|
||||
}
|
||||
}
|
||||
|
||||
%typemap(out) std::unique_ptr<Shape> {
|
||||
$result = numpy::PyShapeInfoFromXlaShape(*$1);
|
||||
}
|
||||
@ -364,14 +392,37 @@ tensorflow::ImportNumpy();
|
||||
const int size = PySequence_Size($input);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* o = PySequence_GetItem($input, i);
|
||||
Status shape_status = numpy::CheckPyShapeInfo(o);
|
||||
if (!shape_status.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str());
|
||||
Py_DECREF(o);
|
||||
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
|
||||
Py_DECREF(o);
|
||||
if (!statusor.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
|
||||
return NULL;
|
||||
}
|
||||
temps.push_back(numpy::XlaShapeFromPyShapeInfo(o));
|
||||
Py_DECREF(o);
|
||||
temps.push_back(statusor.ConsumeValueOrDie());
|
||||
}
|
||||
$1 = &temps;
|
||||
}
|
||||
|
||||
%typemap(in) const std::vector<tensorflow::gtl::optional<Shape> >& (
|
||||
std::vector<tensorflow::gtl::optional<Shape> > temps) {
|
||||
if (!PySequence_Check($input)) {
|
||||
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
|
||||
return NULL;
|
||||
}
|
||||
const int size = PySequence_Size($input);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
PyObject* o = PySequence_GetItem($input, i);
|
||||
if (o == Py_None) {
|
||||
temps.push_back(tensorflow::gtl::nullopt);
|
||||
} else {
|
||||
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
|
||||
Py_DECREF(o);
|
||||
if (!statusor.ok()) {
|
||||
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
|
||||
return NULL;
|
||||
}
|
||||
temps.push_back(statusor.ConsumeValueOrDie());
|
||||
}
|
||||
}
|
||||
$1 = &temps;
|
||||
}
|
||||
@ -461,6 +512,135 @@ tensorflow::ImportNumpy();
|
||||
$1 = temps;
|
||||
}
|
||||
|
||||
// DotDimensionNumbers
|
||||
|
||||
%typemap(in) const DotDimensionNumbers&
|
||||
(DotDimensionNumbers dimension_numbers) {
|
||||
int length;
|
||||
|
||||
/* lhs_contracting_dimensions */
|
||||
PyObject* lhs_contracting_dimensions = PyObject_GetAttrString(
|
||||
$input, "lhs_contracting_dimensions");
|
||||
if (!lhs_contracting_dimensions) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
length = PySequence_Size(lhs_contracting_dimensions);
|
||||
if (length == -1) {
|
||||
Py_DECREF(lhs_contracting_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i);
|
||||
if (!item) {
|
||||
Py_DECREF(lhs_contracting_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||
if (dimension == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(item);
|
||||
Py_DECREF(lhs_contracting_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.add_lhs_contracting_dimensions(dimension);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
Py_DECREF(lhs_contracting_dimensions);
|
||||
|
||||
/* rhs_contracting_dimensions */
|
||||
PyObject* rhs_contracting_dimensions = PyObject_GetAttrString(
|
||||
$input, "rhs_contracting_dimensions");
|
||||
if (!lhs_contracting_dimensions) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
length = PySequence_Size(rhs_contracting_dimensions);
|
||||
if (length == -1) {
|
||||
Py_DECREF(rhs_contracting_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i);
|
||||
if (!item) {
|
||||
Py_DECREF(rhs_contracting_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||
if (dimension == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(item);
|
||||
Py_DECREF(rhs_contracting_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.add_rhs_contracting_dimensions(dimension);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
Py_DECREF(rhs_contracting_dimensions);
|
||||
|
||||
/* lhs_batch_dimensions */
|
||||
PyObject* lhs_batch_dimensions = PyObject_GetAttrString(
|
||||
$input, "lhs_batch_dimensions");
|
||||
if (!lhs_batch_dimensions) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
length = PySequence_Size(lhs_batch_dimensions);
|
||||
if (length == -1) {
|
||||
Py_DECREF(lhs_batch_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i);
|
||||
if (!item) {
|
||||
Py_DECREF(lhs_batch_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||
if (dimension == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(item);
|
||||
Py_DECREF(lhs_batch_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.add_lhs_batch_dimensions(dimension);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
Py_DECREF(lhs_batch_dimensions);
|
||||
|
||||
/* rhs_batch_dimensions */
|
||||
PyObject* rhs_batch_dimensions = PyObject_GetAttrString(
|
||||
$input, "rhs_batch_dimensions");
|
||||
if (!rhs_batch_dimensions) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
length = PySequence_Size(rhs_batch_dimensions);
|
||||
if (length == -1) {
|
||||
Py_DECREF(rhs_batch_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
for (int i = 0; i < length; ++i) {
|
||||
PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i);
|
||||
if (!item) {
|
||||
Py_DECREF(rhs_batch_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
|
||||
if (dimension == -1 && PyErr_Occurred()) {
|
||||
Py_DECREF(item);
|
||||
Py_DECREF(rhs_batch_dimensions);
|
||||
return NULL;
|
||||
}
|
||||
dimension_numbers.add_rhs_batch_dimensions(dimension);
|
||||
Py_DECREF(item);
|
||||
}
|
||||
Py_DECREF(rhs_batch_dimensions);
|
||||
|
||||
$1 = &dimension_numbers;
|
||||
}
|
||||
|
||||
// PaddingConfig
|
||||
|
||||
%typemap(in) const PaddingConfig&
|
||||
@ -623,6 +803,30 @@ tensorflow::ImportNumpy();
|
||||
$1 = &dimension_numbers;
|
||||
}
|
||||
|
||||
// ExecutableBuildOptions
|
||||
|
||||
%typemap(in) const ExecutableBuildOptions*
|
||||
(ExecutableBuildOptions build_options) {
|
||||
if ($input == Py_None) {
|
||||
$1 = NULL;
|
||||
} else {
|
||||
PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph");
|
||||
if (!o) {
|
||||
return NULL;
|
||||
}
|
||||
if (o != Py_None) {
|
||||
if (!PyString_Check(o)) {
|
||||
PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None.");
|
||||
return NULL;
|
||||
}
|
||||
build_options.set_generate_hlo_graph(PyString_AsString(o));
|
||||
}
|
||||
Py_DECREF(o);
|
||||
|
||||
$1 = &build_options;
|
||||
}
|
||||
}
|
||||
|
||||
%ignoreall
|
||||
%unignore xla;
|
||||
%unignore xla::swig;
|
||||
@ -667,6 +871,7 @@ tensorflow::ImportNumpy();
|
||||
%unignore xla::swig::LocalComputationBuilder::Call;
|
||||
%unignore xla::swig::LocalComputationBuilder::Transpose;
|
||||
%unignore xla::swig::LocalComputationBuilder::Rev;
|
||||
%unignore xla::swig::LocalComputationBuilder::Clamp;
|
||||
%unignore xla::swig::LocalComputationBuilder::Map;
|
||||
%unignore xla::swig::LocalComputationBuilder::Reduce;
|
||||
%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding;
|
||||
@ -681,6 +886,7 @@ tensorflow::ImportNumpy();
|
||||
%unignore xla::swig::LocalComputationBuilder::Lt;
|
||||
%unignore xla::swig::LocalComputationBuilder::Le;
|
||||
%unignore xla::swig::LocalComputationBuilder::Dot;
|
||||
%unignore xla::swig::LocalComputationBuilder::DotGeneral;
|
||||
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
|
||||
%unignore xla::swig::LocalComputationBuilder::Add;
|
||||
%unignore xla::swig::LocalComputationBuilder::Sub;
|
||||
@ -696,6 +902,7 @@ tensorflow::ImportNumpy();
|
||||
%unignore xla::swig::LocalComputationBuilder::Exp;
|
||||
%unignore xla::swig::LocalComputationBuilder::Floor;
|
||||
%unignore xla::swig::LocalComputationBuilder::Ceil;
|
||||
%unignore xla::swig::LocalComputationBuilder::Round;
|
||||
%unignore xla::swig::LocalComputationBuilder::Log;
|
||||
%unignore xla::swig::LocalComputationBuilder::Sign;
|
||||
%unignore xla::swig::LocalComputationBuilder::Cos;
|
||||
|
@ -176,85 +176,107 @@ static string PyObjectCppRepr(PyObject* o) {
|
||||
return ExtractStringAndDecref(r);
|
||||
}
|
||||
|
||||
Status CheckPyShapeInfo(PyObject* o) {
|
||||
StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
|
||||
auto error = [o](const string& prefix) {
|
||||
return InvalidArgument("%s; got %s", prefix.c_str(),
|
||||
PyObjectCppRepr(o).c_str());
|
||||
};
|
||||
// The object is a tuple (a pair)
|
||||
if (!PyTuple_Check(o)) {
|
||||
return error("Shape record must be a tuple");
|
||||
}
|
||||
if (PyTuple_Size(o) != 2) {
|
||||
return error("Shape record tuple must be of length 2");
|
||||
}
|
||||
|
||||
// It has a first element, which is a numpy dtype object
|
||||
PyObject* first = PyTuple_GetItem(o, 0);
|
||||
if (first == nullptr) {
|
||||
return error("Tuple has no item 0 (shape dtype)");
|
||||
}
|
||||
if (first->ob_type != &PyArrayDescr_Type) {
|
||||
return error(
|
||||
"Shape record does not have a numpy dtype as its first element");
|
||||
}
|
||||
const int np_type = NumpyTypenum(first);
|
||||
if (!NumpyTypeIsValid(np_type)) {
|
||||
return error("Shape record has an invalid integer dtype");
|
||||
}
|
||||
|
||||
// It has a second element, which is a tuple, either of shape
|
||||
// records or of Python ints
|
||||
PyObject* second = PyTuple_GetItem(o, 1);
|
||||
if (!second) {
|
||||
return error("Tuple has no item 0 (shape dimensions)");
|
||||
}
|
||||
if (!PyTuple_Check(second)) {
|
||||
return error("Shape record does not have a tuple as its second element");
|
||||
}
|
||||
const int length = PyTuple_Size(second);
|
||||
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
|
||||
for (int i = 0; i < length; i++) {
|
||||
PyObject* dimension = PyTuple_GetItem(second, i);
|
||||
if (element_type == TUPLE) {
|
||||
VLOG(3) << "element_type is tuple, checking member: " << i;
|
||||
Status result = CheckPyShapeInfo(dimension);
|
||||
if (!result.ok()) {
|
||||
return AddStatus(
|
||||
result, tensorflow::strings::StrCat("Validating tuple member ", i,
|
||||
" of ", PyObjectCppRepr(o)));
|
||||
}
|
||||
} else if (!CheckPyIntOrLong(dimension)) {
|
||||
return error("Non-tuple shape record has a non-integer dimension");
|
||||
auto get_attr = [o, &error](const string& field) -> StatusOr<PyObject*> {
|
||||
PyObject* result =
|
||||
PyObject_GetAttrString(o, const_cast<char*>(field.c_str()));
|
||||
if (result == nullptr) {
|
||||
return error(tensorflow::strings::StrCat(
|
||||
"Failed to get attribute of Shape object:", field));
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
auto call_method = [o, &error](const string& method) -> StatusOr<PyObject*> {
|
||||
PyObject* result =
|
||||
PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
|
||||
if (result == nullptr) {
|
||||
return error(tensorflow::strings::StrCat(
|
||||
"Failed to call method of shape object:", method));
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
PyObject* np_type;
|
||||
TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype"));
|
||||
if (np_type->ob_type != &PyArrayDescr_Type) {
|
||||
return error("Shape attribute np_dtype is not an integer numpy dtype");
|
||||
}
|
||||
if (!NumpyTypeIsValid(NumpyTypenum(np_type))) {
|
||||
return error("Shape attribute np_dtype is not a valid integer numpy dtype");
|
||||
}
|
||||
const PrimitiveType element_type =
|
||||
NumpyTypeToPrimitiveType(NumpyTypenum(np_type));
|
||||
Py_DECREF(np_type);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Precondition: CheckPyShapeInfo(o)
|
||||
Shape XlaShapeFromPyShapeInfo(PyObject* o) {
|
||||
const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0));
|
||||
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
|
||||
PyObject* py_dimensions = PyTuple_GetItem(o, 1);
|
||||
const int length = PyTuple_Size(py_dimensions);
|
||||
if (element_type == TUPLE) {
|
||||
PyObject* py_subshapes;
|
||||
TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes"));
|
||||
if (!PyTuple_Check(py_subshapes)) {
|
||||
return error(
|
||||
"Return value of Shape method tuple_shapes() is not a tuple");
|
||||
}
|
||||
const int length = PyTuple_Size(py_subshapes);
|
||||
std::vector<Shape> subshapes;
|
||||
subshapes.reserve(length);
|
||||
for (int i = 0; i < length; i++) {
|
||||
subshapes.push_back(
|
||||
XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i)));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const Shape& subshape,
|
||||
XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i)));
|
||||
subshapes.push_back(subshape);
|
||||
}
|
||||
Py_DECREF(py_subshapes);
|
||||
return ShapeUtil::MakeTupleShape(subshapes);
|
||||
} else {
|
||||
PyObject* py_dimensions;
|
||||
PyObject* py_minor_to_major;
|
||||
TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions"));
|
||||
TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major"));
|
||||
if (!PyTuple_Check(py_dimensions)) {
|
||||
return error("Return value of Shape method dimensions() is not a tuple");
|
||||
}
|
||||
if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) {
|
||||
return error(
|
||||
"Return value of Shape method minor_to_major() is neither a tuple "
|
||||
"nor None");
|
||||
}
|
||||
const int length = PyTuple_Size(py_dimensions);
|
||||
if (py_minor_to_major != Py_None &&
|
||||
length != PyTuple_Size(py_minor_to_major)) {
|
||||
return error(
|
||||
"Shape methods dimensions() and minor_to_major() return "
|
||||
"different-length tuples");
|
||||
}
|
||||
std::vector<int64> dimensions(length);
|
||||
std::vector<int64> minor_to_major(length);
|
||||
for (int i = 0; i < length; i++) {
|
||||
dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
|
||||
if (dimensions[i] == -1) {
|
||||
CHECK(!PyErr_Occurred());
|
||||
if (dimensions[i] == -1 && PyErr_Occurred()) {
|
||||
return error("Dimension is not an int");
|
||||
}
|
||||
|
||||
if (py_minor_to_major != Py_None) {
|
||||
minor_to_major[i] =
|
||||
PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i));
|
||||
if (minor_to_major[i] == -1 && PyErr_Occurred()) {
|
||||
return error("Minor-to-major value is not an int");
|
||||
}
|
||||
}
|
||||
}
|
||||
return ShapeUtil::MakeShape(element_type, dimensions);
|
||||
bool with_layout = py_minor_to_major != Py_None;
|
||||
Py_DECREF(py_dimensions);
|
||||
Py_DECREF(py_minor_to_major);
|
||||
if (with_layout) {
|
||||
return ShapeUtil::MakeShapeWithLayout(element_type, dimensions,
|
||||
minor_to_major);
|
||||
} else {
|
||||
return ShapeUtil::MakeShape(element_type, dimensions);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,15 +56,11 @@ bool NumpyTypeIsValid(int np_type);
|
||||
// The return value is a new reference.
|
||||
PyObject* PyShapeInfoFromXlaShape(const Shape& shape);
|
||||
|
||||
// Returns the outcome of a best-effort check that the Python object
|
||||
// is a pair of the form (numpy dtype, dimensions), as produced by
|
||||
// PyShapeInfoFromXlaShape.
|
||||
Status CheckPyShapeInfo(PyObject* o);
|
||||
|
||||
// Performs the inverse conversion to that of PyShapeInfoFromXlaShape.
|
||||
// Converts a Python object with a method interface mathing that of
|
||||
// xla_client.Shape into an XLA Shape object.
|
||||
//
|
||||
// The return value is a new reference.
|
||||
Shape XlaShapeFromPyShapeInfo(PyObject* o);
|
||||
StatusOr<Shape> XlaShapeFromPyShape(PyObject* o);
|
||||
|
||||
// Converts a PyObject that represents operation metadata into protocol buffer
|
||||
// form.
|
||||
|
@ -89,6 +89,7 @@ _UNARY_OPS = [
|
||||
'Abs',
|
||||
'Exp',
|
||||
'Floor',
|
||||
'Round',
|
||||
'Ceil',
|
||||
'Log',
|
||||
'Sign',
|
||||
@ -155,9 +156,14 @@ class LocalBuffer(object):
|
||||
self._delete = c_api.DeleteLocalShapedBuffer
|
||||
|
||||
@staticmethod
|
||||
def from_py(npval):
|
||||
def from_py(npval, layout_fn=None):
|
||||
npval = require_numpy_array_layout(npval)
|
||||
return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval))
|
||||
if layout_fn:
|
||||
shape = Shape.from_numpy(npval)
|
||||
shape = shape.map_leaves(layout_fn)
|
||||
else:
|
||||
shape = None
|
||||
return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape))
|
||||
|
||||
def to_py(self):
|
||||
return self.c_local_shaped_buffer.ToLiteral()
|
||||
@ -182,13 +188,17 @@ class Shape(object):
|
||||
represents an XLA tuple.
|
||||
"""
|
||||
|
||||
def __init__(self, np_dtype, dimensions):
|
||||
def __init__(self, np_dtype, dimensions, minor_to_major=None):
|
||||
assert isinstance(dimensions, tuple)
|
||||
self.np_dtype = np_dtype
|
||||
self._dimensions = dimensions
|
||||
self._minor_to_major = minor_to_major
|
||||
self._check_minor_to_major()
|
||||
|
||||
def __repr__(self):
|
||||
return 'xla_client.Shape(np_dtype={!r}, dimensions={!r})'.format(
|
||||
self.np_dtype, self._dimensions)
|
||||
return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, '
|
||||
'minor_to_major={!r})').format(self.np_dtype, self._dimensions,
|
||||
self._minor_to_major)
|
||||
|
||||
def element_type(self):
|
||||
return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
|
||||
@ -201,11 +211,49 @@ class Shape(object):
|
||||
raise ValueError('Tuple shape has no dimensions')
|
||||
return self._dimensions
|
||||
|
||||
def minor_to_major(self):
|
||||
return self._minor_to_major
|
||||
|
||||
def tuple_shapes(self):
|
||||
if not self.is_tuple():
|
||||
raise ValueError('Shape is not a tuple shape')
|
||||
return self._dimensions
|
||||
|
||||
def rank(self):
|
||||
return len(self.dimensions())
|
||||
|
||||
def map_leaves(self, f):
|
||||
"""Map f over each leaf-level array subshape.
|
||||
|
||||
Args:
|
||||
f: The function to apply. Whenever f returns None, the identity is
|
||||
applied instead.
|
||||
|
||||
Returns:
|
||||
A new Shape with the mapped leaves.
|
||||
"""
|
||||
if self.is_tuple():
|
||||
children = tuple(child.map_leaves(f) for child in self.tuple_shapes())
|
||||
return Shape(np.dtype('O'), children)
|
||||
else:
|
||||
mapped = f(self)
|
||||
return self if mapped is None else mapped
|
||||
|
||||
def _check_minor_to_major(self):
|
||||
mtm = self._minor_to_major
|
||||
if self.is_tuple():
|
||||
assert mtm is None, self
|
||||
if mtm is not None:
|
||||
assert self.rank() == len(mtm), self
|
||||
assert sorted(mtm) == range(len(mtm)), self
|
||||
|
||||
def update_minor_to_major(self, minor_to_major):
|
||||
if not isinstance(minor_to_major, tuple):
|
||||
raise TypeError('minor_to_major must be a tuple')
|
||||
updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major)
|
||||
updated._check_minor_to_major() # pylint: disable=protected-access
|
||||
return updated
|
||||
|
||||
@staticmethod
|
||||
def from_numpy(npval):
|
||||
|
||||
@ -222,23 +270,10 @@ def _wrap_shape(shape_info):
|
||||
dtype, dims = shape_info
|
||||
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
|
||||
if element_type == xla_data_pb2.TUPLE:
|
||||
dims = [_wrap_shape(subshape_info) for subshape_info in dims]
|
||||
dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims)
|
||||
return Shape(dtype, dims)
|
||||
|
||||
|
||||
def _unwrap_shape(shape):
|
||||
if shape.is_tuple():
|
||||
components = tuple(
|
||||
_unwrap_shape(subshape) for subshape in shape.tuple_shapes())
|
||||
else:
|
||||
components = shape.dimensions()
|
||||
return (shape.np_dtype, components)
|
||||
|
||||
|
||||
def _unwrap_shapes(shapes):
|
||||
return [_unwrap_shape(shape) for shape in shapes]
|
||||
|
||||
|
||||
def _wrap_data_handle(handle):
|
||||
cdh = xla_data_pb2.ComputationDataHandle()
|
||||
cdh.handle = handle
|
||||
@ -260,6 +295,17 @@ def require_numpy_array_layout(value):
|
||||
return np.require(value, requirements=['C', 'A'])
|
||||
|
||||
|
||||
class CompileOptions(object):
|
||||
"""Python object for XLA compile options.
|
||||
|
||||
These options can be passed to the 'compile' step when using a local XLA
|
||||
client.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.generate_hlo_graph = None
|
||||
|
||||
|
||||
def transfer_to_infeed(value, replica_number=None):
|
||||
"""Transfers the given value into the XLA infeed queue.
|
||||
|
||||
@ -291,8 +337,7 @@ def transfer_from_outfeed(shape, replica_number=None):
|
||||
Returns:
|
||||
The literal value that is produced from the outfeed queue.
|
||||
"""
|
||||
return c_api.TransferFromOutfeedLocalReplica(
|
||||
_unwrap_shape(shape), replica_number or 0)
|
||||
return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0)
|
||||
|
||||
|
||||
class LocalComputation(object):
|
||||
@ -313,22 +358,39 @@ class LocalComputation(object):
|
||||
else:
|
||||
self._delete = c_api.DeleteLocalComputation
|
||||
|
||||
def Compile(self, argument_shapes=()):
|
||||
def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None):
|
||||
if self.is_compiled:
|
||||
raise ValueError('Attempt to compile a compiled local XLA computation.')
|
||||
if layout_fn:
|
||||
argument_shapes = [
|
||||
shape.map_leaves(layout_fn) for shape in argument_shapes
|
||||
]
|
||||
return LocalComputation(
|
||||
self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)),
|
||||
self.c_local_computation.Compile(argument_shapes, compile_options),
|
||||
is_compiled=True)
|
||||
|
||||
def CompileWithExampleArguments(self, arguments=()):
|
||||
def CompileWithExampleArguments(self,
|
||||
arguments=(),
|
||||
compile_options=None,
|
||||
layout_fn=None):
|
||||
return self.Compile(
|
||||
argument_shapes=[Shape.from_numpy(arg) for arg in arguments])
|
||||
argument_shapes=[Shape.from_numpy(arg) for arg in arguments],
|
||||
compile_options=compile_options,
|
||||
layout_fn=layout_fn)
|
||||
|
||||
def Execute(self, arguments=()):
|
||||
def Execute(self, arguments=(), layout_fn=None):
|
||||
"""Execute with Python values as arguments and return value."""
|
||||
if not self.is_compiled:
|
||||
raise ValueError('Cannot execute an uncompiled local XLA computation.')
|
||||
argument_shapes = [Shape.from_numpy(arg) for arg in arguments]
|
||||
if layout_fn:
|
||||
argument_shapes = [
|
||||
shape.map_leaves(layout_fn) for shape in argument_shapes
|
||||
]
|
||||
else:
|
||||
argument_shapes = [None for shape in argument_shapes]
|
||||
arguments = tuple(map(require_numpy_array_layout, arguments))
|
||||
return self.c_local_computation.Execute(arguments)
|
||||
return self.c_local_computation.Execute(arguments, argument_shapes)
|
||||
|
||||
def ExecuteWithLocalBuffers(self, arguments=()):
|
||||
"""Execute with LocalBuffer arguments and return value."""
|
||||
@ -384,7 +446,7 @@ class ComputationBuilder(object):
|
||||
Returns:
|
||||
A ComputationDataHandle message.
|
||||
"""
|
||||
return _wrap_data_handle(self._client.Infeed(_unwrap_shape(shape)))
|
||||
return _wrap_data_handle(self._client.Infeed(shape))
|
||||
|
||||
def Outfeed(self, operand):
|
||||
"""Enqueues an outfeed op onto the computation.
|
||||
@ -393,7 +455,7 @@ class ComputationBuilder(object):
|
||||
outfeed queue for subsequent dequeue via the client API.
|
||||
"""
|
||||
self._client.Outfeed(
|
||||
_unwrap_data_handle(operand), _unwrap_shape(self.GetShape(operand)),
|
||||
_unwrap_data_handle(operand), self.GetShape(operand),
|
||||
''.encode('utf-8'))
|
||||
|
||||
def Constant(self, value):
|
||||
@ -484,8 +546,7 @@ class ComputationBuilder(object):
|
||||
parameter_num = next(self._parameter_numbering)
|
||||
|
||||
return _wrap_data_handle(
|
||||
self._client.Parameter(
|
||||
parameter_num, _unwrap_shape(shape), name.encode('utf8')))
|
||||
self._client.Parameter(parameter_num, shape, name.encode('utf8')))
|
||||
|
||||
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
|
||||
"""Enqueues a Parameter op onto the computation.
|
||||
@ -606,6 +667,13 @@ class ComputationBuilder(object):
|
||||
return _wrap_data_handle(
|
||||
self._client.Rev(_unwrap_data_handle(operand), dimensions))
|
||||
|
||||
def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin
|
||||
"""Clamp op."""
|
||||
return _wrap_data_handle(
|
||||
self._client.Clamp(_unwrap_data_handle(min),
|
||||
_unwrap_data_handle(operand),
|
||||
_unwrap_data_handle(max)))
|
||||
|
||||
def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
|
||||
padding, source, init_value, scatter):
|
||||
"""Select and scatter op, used by the gradient of ReduceWindow.
|
||||
@ -825,8 +893,7 @@ class ComputationBuilder(object):
|
||||
shape = Shape(self.GetShape(mu).np_dtype, dims)
|
||||
return _wrap_data_handle(
|
||||
self._client.RngNormal(
|
||||
_unwrap_data_handle(mu), _unwrap_data_handle(sigma),
|
||||
_unwrap_shape(shape)))
|
||||
_unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape))
|
||||
|
||||
def RngUniform(self, a, b, dims):
|
||||
"""Enqueues an RngUniform operation onto the computation.
|
||||
@ -846,8 +913,7 @@ class ComputationBuilder(object):
|
||||
shape = Shape(self.GetShape(a).np_dtype, dims)
|
||||
return _wrap_data_handle(
|
||||
self._client.RngUniform(
|
||||
_unwrap_data_handle(a), _unwrap_data_handle(b),
|
||||
_unwrap_shape(shape)))
|
||||
_unwrap_data_handle(a), _unwrap_data_handle(b), shape))
|
||||
|
||||
def While(self, cond, body, init):
|
||||
"""Enqueues a While operation onto the computation.
|
||||
@ -865,10 +931,37 @@ class ComputationBuilder(object):
|
||||
_unwrap_data_handle(init)))
|
||||
|
||||
def Dot(self, lhs, rhs):
|
||||
"""Matrix multiplication between lhs and rhs."""
|
||||
"""Enqueues a dot operation onto the computation.
|
||||
|
||||
Args:
|
||||
lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array.
|
||||
rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array.
|
||||
|
||||
Returns: a ComputationDataHandle representing the Dot operation.
|
||||
"""
|
||||
return _wrap_data_handle(
|
||||
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
|
||||
|
||||
def DotGeneral(self, lhs, rhs, dimension_numbers):
|
||||
"""Enqueues a general dot operation onto the computation.
|
||||
|
||||
Args:
|
||||
lhs: ComputationDataHandle for the left-hand-side array.
|
||||
rhs: ComputationDataHandle for the right-hand-side array.
|
||||
dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested
|
||||
tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of
|
||||
integers representing the dimensions to treat as contracting dimensions
|
||||
and batch dimensions on each input operand.
|
||||
|
||||
Returns: a ComputationDataHandle representing the DotGeneral operation.
|
||||
"""
|
||||
if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers):
|
||||
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
|
||||
return _wrap_data_handle(
|
||||
self._client.DotGeneral(
|
||||
_unwrap_data_handle(lhs), _unwrap_data_handle(rhs),
|
||||
dimension_numbers))
|
||||
|
||||
def Conv(self, lhs, rhs, window_strides, padding):
|
||||
"""Enqueues a Conv operation onto the computation.
|
||||
|
||||
@ -979,7 +1072,7 @@ def initialize_replica_count(replica_count):
|
||||
|
||||
Args:
|
||||
replica_count: number of replicas that are desired for set up during XLA
|
||||
initalization.
|
||||
initialization.
|
||||
|
||||
Raises:
|
||||
A runtime exception if the XLA service has already been initialized.
|
||||
@ -1005,3 +1098,13 @@ def GetPaddingConfigFromTriples(triples):
|
||||
dimension.edge_padding_high = hi
|
||||
dimension.interior_padding = interior
|
||||
return padding_config
|
||||
|
||||
|
||||
def GetDotDimensionsFromLists(dimension_numbers):
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
dot_dims_proto = xla_data_pb2.DotDimensionNumbers()
|
||||
dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
|
||||
dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
|
||||
dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
|
||||
dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
|
||||
return dot_dims_proto
|
||||
|
@ -444,6 +444,30 @@ class SingleOpTest(LocalComputationTest):
|
||||
c.Dot(c.Constant(lhs), c.Constant(rhs))
|
||||
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
|
||||
|
||||
def testDotGeneral(self):
|
||||
c = self._NewComputation()
|
||||
rng = np.random.RandomState(0)
|
||||
lhs = NumpyArrayF32(rng.randn(10, 3, 4))
|
||||
rhs = NumpyArrayF32(rng.randn(10, 4, 5))
|
||||
dimension_numbers = (([2], [1]), ([0], [0]))
|
||||
c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
|
||||
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
|
||||
|
||||
def testDotGeneralWithDotDimensionNumbersProto(self):
|
||||
c = self._NewComputation()
|
||||
rng = np.random.RandomState(0)
|
||||
lhs = NumpyArrayF32(rng.randn(10, 3, 4))
|
||||
rhs = NumpyArrayF32(rng.randn(10, 4, 5))
|
||||
|
||||
dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers()
|
||||
dimension_numbers.lhs_contracting_dimensions.append(2)
|
||||
dimension_numbers.rhs_contracting_dimensions.append(1)
|
||||
dimension_numbers.lhs_batch_dimensions.append(0)
|
||||
dimension_numbers.rhs_batch_dimensions.append(0)
|
||||
|
||||
c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
|
||||
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
|
||||
|
||||
def testConvF32Same(self):
|
||||
c = self._NewComputation()
|
||||
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
|
||||
@ -496,6 +520,12 @@ class SingleOpTest(LocalComputationTest):
|
||||
c.Exp(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
|
||||
|
||||
def testRound(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayF32([3.3, 12.1])
|
||||
c.Round(c.Constant(arr))
|
||||
self._ExecuteAndCompareClose(c, expected=np.round(arr))
|
||||
|
||||
def testLog(self):
|
||||
c = self._NewComputation()
|
||||
arr = NumpyArrayF32([3.3, 12.1])
|
||||
@ -699,6 +729,23 @@ class SingleOpTest(LocalComputationTest):
|
||||
self._ExecuteAndCompareExact(
|
||||
c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]])
|
||||
|
||||
def testClampF32(self):
|
||||
c = self._NewComputation()
|
||||
c.Clamp(
|
||||
c.Constant(NumpyArrayF32(-1)),
|
||||
c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
|
||||
c.Constant(NumpyArrayF32(2)))
|
||||
self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
|
||||
|
||||
# TODO(b/72689392): re-enable when bug S32 resolved
|
||||
def DISABLED_testClampS32(self):
|
||||
c = self._NewComputation()
|
||||
c.Clamp(
|
||||
c.Constant(NumpyArrayS32(-1)),
|
||||
c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
|
||||
c.Constant(NumpyArrayS32(2)))
|
||||
self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2])
|
||||
|
||||
def testSelect(self):
|
||||
c = self._NewComputation()
|
||||
c.Select(
|
||||
|
@ -509,6 +509,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
@ -1110,8 +1111,6 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_evaluator",
|
||||
":hlo_pass",
|
||||
":tuple_util",
|
||||
":while_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
@ -1156,6 +1155,34 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "implicit_broadcast_remover",
|
||||
srcs = ["implicit_broadcast_remover.cc"],
|
||||
hdrs = ["implicit_broadcast_remover.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dce",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "implicit_broadcast_remover_test",
|
||||
srcs = ["implicit_broadcast_remover_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":implicit_broadcast_remover",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dot_decomposer",
|
||||
srcs = ["dot_decomposer.cc"],
|
||||
@ -1825,7 +1852,9 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1856,6 +1885,7 @@ cc_library(
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
":hlo_pass",
|
||||
":hlo_proto_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
|
@ -1618,9 +1618,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
|
||||
reduce,
|
||||
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
|
||||
}
|
||||
|
||||
// A Transpose feeding a reduce can simply permute the reduction dimensions
|
||||
// field.
|
||||
if (arg->opcode() == HloOpcode::kTranspose) {
|
||||
// field if the output of the reduce is a vector or scalar. Higher ranked
|
||||
// result may require a transpose of the output.
|
||||
if (ShapeUtil::Rank(reduce->shape()) <= 1 &&
|
||||
arg->opcode() == HloOpcode::kTranspose) {
|
||||
auto transpose_dimensions = arg->dimensions();
|
||||
std::vector<int64> new_reduce_dimensions;
|
||||
for (auto dim : dimensions) {
|
||||
|
@ -997,14 +997,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
auto color = single_colored_set.first;
|
||||
VLOG(2) << "Simulating heap for color " << color;
|
||||
int64 alignment = assignment->color_alignment_(color);
|
||||
HeapSimulator::Options options;
|
||||
options.buffers_to_assign = &single_colored_set.second;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const HeapSimulator::Result result,
|
||||
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
|
||||
MakeUnique<LazyBestFitHeap>(alignment)),
|
||||
assignment->module(), module_sequence,
|
||||
assignment->points_to_analysis(),
|
||||
assignment->buffer_size_,
|
||||
&single_colored_set.second));
|
||||
assignment->buffer_size_, options));
|
||||
AssignBuffersFromHeapSimulator(result, assignment,
|
||||
single_colored_set.first);
|
||||
}
|
||||
@ -1024,14 +1025,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
|
||||
auto color = single_colored_set.first;
|
||||
VLOG(2) << "Simulating heap for color " << color;
|
||||
int64 alignment = assignment->color_alignment_(color);
|
||||
HeapSimulator::Options options;
|
||||
options.buffers_to_assign = &single_colored_set.second;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const HeapSimulator::Result result,
|
||||
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
|
||||
MakeUnique<LazyBestFitHeap>(alignment)),
|
||||
*computation, *instruction_sequence,
|
||||
assignment->points_to_analysis(),
|
||||
assignment->buffer_size_,
|
||||
&single_colored_set.second));
|
||||
assignment->buffer_size_, options));
|
||||
AssignBuffersFromHeapSimulator(result, assignment,
|
||||
single_colored_set.first);
|
||||
}
|
||||
|
@ -72,8 +72,18 @@ class AotCompilationOptions {
|
||||
// Returns the ID of the platform to which these options apply.
|
||||
virtual perftools::gputools::Platform::Id PlatformId() const = 0;
|
||||
|
||||
// Optional allocator that may be used for allocating temp space on the device
|
||||
// during compilation.
|
||||
DeviceMemoryAllocator* device_allocator() const { return device_allocator_; }
|
||||
void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
|
||||
device_allocator_ = device_allocator;
|
||||
}
|
||||
|
||||
protected:
|
||||
AotCompilationOptions() = default;
|
||||
|
||||
private:
|
||||
DeviceMemoryAllocator* device_allocator_ = nullptr;
|
||||
};
|
||||
|
||||
// Abstract compiler interface that is subclassed for compilation on a
|
||||
@ -99,9 +109,16 @@ class Compiler {
|
||||
|
||||
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
|
||||
// module.
|
||||
//
|
||||
// If device_allocator is not null, the compiler may use it to allocate temp
|
||||
// space on the device for use during compilation. For example, the compiler
|
||||
// may allocate buffers on the device and then run variants of a given
|
||||
// algorithm over those buffers, to see which variant is fastest. Any space
|
||||
// allocated should be deallocated before this function returns.
|
||||
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* executor) = 0;
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
DeviceMemoryAllocator* device_allocator) = 0;
|
||||
|
||||
// Compiles the HLO module for execution on a device given by the executor,
|
||||
// and returns an executable object or an error status. No HLO passes are
|
||||
@ -112,21 +129,27 @@ class Compiler {
|
||||
// The compiler may optionally specialize to the individual device
|
||||
// (not just type of device) indicated by the executor.
|
||||
//
|
||||
// device_allocator is optional; see RunHloPasses.
|
||||
//
|
||||
// Use the overload below to compile computations that run in parallel.
|
||||
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* executor) = 0;
|
||||
perftools::gputools::StreamExecutor* executor,
|
||||
DeviceMemoryAllocator* device_allocator) = 0;
|
||||
|
||||
// Compiles a set of HLO modules that can run in parallel, potentially
|
||||
// communicating data between the modules, and returns a corresponding
|
||||
// sequence of executable objects.
|
||||
//
|
||||
// device_allocator is optional; see RunHloPasses.
|
||||
//
|
||||
// TODO(b/68666782): Remove this method after adding support for multiple
|
||||
// modules to RunHloPasses and RunBackends.
|
||||
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
|
||||
std::vector<std::unique_ptr<HloModule>> modules,
|
||||
std::vector<std::vector<perftools::gputools::StreamExecutor*>>
|
||||
stream_exec) = 0;
|
||||
stream_exec,
|
||||
DeviceMemoryAllocator* device_allocator) = 0;
|
||||
|
||||
// Compiles the HLO module for ahead-of-time execution. This is intended for
|
||||
// use in static compilation.
|
||||
|
@ -437,7 +437,8 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* /*stream_exec*/) {
|
||||
perftools::gputools::StreamExecutor* /*stream_exec*/,
|
||||
DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
VLOG(2) << "Before optimization:";
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
@ -450,7 +451,8 @@ StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) {
|
||||
perftools::gputools::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* /*device_allocator*/) {
|
||||
const string timer_message =
|
||||
"Compiling [" + module->name() + "] for CPU using JIT";
|
||||
XLA_SCOPED_LOGGING_TIMER(timer_message);
|
||||
@ -517,8 +519,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
// ownership is std::moved.
|
||||
const bool embed_ir_in_executable =
|
||||
module->config().debug_options().xla_embed_ir_in_executable();
|
||||
const string xla_dump_hlo_proto_to =
|
||||
module->config().debug_options().xla_dump_hlo_proto_to();
|
||||
const string xla_dump_optimized_hlo_proto_to =
|
||||
module->config().debug_options().xla_dump_optimized_hlo_proto_to();
|
||||
|
||||
if (options::CpuParallelBackendRequested(module->config())) {
|
||||
VLOG(1) << "Using parallel cpu backend";
|
||||
@ -538,10 +540,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
// print one ourselves.
|
||||
XLA_VLOG_LINES(2, assignment->ToString());
|
||||
|
||||
if (!xla_dump_hlo_proto_to.empty()) {
|
||||
if (!xla_dump_optimized_hlo_proto_to.empty()) {
|
||||
HloProto proto = MakeHloProto(*module, *assignment);
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||
proto, xla_dump_hlo_proto_to, module->name()));
|
||||
proto, xla_dump_optimized_hlo_proto_to, module->name()));
|
||||
}
|
||||
|
||||
// If we are using the parallel CPU backend, we need to create map from
|
||||
@ -647,10 +649,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
|
||||
// print one ourselves.
|
||||
XLA_VLOG_LINES(2, assignment->ToString());
|
||||
|
||||
if (!xla_dump_hlo_proto_to.empty()) {
|
||||
if (!xla_dump_optimized_hlo_proto_to.empty()) {
|
||||
HloProto proto = MakeHloProto(*module, *assignment);
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||
proto, xla_dump_hlo_proto_to, module->name()));
|
||||
proto, xla_dump_optimized_hlo_proto_to, module->name()));
|
||||
}
|
||||
|
||||
// Each computation is a single function. Emit all embedded computations
|
||||
@ -826,12 +828,12 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
// print one ourselves.
|
||||
XLA_VLOG_LINES(2, assignment->ToString());
|
||||
|
||||
const string xla_dump_hlo_proto_to =
|
||||
module->config().debug_options().xla_dump_hlo_proto_to();
|
||||
if (!xla_dump_hlo_proto_to.empty()) {
|
||||
const string xla_dump_optimized_hlo_proto_to =
|
||||
module->config().debug_options().xla_dump_optimized_hlo_proto_to();
|
||||
if (!xla_dump_optimized_hlo_proto_to.empty()) {
|
||||
HloProto proto = MakeHloProto(*module, *assignment);
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||
proto, xla_dump_hlo_proto_to, module->name()));
|
||||
proto, xla_dump_optimized_hlo_proto_to, module->name()));
|
||||
}
|
||||
|
||||
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
|
||||
|
@ -118,11 +118,13 @@ class CpuCompiler : public LLVMCompiler {
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
perftools::gputools::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::unique_ptr<Executable>> RunBackend(
|
||||
std::unique_ptr<HloModule> module,
|
||||
perftools::gputools::StreamExecutor* stream_exec) override;
|
||||
perftools::gputools::StreamExecutor* stream_exec,
|
||||
DeviceMemoryAllocator* device_allocator) override;
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
|
@ -479,7 +479,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
|
||||
|
||||
Status IrEmitter::HandleSort(HloInstruction* sort) {
|
||||
// TODO(b/26783907): Implement sort on CPU.
|
||||
return Unimplemented("Sort is not supported on CPU (b/26783907).");
|
||||
return Unimplemented("Sort is not implemented on CPU.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
|
||||
@ -522,7 +522,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
|
||||
// TODO(b/31410564): Implement dilation for reduce-window.
|
||||
if (window_util::HasDilation(window)) {
|
||||
return Unimplemented(
|
||||
"Dilation for reduce-window not implemented on CPU. See b/31410564.");
|
||||
"Dilation for ReduceWindow is not implemented on CPU.");
|
||||
}
|
||||
|
||||
// The called computation should have been emitted previously.
|
||||
@ -625,8 +625,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
|
||||
// TODO(b/31410564): Implement dilation for select-and-scatter.
|
||||
if (window_util::HasDilation(window)) {
|
||||
return Unimplemented(
|
||||
"Dilation for select-and-scatter not implemented on CPU. "
|
||||
"See b/31410564.");
|
||||
"Dilation for SelectAndScatter is not implemented on CPU. ");
|
||||
}
|
||||
|
||||
// The select and scatter computations should have been emitted previously.
|
||||
@ -1196,8 +1195,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
|
||||
}
|
||||
|
||||
// TODO(b/33011107): Support cross replica sum on CPU.
|
||||
return Unimplemented(
|
||||
"Cross replica sum is not implemented on CPU. See b/33011107.");
|
||||
return Unimplemented("CrossReplicaSum is not implemented on CPU.");
|
||||
}
|
||||
|
||||
// Fills up the free variables in 'index_with_free_var' with values from
|
||||
@ -1811,12 +1809,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
|
||||
|
||||
Status IrEmitter::HandleSend(HloInstruction* send) {
|
||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||
return Unimplemented("Send is not implemented on CPU. See b/33942983.");
|
||||
return Unimplemented("Send is not implemented on CPU.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
|
||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||
return Unimplemented("Send-done is not implemented on CPU. See b/33942983.");
|
||||
return Unimplemented("Send-done is not implemented on CPU.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSlice(HloInstruction* slice) {
|
||||
@ -1981,12 +1979,12 @@ Status IrEmitter::HandleDynamicUpdateSlice(
|
||||
|
||||
Status IrEmitter::HandleRecv(HloInstruction* recv) {
|
||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||
return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
|
||||
return Unimplemented("Recv is not implemented on CPU.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
|
||||
// TODO(b/33942983): Support Send/Recv on CPU.
|
||||
return Unimplemented("Recv-done is not implemented on CPU. See b/33942983.");
|
||||
return Unimplemented("Recv-done is not implemented on CPU.");
|
||||
}
|
||||
|
||||
Status IrEmitter::HandlePad(HloInstruction* pad) {
|
||||
@ -1995,10 +1993,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
|
||||
for (auto& padding_dimension : pad->padding_config().dimensions()) {
|
||||
if (padding_dimension.edge_padding_low() < 0 ||
|
||||
padding_dimension.edge_padding_high() < 0) {
|
||||
return Unimplemented(
|
||||
"Negative padding not supported in the CPU backend (b/34628603); "
|
||||
"this should have been eliminated at the HLO level: %s",
|
||||
pad->padding_config().ShortDebugString().c_str());
|
||||
return InternalErrorStrCat(
|
||||
"Encountered negative padding in IrEmitter on CPU. "
|
||||
"This should have been eliminated at the HLO level. ",
|
||||
pad->ToString());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
|
||||
perftools::gputools::Platform* platform,
|
||||
const perftools::gputools::Platform* platform,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
stream_executors)
|
||||
: DeviceMemoryAllocator(platform),
|
||||
|
@ -33,7 +33,7 @@ class DeviceMemoryAllocator {
|
||||
public:
|
||||
// Parameter platform indicates which platform the allocator allocates memory
|
||||
// on. Must be non-null.
|
||||
explicit DeviceMemoryAllocator(perftools::gputools::Platform* platform)
|
||||
explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform)
|
||||
: platform_(platform) {}
|
||||
virtual ~DeviceMemoryAllocator() {}
|
||||
|
||||
@ -49,14 +49,14 @@ class DeviceMemoryAllocator {
|
||||
int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0;
|
||||
|
||||
// Return the platform that the allocator allocates memory on.
|
||||
perftools::gputools::Platform* platform() const { return platform_; }
|
||||
const perftools::gputools::Platform* platform() const { return platform_; }
|
||||
|
||||
// Can we call Deallocate() as soon as a computation has been scheduled on
|
||||
// a stream, or do we have to wait for the computation to complete first?
|
||||
virtual bool AllowsAsynchronousDeallocation() const = 0;
|
||||
|
||||
protected:
|
||||
perftools::gputools::Platform* platform_;
|
||||
const perftools::gputools::Platform* platform_;
|
||||
};
|
||||
|
||||
// Default memory allocator for a platform which uses
|
||||
@ -64,7 +64,7 @@ class DeviceMemoryAllocator {
|
||||
class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
|
||||
public:
|
||||
StreamExecutorMemoryAllocator(
|
||||
perftools::gputools::Platform* platform,
|
||||
const perftools::gputools::Platform* platform,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
|
||||
stream_executors);
|
||||
|
||||
|
@ -428,7 +428,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
|
||||
llvm::Intrinsic::round, {operand_value}, {operand_value->getType()},
|
||||
ir_builder_);
|
||||
case HloOpcode::kSign: {
|
||||
// TODO(b/32151903): Ensure consistent sign behavior for -0.0
|
||||
// TODO(b/32151903): Ensure consistent sign behavior for -0.0.
|
||||
auto type = operand_value->getType();
|
||||
auto zero = llvm::ConstantFP::get(type, 0.0);
|
||||
auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
|
||||
@ -870,7 +870,10 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
|
||||
llvm::Value* x) const {
|
||||
if (prim_type != F32) {
|
||||
return Unimplemented("inverse erf only implemented for F32 (b/34339814)");
|
||||
// TODO(b/34339814): Implement inverse erf for F64.
|
||||
return Unimplemented(
|
||||
"Inverse erf is only implemented for element "
|
||||
"type F32.");
|
||||
}
|
||||
auto getFloat = [&](const float f) {
|
||||
return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
|
||||
@ -1040,17 +1043,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
||||
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
|
||||
lhs_value, rhs_value, ir_builder_);
|
||||
case HloOpcode::kMinimum:
|
||||
return ir_builder_->CreateSelect(
|
||||
ir_builder_->CreateICmp(
|
||||
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
|
||||
lhs_value, rhs_value),
|
||||
lhs_value, rhs_value);
|
||||
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
|
||||
case HloOpcode::kMaximum:
|
||||
return ir_builder_->CreateSelect(
|
||||
ir_builder_->CreateICmp(
|
||||
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
|
||||
lhs_value, rhs_value),
|
||||
lhs_value, rhs_value);
|
||||
return EmitIntegralMax(lhs_value, rhs_value, is_signed);
|
||||
case HloOpcode::kAnd:
|
||||
return ir_builder_->CreateAnd(lhs_value, rhs_value);
|
||||
case HloOpcode::kOr:
|
||||
@ -1067,6 +1062,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value,
|
||||
bool is_signed) const {
|
||||
return ir_builder_->CreateSelect(
|
||||
ir_builder_->CreateICmp(
|
||||
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
|
||||
lhs_value, rhs_value),
|
||||
lhs_value, rhs_value);
|
||||
}
|
||||
|
||||
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value,
|
||||
bool is_signed) const {
|
||||
return ir_builder_->CreateSelect(
|
||||
ir_builder_->CreateICmp(
|
||||
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
|
||||
lhs_value, rhs_value),
|
||||
lhs_value, rhs_value);
|
||||
}
|
||||
|
||||
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
|
||||
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
|
||||
int64 operand_no) const {
|
||||
@ -1363,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
|
||||
operand_to_generator.at(hlo->operand(2))(
|
||||
ElementwiseSourceIndex(index, *hlo, 2)));
|
||||
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
|
||||
PrimitiveType prim_type = hlo->shape().element_type();
|
||||
if (primitive_util::IsFloatingPointType(prim_type)) {
|
||||
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
|
||||
} else if (primitive_util::IsIntegralType(prim_type)) {
|
||||
bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
|
||||
return EmitIntegralMin(
|
||||
max_value, EmitIntegralMax(min_value, arg_value, is_signed),
|
||||
is_signed);
|
||||
} else {
|
||||
return Unimplemented("Clamp unimplemented for %s",
|
||||
PrimitiveType_Name(prim_type).c_str());
|
||||
}
|
||||
};
|
||||
case HloOpcode::kReducePrecision:
|
||||
return [this, hlo, &operand_to_generator](
|
||||
|
@ -86,6 +86,12 @@ class ElementalIrEmitter {
|
||||
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
|
||||
llvm::Value* rhs_value) const;
|
||||
|
||||
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
bool is_signed) const;
|
||||
|
||||
llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
|
||||
bool is_signed) const;
|
||||
|
||||
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
|
||||
llvm::Value* value) const;
|
||||
|
||||
|
@ -131,6 +131,7 @@ cc_library(
|
||||
"ir_emitter_context.h",
|
||||
],
|
||||
deps = [
|
||||
":cudnn_convolution_runner",
|
||||
":elemental_ir_emitter",
|
||||
":gpu_constants",
|
||||
":gpu_executable",
|
||||
@ -262,6 +263,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":buffer_allocations",
|
||||
":cudnn_convolution_runner",
|
||||
":infeed_manager",
|
||||
":ir_emission_utils",
|
||||
":partition_assignment",
|
||||
@ -309,9 +311,41 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "convolution_folding",
|
||||
srcs = ["convolution_folding.cc"],
|
||||
hdrs = ["convolution_folding.h"],
|
||||
name = "cudnn_convolution_algorithm_picker",
|
||||
srcs = ["cudnn_convolution_algorithm_picker.cc"],
|
||||
hdrs = ["cudnn_convolution_algorithm_picker.h"],
|
||||
deps = [
|
||||
":cudnn_convolution_runner",
|
||||
":gpu_executable",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla/service:device_memory_allocator",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cudnn_convolution_runner",
|
||||
srcs = ["cudnn_convolution_runner.cc"],
|
||||
hdrs = ["cudnn_convolution_runner.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cudnn_convolution_rewriter",
|
||||
srcs = ["cudnn_convolution_rewriter.cc"],
|
||||
hdrs = ["cudnn_convolution_rewriter.h"],
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
@ -325,15 +359,18 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "convolution_folding_test",
|
||||
srcs = ["convolution_folding_test.cc"],
|
||||
name = "cudnn_convolution_rewriter_test",
|
||||
srcs = ["cudnn_convolution_rewriter_test.cc"],
|
||||
deps = [
|
||||
":convolution_folding",
|
||||
":cudnn_convolution_rewriter",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:shape_inference",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
@ -446,7 +483,8 @@ cc_library(
|
||||
srcs = ["gpu_compiler.cc"],
|
||||
hdrs = ["gpu_compiler.h"],
|
||||
deps = [
|
||||
":convolution_folding",
|
||||
":cudnn_convolution_algorithm_picker",
|
||||
":cudnn_convolution_rewriter",
|
||||
":fusion_merger",
|
||||
":gpu_constants",
|
||||
":gpu_copy_insertion",
|
||||
@ -514,7 +552,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"@llvm//:core",
|
||||
],
|
||||
)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user