Merge branch 'master' of https://github.com/tensorflow/tensorflow into tensorrt to fix some of the failed tests.

This commit is contained in:
gracehoney 2018-02-05 19:44:53 -08:00
commit 149fc8dbd6
1402 changed files with 38564 additions and 17667 deletions

View File

@ -4,7 +4,7 @@ https://stackoverflow.com/questions/tagged/tensorflow
If you open a GitHub issue, here is our policy:
1. It must be a bug or a feature request.
1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead).
2. The form below must be filled out.
3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues).

View File

@ -6,7 +6,7 @@
| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** |
|-----------------|---------------------|------------------|-------------------|---------------|
| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) |
| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) |
**TensorFlow** is an open source software library for numerical computation using
data flow graphs. The graph nodes represent mathematical operations, while

View File

@ -1,18 +1,39 @@
# Release 1.5.0
## Breaking Changes
* Prebuilt binaries are now built against CUDA 9 and cuDNN 7.
* Prebuilt binaries are now built against CUDA 9.0 and cuDNN 7.
* Our Linux binaries are built using ubuntu 16 containers, potentially
introducing glibc incompatibility issues with ubuntu 14.
* Starting from 1.6 release, our prebuilt binaries will use AVX instructions.
This may break TF on older CPUs.
## Known Bugs
* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or
`CUDA_ILLEGAL_ADDRESS` failures.
Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9
and CUDA 9.1 sometimes does not properly compute the carry bit when
decomposing 64-bit address calculations with large offsets (e.g. `load [x +
large_constant]`) into 32-bit arithmetic in SASS.
As a result, these versions of `ptxas` miscompile most XLA programs which use
more than 4GB of temp memory. This results in garbage results and/or
`CUDA_ERROR_ILLEGAL_ADDRESS` failures.
A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a
fix for CUDA 9.0.x. Until the fix is available, the only workaround is to
[downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x
or disable XLA:GPU.
TensorFlow will print a warning if you use XLA:GPU with a known-bad version of
CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122.
## Major Features And Improvements
* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager)
preview version is now available.
* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite)
dev preview is now available.
* CUDA 9 and cuDNN 7 support.
* CUDA 9.0 and cuDNN 7 support.
* Accelerated Linear Algebra (XLA):
* Add `complex64` support to XLA compiler.
* `bfloat` support is now added to XLA infrastructure.

View File

@ -298,7 +298,7 @@ def get_var(environ_cp,
System".
enabled_by_default: boolean for default behavior.
question: optional string for how to ask for user input.
yes_reply: optionanl string for reply when feature is enabled.
yes_reply: optional string for reply when feature is enabled.
no_reply: optional string for reply when feature is disabled.
Returns:
@ -411,7 +411,7 @@ def set_action_env_var(environ_cp,
System".
enabled_by_default: boolean for default behavior.
question: optional string for how to ask for user input.
yes_reply: optionanl string for reply when feature is enabled.
yes_reply: optional string for reply when feature is enabled.
no_reply: optional string for reply when feature is disabled.
"""
var = int(
@ -1354,6 +1354,7 @@ def main():
environ_cp['TF_NEED_GCP'] = '0'
environ_cp['TF_NEED_HDFS'] = '0'
environ_cp['TF_NEED_JEMALLOC'] = '0'
environ_cp['TF_NEED_KAFKA'] = '0'
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
@ -1372,6 +1373,8 @@ def main():
'with_hdfs_support', True, 'hdfs')
set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
'with_s3_support', True, 's3')
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
'with_kafka_support', False, 'kafka')
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
False, 'xla')
set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support',

View File

@ -211,6 +211,12 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "with_kafka_support",
define_values = {"with_kafka_support": "true"},
visibility = ["//visibility:public"],
)
# Crosses between platforms and file system libraries not supported on those
# platforms due to limitations in nested select() statements.
config_setting(
@ -544,8 +550,10 @@ filegroup(
"//tensorflow/contrib/predictor:all_files",
"//tensorflow/contrib/py2tf:all_files",
"//tensorflow/contrib/py2tf/converters:all_files",
"//tensorflow/contrib/py2tf/impl:all_files",
"//tensorflow/contrib/py2tf/pyct:all_files",
"//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
"//tensorflow/contrib/py2tf/utils:all_files",
"//tensorflow/contrib/quantize:all_files",
"//tensorflow/contrib/receptive_field:all_files",
"//tensorflow/contrib/reduce_slice_ops:all_files",

239
tensorflow/SECURITY.md Normal file
View File

@ -0,0 +1,239 @@
# Using TensorFlow Securely
This document discusses how to safely deal with untrusted programs (models or
model parameters), and input data. Below, we also provide guidelines on how to
report vulnerabilities in TensorFlow.
## TensorFlow models are programs
TensorFlow's runtime system interprets and executes programs. What machine
learning practitioners term
[**models**](https://developers.google.com/machine-learning/glossary/#model) are
expressed as programs that TensorFlow executes. TensorFlow programs are encoded
as computation
[**graphs**](https://developers.google.com/machine-learning/glossary/#graph).
The model's parameters are often stored separately in **checkpoints**.
At runtime, TensorFlow executes the computation graph using the parameters
provided. Note that the behavior of the computation graph may change
depending on the parameters provided. TensorFlow itself is not a sandbox. When
executing the computation graph, TensorFlow may read and write files, send and
receive data over the network, and even spawn additional processes. All these
tasks are performed with the permissions of the TensorFlow process. Allowing
for this flexibility makes for a powerful machine learning platform,
but it has implications for security.
The computation graph may also accept **inputs**. Those inputs are the
data you supply to TensorFlow to train a model, or to use a model to run
inference on the data.
**TensorFlow models are programs, and need to be treated as such from a security
perspective.**
## Running untrusted models
As a general rule: **Always** execute untrusted models inside a sandbox (e.g.,
[nsjail](https://github.com/google/nsjail)).
There are several ways in which a model could become untrusted. Obviously, if an
untrusted party supplies TensorFlow kernels, arbitrary code may be executed.
The same is true if the untrusted party provides Python code, such as the
Python code that generates TensorFlow graphs.
Even if the untrusted party only supplies the serialized computation
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
set of computation primitives available to TensorFlow is powerful enough that
you should assume that the TensorFlow process effectively executes arbitrary
code. One common solution is to whitelist only a few safe Ops. While this is
possible in theory, we still recommend you sandbox the execution.
It depends on the computation graph whether a user provided checkpoint is safe.
It is easily possible to create computation graphs in which malicious
checkpoints can trigger unsafe behavior. For example, consider a graph that
contains a `tf.cond` depending on the value of a `tf.Variable`. One branch of
the `tf.cond` is harmless, but the other is unsafe. Since the `tf.Variable` is
stored in the checkpoint, whoever provides the checkpoint now has the ability to
trigger unsafe behavior, even though the graph is not under their control.
In other words, graphs can contain vulnerabilities of their own. To allow users
to provide checkpoints to a model you run on their behalf (e.g., in order to
compare model quality for a fixed model architecture), you must carefully audit
your model, and we recommend you run the TensorFlow process in a sandbox.
## Accepting untrusted Inputs
It is possible to write models that are secure in a sense that they can safely
process untrusted inputs assuming there are no bugs. There are two main reasons
to not rely on this: first, it is easy to write models which must not be exposed
to untrusted inputs, and second, there are bugs in any software system of
sufficient complexity. Letting users control inputs could allow them to trigger
bugs either in TensorFlow or in dependent libraries.
In general, it is good practice to isolate parts of any system which is exposed
to untrusted (e.g., user-provided) inputs in a sandbox.
A useful analogy to how any TensorFlow graph is executed is any interpreted
programming language, such as Python. While it is possible to write secure
Python code which can be exposed to user supplied inputs (by, e.g., carefully
quoting and sanitizing input strings, size-checking input blobs, etc.), it is
very easy to write Python programs which are insecure. Even secure Python code
could be rendered insecure by a bug in the Python interpreter, or in a bug in a
Python library used (e.g.,
[this one](https://www.cvedetails.com/cve/CVE-2017-12852/)).
## Running a TensorFlow server
TensorFlow is a platform for distributed computing, and as such there is a
TensorFlow server (`tf.train.Server`). **The TensorFlow server is meant for
internal communication only. It is not built for use in an untrusted network.**
For performance reasons, the default TensorFlow server does not include any
authorization protocol and sends messages unencrypted. It accepts connections
from anywhere, and executes the graphs it is sent without performing any checks.
Therefore, if you run a `tf.train.Server` in your network, anybody with
access to the network can execute what you should consider arbitrary code with
the privileges of the process running the `tf.train.Server`.
When running distributed TensorFlow, you must isolate the network in which the
cluster lives. Cloud providers provide instructions for setting up isolated
networks, which are sometimes branded as "virtual private cloud." Refer to the
instructions for
[GCP](https://cloud.google.com/compute/docs/networks-and-firewalls) and
[AWS](https://aws.amazon.com/vpc/)) for details.
Note that `tf.train.Server` is different from the server created by
`tensorflow/serving` (the default binary for which is called `ModelServer`).
By default, `ModelServer` also has no built-in mechanism for authentication.
Connecting it to an untrusted network allows anyone on this network to run the
graphs known to the `ModelServer`. This means that an attacker may run
graphs using untrusted inputs as described above, but they would not be able to
execute arbitrary graphs. It is possible to safely expose a `ModelServer`
directly to an untrusted network, **but only if the graphs it is configured to
use have been carefully audited to be safe**.
Similar to best practices for other servers, we recommend running any
`ModelServer` with appropriate privileges (i.e., using a separate user with
reduced permisisons). In the spirit of defense in depth, we recommend
authenticating requests to any TensorFlow server connected to an untrusted
network, as well as sandboxing the server to minimize the adverse effects of
any breach.
## Vulnerabilities in TensorFlow
TensorFlow is a large and complex system. It also depends on a large set of
third party libraries (e.g., `numpy`, `libjpeg-turbo`, PNG parsers, `protobuf`).
It is possible that TensorFlow or its dependent libraries contain
vulnerabilities that would allow triggering unexpected or dangerous behavior
with specially crafted inputs.
### What is a vulnerability?
Given TensorFlow's flexibility, it is possible to specify computation graphs
which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models
can perform arbitrary computations means that they may read and write files,
communicate via the network, produce deadlocks and infinite loops, or run out
of memory. It is only when these behaviors are outside the specifications of the
operations involved that such behavior is a vulnerability.
A `FileWriter` writing a file is not unexpected behavior and therefore is not a
vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution
**is** a vulnerability.
This is more subtle from a system perspective. For example, it is easy to cause
a TensorFlow process to try to allocate more memory than available by specifying
a computation graph containing an ill-considered `tf.tile` operation. TensorFlow
should exit cleanly in this case (it would raise an exception in Python, or
return an error `Status` in C++). However, if the surrounding system is not
expecting the possibility, such behavior could be used in a denial of service
attack (or worse). Because TensorFlow behaves correctly, this is not a
vulnerability in TensorFlow (although it would be a vulnerability of this
hypothetical system).
As a general rule, it is incorrect behavior for Tensorflow to access memory it
does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to
such behaviors constitute a vulnerability.
One of the most critical parts of any system is input handling. If malicious
input can trigger side effects or incorrect behavior, this is a bug, and likely
a vulnerability.
### Reporting vulnerabilities
Please email reports about any security related issues you find to
`security@tensorflow.org`. This mail is delivered to a small security team. Your
email will be acknowledged within one business day, and you'll receive a more
detailed response to your email within 7 days indicating the next steps in
handling your report. For critical problems, you may encrypt your report (see
below).
Please use a descriptive subject line for your report email. After the initial
reply to your report, the security team will endeavor to keep you informed of
the progress being made towards a fix and announcement.
If you believe that an existing (public) issue is security-related, please send
an email to `security@tensorflow.org`. The email should include the issue ID and
a short description of why it should be handled according to this security
policy.
Once an issue is reported, TensorFlow uses the following disclosure process:
* When a report is received, we confirm the issue and determine its severity.
* If we know of specific third-party services or software based on TensorFlow
that require mitigation before publication, those projects will be notified.
* An advisory is prepared (but not published) which details the problem and
steps for mitigation.
* Wherever possible, fixes are prepared for the last minor release of the two
latest major releases, as well as the master branch. We will attempt to
commit these fixes as soon as possible, and as close together as
possible.
* Patch releases are published for all fixed released versions, a
notification is sent to discuss@tensorflow.org, and the advisory is published.
Past security advisories are listed below. We credit reporters for identifying
security issues, although we keep your name confidential if you request it.
#### Encryption key for `security@tensorflow.org`
If your disclosure is extremely sensitive, you may choose to encrypt your
report using the key below. Please only use this for critical security
reports.
```
-----BEGIN PGP PUBLIC KEY BLOCK-----
mQENBFpqdzwBCADTeAHLNEe9Vm77AxhmGP+CdjlY84O6DouOCDSq00zFYdIU/7aI
LjYwhEmDEvLnRCYeFGdIHVtW9YrVktqYE9HXVQC7nULU6U6cvkQbwHCdrjaDaylP
aJUXkNrrxibhx9YYdy465CfusAaZ0aM+T9DpcZg98SmsSml/HAiiY4mbg/yNVdPs
SEp/Ui4zdIBNNs6at2gGZrd4qWhdM0MqGJlehqdeUKRICE/mdedXwsWLM8AfEA0e
OeTVhZ+EtYCypiF4fVl/NsqJ/zhBJpCx/1FBI1Uf/lu2TE4eOS1FgmIqb2j4T+jY
e+4C8kGB405PAC0n50YpOrOs6k7fiQDjYmbNABEBAAG0LVRlbnNvckZsb3cgU2Vj
dXJpdHkgPHNlY3VyaXR5QHRlbnNvcmZsb3cub3JnPokBTgQTAQgAOBYhBEkvXzHm
gOJBnwP4Wxnef3wVoM2yBQJaanc8AhsDBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheA
AAoJEBnef3wVoM2yNlkIAICqetv33MD9W6mPAXH3eon+KJoeHQHYOuwWfYkUF6CC
o+X2dlPqBSqMG3bFuTrrcwjr9w1V8HkNuzzOJvCm1CJVKaxMzPuXhBq5+DeT67+a
T/wK1L2R1bF0gs7Pp40W3np8iAFEh8sgqtxXvLGJLGDZ1Lnfdprg3HciqaVAiTum
HBFwszszZZ1wAnKJs5KVteFN7GSSng3qBcj0E0ql2nPGEqCVh+6RG/TU5C8gEsEf
3DX768M4okmFDKTzLNBm+l08kkBFt+P43rNK8dyC4PXk7yJa93SmS/dlK6DZ16Yw
2FS1StiZSVqygTW59rM5XNwdhKVXy2mf/RtNSr84gSi5AQ0EWmp3PAEIALInfBLR
N6fAUGPFj+K3za3PeD0fWDijlC9f4Ety/icwWPkOBdYVBn0atzI21thPRbfuUxfe
zr76xNNrtRRlbDSAChA1J5T86EflowcQor8dNC6fS+oHFCGeUjfEAm16P6mGTo0p
osdG2XnnTHOOEFbEUeWOwR/zT0QRaGGknoy2pc4doWcJptqJIdTl1K8xyBieik/b
nSoClqQdZJa4XA3H9G+F4NmoZGEguC5GGb2P9NHYAJ3MLHBHywZip8g9oojIwda+
OCLL4UPEZ89cl0EyhXM0nIAmGn3Chdjfu3ebF0SeuToGN8E1goUs3qSE77ZdzIsR
BzZSDFrgmZH+uP0AEQEAAYkBNgQYAQgAIBYhBEkvXzHmgOJBnwP4Wxnef3wVoM2y
BQJaanc8AhsMAAoJEBnef3wVoM2yX4wIALcYZbQhSEzCsTl56UHofze6C3QuFQIH
J4MIKrkTfwiHlCujv7GASGU2Vtis5YEyOoMidUVLlwnebE388MmaJYRm0fhYq6lP
A3vnOCcczy1tbo846bRdv012zdUA+wY+mOITdOoUjAhYulUR0kiA2UdLSfYzbWwy
7Obq96Jb/cPRxk8jKUu2rqC/KDrkFDtAtjdIHh6nbbQhFuaRuWntISZgpIJxd8Bt
Gwi0imUVd9m9wZGuTbDGi6YTNk0GPpX5OMF5hjtM/objzTihSw9UN+65Y/oSQM81
v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
=CDME
-----END PGP PUBLIC KEY BLOCK-----
```
### Known vulnerabilities
| Type | Versions affected | Reported by | Additional Information |
|------|:-----------------:|---------------------------------------|
| out of bounds read| <=1.4 | @zhangbo5891001 | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) |

View File

@ -195,10 +195,10 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims,
reinterpret_cast<intptr_t>(data) % EIGEN_MAX_ALIGN_BYTES != 0) {
// TF_STRING and TF_RESOURCE tensors have a different representation in
// TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste
// (any alignement requirements will be taken care of by TF_TensorToTensor
// (any alignment requirements will be taken care of by TF_TensorToTensor
// and TF_TensorFromTensor).
//
// Other types have the same represntation, so copy only if it is safe to do
// Other types have the same representation, so copy only if it is safe to do
// so.
buf->data_ = allocate_tensor("TF_NewTensor", len);
std::memcpy(buf->data_, data, len);
@ -2144,7 +2144,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph,
opts.return_tensors.push_back(ToTensorId(nodes_to_return[i]));
}
// TOOD(skyewm): change to OutputTensor
// TODO(skyewm): change to OutputTensor
tensorflow::ImportGraphDefResults results;
TF_RETURN_IF_ERROR(
ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results));

View File

@ -46,6 +46,7 @@ tf_cuda_library(
"//tensorflow/c:c_api",
"//tensorflow/c:c_api_internal",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib_internal",

View File

@ -85,15 +85,7 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
return nullptr;
}
TFE_Context* ret = new TFE_Context(session);
ret->policy = opts->policy;
ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime(
ret->session->device_mgr, opts->session_options.options.env,
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {}));
ret->rendezvous =
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);
return ret;
return new TFE_Context(*opts, session);
}
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
@ -261,15 +253,6 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
void TFE_DeleteOp(TFE_Op* op) { delete op; }
static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device,
TF_Status* status) {
// Questionable heuristic: Place the op on the same device as the first input
// placed outside of host memory?
if (IsCPU(op->device) && !IsCPU(device)) {
op->device = device;
}
}
void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
tensorflow::Device* d = nullptr;
if (device_name != nullptr && strlen(device_name) > 0) {
@ -277,11 +260,24 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
op->ctx->session->device_mgr->LookupDevice(device_name, &d);
if (!status->status.ok()) return;
}
TFE_OpSetDeviceHelper(op, d, status);
op->device = d;
}
const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
tensorflow::Device* device =
(op->device == nullptr) ? op->ctx->devices()[0] : op->device;
return device->name().c_str();
}
void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
TFE_OpSetDeviceHelper(op, h->d, status);
// Questionable heuristic ...
//
// Motivation: After an 'op' is placed on GPU because some of its earlier
// inputs are on GPU, we want to keep the 'op' there, even if some later
// inputs of it are not on GPU.
if (IsCPU(op->device) && !IsCPU(h->d)) {
op->device = h->d;
}
if (!status->status.ok()) return;
op->inputs.push_back(h->t);
op->input_devices.push_back(h->d);
@ -298,7 +294,7 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
return TF_ATTR_INT; // The compiler requires that we return something.
}
status->status =
tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list);
tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
return ret;
}

View File

@ -154,6 +154,9 @@ TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op);
TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name,
TF_Status* status);
// The returned string remains valid throughout the lifetime of 'op'.
TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op,
TF_Status* status);
TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status);

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
struct TFE_ContextOptions {
TF_SessionOptions session_options;
@ -43,9 +44,15 @@ struct TFE_ContextOptions {
};
struct TFE_Context {
explicit TFE_Context(TF_Session* s) : session(s) {}
explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s)
: policy(opts.policy),
session(s),
rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)),
pflr(new tensorflow::ProcessFunctionLibraryRuntime(
session->device_mgr, opts.session_options.options.env,
TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {}
TFE_ContextDevicePlacementPolicy policy;
const TFE_ContextDevicePlacementPolicy policy;
// Note: we cannot use C++11 thread_local here as there is no concept of a
// thread-local-object-local variable in C++11.
@ -54,8 +61,8 @@ struct TFE_Context {
thread_local_policies GUARDED_BY(policy_map_mu);
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
TF_Session* session;
tensorflow::Rendezvous* rendezvous;
TF_Session* const session;
tensorflow::Rendezvous* const rendezvous;
tensorflow::mutex functions_mu;
tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){
@ -64,14 +71,14 @@ struct TFE_Context {
// One FunctionLibraryRuntime per device.
// func_libs[i] is the FunctionLibraryRuntime corresponding to
// session->devices[i].
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
const std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
tensorflow::mutex cache_mu;
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
tensorflow::Fprint128Hasher>
kernel_cache GUARDED_BY(cache_mu);
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const {
return pflr->GetFLR(d->name());
}
@ -100,6 +107,8 @@ struct TFE_TensorHandle {
};
struct TFE_Op {
// t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a
// primitive operation.
TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t)
: ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {}

View File

@ -60,6 +60,31 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
return op;
}
// If there is a GPU device, returns true and sets 'gpu_device_name'
// accordingly.
bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
for (int i = 0; i < num_devices; ++i) {
const string device_type(TF_DeviceListType(devices, i, status.get()));
CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
const string device_name(TF_DeviceListName(devices, i, status.get()));
CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
if (device_type == "GPU") {
*gpu_device_name = device_name;
LOG(INFO) << "Found GPU device " << device_name;
TF_DeleteDeviceList(devices);
return true;
}
}
TF_DeleteDeviceList(devices);
return false;
}
void BM_InitOp(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
@ -288,22 +313,15 @@ TEST(CAPI, TensorHandleSilentCopy) {
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
// Disable the test if no GPU is present.
if (num_devices > 1) {
const int device_to_use = 1;
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hgpu =
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
string gpu_device_name;
if (GetGPUDeviceName(ctx, &gpu_device_name)) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, name.c_str(), status.get());
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
@ -314,7 +332,6 @@ TEST(CAPI, TensorHandleSilentCopy) {
TFE_DeleteTensorHandle(hgpu);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
@ -337,22 +354,15 @@ TEST(CAPI, TensorHandleSilentCopyLocal) {
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
// Disable the test if no GPU is present.
if (num_devices > 1) {
const int device_to_use = 1;
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hgpu =
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
string gpu_device_name;
if (GetGPUDeviceName(ctx, &gpu_device_name)) {
TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice(
hcpu, ctx, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, name.c_str(), status.get());
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
@ -363,13 +373,43 @@ TEST(CAPI, TensorHandleSilentCopyLocal) {
TFE_DeleteTensorHandle(hgpu);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, SetAndGetOpDevices) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
// Disable the test if no GPU is present.
string gpu_device_name;
if (GetGPUDeviceName(ctx, &gpu_device_name)) {
TFE_OpSetDevice(matmul, "GPU:0", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
const char* device_name = TFE_OpGetDevice(matmul, status);
ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr);
TFE_OpSetDevice(matmul, "CPU:0", status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
device_name = TFE_OpGetDevice(matmul, status);
ASSERT_TRUE(strstr(device_name, "CPU:0") != nullptr);
}
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
TFE_DeleteContext(ctx, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
}
TEST(CAPI, Execute) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();

View File

@ -86,10 +86,9 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
return Status::OK();
}
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
TF_AttrType* out, unsigned char* is_list) {
CHECK(m);
auto* t = gtl::FindOrNull(*m, attr_name);
auto* t = gtl::FindOrNull(m, attr_name);
if (t == nullptr) {
return errors::InvalidArgument("Attribute '", attr_name,
"' does not exist for this operation");
@ -173,14 +172,14 @@ void CombineUnordered(const tensorflow::Fprint128& a,
b->high64 += a.high64;
}
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s,
inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s,
const tensorflow::Fprint128& b) {
// TODO(agarwal): avoid ToString().
tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString());
return FingerprintCat128(a, b);
}
inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) {
inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) {
return CacheKeyHelper(s, {b, b});
}

View File

@ -43,7 +43,7 @@ typedef std::unordered_map<string, uint32> AttrTypeMap;
Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out);
// Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'.
Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name,
Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name,
TF_AttrType* out, unsigned char* is_list);
// KernelAndDevice::Init needs a NodeDef only to pass the attribute map through.

View File

@ -63,17 +63,17 @@ TEST(AttrTypeMap, Lookup) {
TF_AttrType t;
unsigned char is_list = 1;
s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
s = AttrTypeByName(*m, "ThisAttribyteCannotPossiblyExist", &t, &is_list);
EXPECT_FALSE(s.ok());
EXPECT_NE(is_list, 0);
s = AttrTypeByName(m, "transpose_a", &t, &is_list);
s = AttrTypeByName(*m, "transpose_a", &t, &is_list);
ASSERT_TRUE(s.ok()) << s;
EXPECT_EQ(TF_ATTR_BOOL, t);
EXPECT_EQ(is_list, 0);
s = AttrTypeMapForOp("Squeeze", &m);
ASSERT_TRUE(s.ok()) << s;
s = AttrTypeByName(m, "squeeze_dims", &t, &is_list);
s = AttrTypeByName(*m, "squeeze_dims", &t, &is_list);
ASSERT_TRUE(s.ok()) << s;
EXPECT_EQ(TF_ATTR_INT, t);
EXPECT_NE(is_list, 0);

View File

@ -18,12 +18,12 @@ limitations under the License.
// Language-agnostic gradient tape. Does not perform backpropagation, just
// maintains the data structures required to do so.
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@ -54,11 +54,11 @@ struct OpTapeEntry {
// Map from tensor_id to internally-defined operation-id of the operation which
// produced this tensor. A value of -1 means that the tensor was directly
// watched and not the result of any operation in the tape.
using TensorTape = std::unordered_map<int64, int64>;
using TensorTape = gtl::FlatMap<int64, int64>;
// Map from operation-id to tape entry.
template <typename BackwardFunction>
using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
// Operations the tape needs to perform on tensors to do backpropagation. Named
// "vspace" because a subset of these are related to a vector space, such as
@ -159,7 +159,7 @@ class GradientTape {
// Map from tensor id to number of remaining usages (i.e. how many entries in
// the tape refer to it); to aid in tape garbage collection.
std::unordered_map<int64, int64> tensor_usage_;
gtl::FlatMap<int64, int64> tensor_usage_;
// If false, all activations are deleted in the first call to ComputeGradient.
// Else, only when this is destructed.
@ -286,11 +286,11 @@ struct BackpropInitialState {
// Map from tensor ID to how many references still exist for this tensor in
// the tape.
std::unordered_map<int64, int64> tensor_usage_counts;
gtl::FlatMap<int64, int64> tensor_usage_counts;
// Maps from op ID to how many output tensors of this op still need to have
// their gradients computed.
std::unordered_map<int64, int64> op_missing_tensor;
gtl::FlatMap<int64, int64> op_missing_tensor;
};
// If `persistent_tape` is true, op_tape is not changed and none of the
@ -301,8 +301,8 @@ struct BackpropInitialState {
template <typename BackwardFunction>
BackpropInitialState<BackwardFunction> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
OpTape<BackwardFunction>* op_tape,
const std::unordered_set<int64>& sources_set, bool persistent_tape) {
OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
bool persistent_tape) {
std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size());
for (auto t : target) {
@ -362,7 +362,7 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
template <typename BackwardFunction>
std::vector<int64> InitialStack(
const OpTape<BackwardFunction>& op_tape,
const std::unordered_map<int64, int64>& op_missing_tensor) {
const gtl::FlatMap<int64, int64>& op_missing_tensor) {
std::vector<int64> result;
for (auto& op_entry : op_tape) {
if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
@ -373,13 +373,13 @@ std::vector<int64> InitialStack(
}
template <typename Gradient, typename BackwardFunction>
Status InitialGradients(
const VSpace<Gradient, BackwardFunction>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
const OpTape<BackwardFunction>& op_tape,
const std::unordered_map<int64, int64>& tensor_usage_counts,
std::unordered_map<int64, std::vector<Gradient*>>* result) {
Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients,
const TensorTape& tensor_tape,
const OpTape<BackwardFunction>& op_tape,
const gtl::FlatMap<int64, int64>& tensor_usage_counts,
gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
@ -441,13 +441,13 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
gtl::ArraySlice<int64> source_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
BackpropInitialState<BackwardFunction> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
std::unordered_map<int64, std::vector<Gradient*>> gradients;
gtl::FlatMap<int64, std::vector<Gradient*>> gradients;
Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
tensor_tape_, state.op_tape,
state.tensor_usage_counts, &gradients);
@ -463,7 +463,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
cleanup();
return s;
}
std::unordered_map<int64, int64> gradients_size;
gtl::FlatMap<int64, int64> gradients_size;
// TODO(apassos) multiple threads could be dequeuing from op_stack at the same
// time, for better CPU backprop performance.
VLOG(1) << "Initial stack:";
@ -472,11 +472,10 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
VLOG(1) << " " << t;
}
}
std::unordered_map<string, std::unordered_set<int>>
functions_accept_none_for_indices({
{"SoftmaxCrossEntropyWithLogits", {1}},
{"FusedBatchNorm", {1, 2, 3, 4}},
});
gtl::FlatMap<string, gtl::FlatSet<int>> functions_accept_none_for_indices({
{"SoftmaxCrossEntropyWithLogits", {1}},
{"FusedBatchNorm", {1, 2, 3, 4}},
});
while (!op_stack.empty()) {
const int64 op = op_stack.back();
VLOG(1) << "Popped " << op;

View File

@ -433,6 +433,7 @@ tf_gen_op_wrappers_cc(
"linalg_ops",
"logging_ops",
"lookup_ops",
"manip_ops",
"math_ops",
"nn_ops",
"no_op",

View File

@ -96,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
const SessionOptions& session_options,
std::unique_ptr<Session>* session) {
session->reset(NewSession(session_options));
Session* session_p = nullptr;
TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
session->reset(session_p);
return (*session)->Create(meta_graph_def.graph_def());
}

View File

@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) {
<< st.error_message();
}
TEST_F(LoaderTest, SessionCreationFailure) {
SavedModelBundle bundle;
// Use invalid SessionOptions to cause session creation to fail. Default
// options work, so provide an invalid value for the target field.
SessionOptions session_options;
constexpr char kInvalidTarget[] = "invalid target";
session_options.target = kInvalidTarget;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
Status st = LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget))
<< st.error_message();
}
TEST_F(LoaderTest, PbtxtFormat) {
SavedModelBundle bundle;
SessionOptions session_options;

View File

@ -23,7 +23,6 @@ cc_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
],
)

View File

@ -4,7 +4,7 @@
To use from your BUILD file, add the following line to load the macro:
load("@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
Then call the macro like this:
@ -16,14 +16,15 @@ tf_library(
)
"""
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_android", "tf_copts")
load("//tensorflow:tensorflow.bzl",
"if_android", "tf_cc_test", "tf_copts")
def tf_library(name, graph, config,
freeze_checkpoint=None, freeze_saver=None,
cpp_class=None, gen_test=True, gen_benchmark=True,
visibility=None, testonly=None,
tfcompile_flags=None,
tfcompile_tool="@org_tensorflow//tensorflow/compiler/aot:tfcompile",
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
include_standard_runtime_deps=True, deps=None, tags=None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
@ -119,9 +120,9 @@ def tf_library(name, graph, config,
out_nodes_file,
] + freeze_saver_srcs,
outs=[freeze_file],
cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" +
cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
freeze_args),
tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"],
tools=["//tensorflow/python/tools:freeze_graph"],
tags=tags,
)
tfcompile_graph = freeze_file
@ -213,22 +214,22 @@ def tf_library(name, graph, config,
# These deps are required by all tf_library targets even if
# include_standard_runtime_deps is False. Without them, the
# generated code will fail to compile.
"@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"@org_tensorflow//tensorflow/core:framework_lite",
"//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"//tensorflow/core:framework_lite",
] + (need_xla_data_proto and [
# If we're generating the program shape, we must depend on the proto.
"@org_tensorflow//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_data_proto",
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
"//third_party/eigen3",
] or []) + (deps or []),
tags=tags,
@ -254,28 +255,32 @@ def tf_library(name, graph, config,
name=("gen_" + test_name),
testonly=1,
srcs=[
"@org_tensorflow//tensorflow/compiler/aot:test.cc",
"//tensorflow/compiler/aot:test.cc",
header_file,
],
outs=[test_file],
cmd=("sed " + sed_replace +
" $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " +
" $(location //tensorflow/compiler/aot:test.cc) " +
"> $(OUTS)"),
tags=tags,
)
# The cc_test rule for the generated code.
native.cc_test(
# The cc_test rule for the generated code. To ensure that this works
# reliably across build configurations, we must use tf_cc_test instead of
# native.cc_test. This is related to how we build
# //tensorflow/core:lib -- see the note in tensorflow/core/BUILD
# for more details.
tf_cc_test(
name=test_name,
srcs=[test_file],
deps=[
":" + name,
"@org_tensorflow//tensorflow/compiler/aot:runtime",
"@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main",
"@org_tensorflow//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/aot:tf_library_test_main",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:test",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
tags=tags,
)
@ -283,7 +288,7 @@ def tf_library(name, graph, config,
if gen_benchmark:
benchmark_name = name + "_benchmark"
benchmark_file = benchmark_name + ".cc"
benchmark_main = ("@org_tensorflow//tensorflow/compiler/aot:" +
benchmark_main = ("//tensorflow/compiler/aot:" +
"benchmark_main.template")
# Rule to rewrite benchmark.cc to produce the benchmark_file.
@ -301,7 +306,9 @@ def tf_library(name, graph, config,
tags=tags,
)
# The cc_benchmark rule for the generated code.
# The cc_benchmark rule for the generated code. This does not need the
# tf_cc_binary since we (by deliberate design) do not depend on
# //tensorflow/core:lib.
#
# Note: to get smaller size on android for comparison, compile with:
# --copt=-fvisibility=hidden
@ -315,12 +322,12 @@ def tf_library(name, graph, config,
linkopts = if_android(["-pie", "-s"]),
deps=[
":" + name,
"@org_tensorflow//tensorflow/compiler/aot:benchmark",
"@org_tensorflow//tensorflow/compiler/aot:runtime",
"@org_tensorflow//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/aot:benchmark",
"//tensorflow/compiler/aot:runtime",
"//tensorflow/compiler/xla:executable_run_options",
"//third_party/eigen3",
] + if_android([
"@org_tensorflow//tensorflow/compiler/aot:benchmark_extra_android",
"//tensorflow/compiler/aot:benchmark_extra_android",
]),
tags=tags,
)
@ -330,11 +337,11 @@ def target_llvm_triple():
# TODO(toddw): Add target_triple for other targets. For details see:
# http://llvm.org/docs/doxygen/html/Triple_8h_source.html
return select({
"@org_tensorflow//tensorflow:android_armeabi": "armv5-none-android",
"@org_tensorflow//tensorflow:android_arm": "armv7-none-android",
"@org_tensorflow//tensorflow:android_arm64": "aarch64-none-android",
"@org_tensorflow//tensorflow:android_x86": "i686-none-android",
"@org_tensorflow//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"@org_tensorflow//tensorflow:darwin": "x86_64-none-darwin",
"//tensorflow:android_armeabi": "armv5-none-android",
"//tensorflow:android_arm": "armv7-none-android",
"//tensorflow:android_arm64": "aarch64-none-android",
"//tensorflow:android_x86": "i686-none-android",
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
"//tensorflow:darwin": "x86_64-none-darwin",
"//conditions:default": "x86_64-pc-linux",
})

View File

@ -30,12 +30,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@ -141,8 +143,7 @@ struct NodeSlot {
// everything to use it.
static const char* const kArgOp = "_Arg";
static const char* const kRetValOp = "_Retval";
static const char* const kSendToHostOp = "_XlaSendToHost";
static const char* const kRecvFromHostOp = "_XlaRecvFromHost";
static const char* const kHostComputeOp = "_XlaHostCompute";
static const char* const kSendFromHostOp = "_XlaSendFromHost";
static const char* const kRecvAtHostOp = "_XlaRecvAtHost";
@ -171,7 +172,8 @@ class Encapsulator {
// Write a copy of the input graph to 'graph_out', where the subgraphs are
// replaced with calls to the new functions.
Status BuildOutputGraph(bool parallel_checking, Graph* graph_out);
Status BuildOutputGraph(bool parallel_checking, Graph* graph_out,
FunctionLibraryDefinition* library);
private:
// A subgraph of the input, all marked with a common 'group_attribute'
@ -201,21 +203,29 @@ class Encapsulator {
// .. .
// RAH --> C --> SFH
//
// The compiled cluster is as follows. STH is a SendToHost node which is the
// source of a channel to the RAH node above. RFH is a RecvFromHost node which
// is the destination of a channel from the SFH node above. There is a control
// edge that ensures RFH follows STH, which is used in shape inference to
// ensure that the shapes on the STH host channel are known before the RFH
// channel is compiled.
// The compiled cluster is as follows. HC is a HostCompute node which is the
// source of a channel to the RAH node above and the destination of a channel
// from the SFH node above.
//
// Arg --> B --> STH ..> RFH --> D --> Retval
// Arg --> B --> HC --> D --> Retval
//
// The channels STH/RAH and SFH/RFH each transmit a tuple, so there is at most
// one RAH and SFH in each compiled cluster. This design is preferred over
// adding separate Arg/Retval nodes for each transmitted value because it
// simplifies the host code that would like to limit communication between
// host and device and, e.g., raise only one interrupt per channel rather than
// one per transmitted value.
// The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is
// at most one RAH and SFH in each outside_compilation cluster. This design is
// preferred over adding separate Arg/Retval nodes for each transmitted value
// because it allows optimizations to the host code that would like to limit
// communication between host and device and, e.g., raise only one interrupt
// per channel rather than one per transmitted value.
//
// The shapes of the outputs from the HC node in general cannot be determined
// until the shapes of its inputs are known at compile time, since e.g.,
// above, the shape of C's outputs aren't known until the shape of its inputs
// are known. If the shapes of the HC's outputs can be determined during the
// rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal
// graph is stored in the shape_inference_graph attr. This graph can be used
// when compiling the HC Op to determined the shape of the SFH inputs given
// the shapes of any ancestor RAH outputs. If it can be determined that the
// shape of the SFH inputs will not be inferrable even once the shapes of the
// RAH outputs are known, an error is returned by the rewriter.
class Subgraph {
public:
// Creates a graph to build the subgraph in, if it doesn't already exist,
@ -246,6 +256,10 @@ class Encapsulator {
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out);
// Returns the names of all the outside_compilation subgraphs in this
// Subgraph.
void GetOutsideCompilationSubgraphNames(std::vector<string>* names) const;
// Returns the Node that inputs to the function should be wired up to.
Node* GetCallNodeForInputs() const;
@ -305,15 +319,9 @@ class Encapsulator {
void RecordOutsideCompilationOutputOrControl(
const string& outside_compilation_id, const Edge* edge);
// Adds the SendToHost nodes for each outside_compilation subgraph once the
// edges have all been recorded via RecordOutsideCompilationInputOrControl.
Status AddSendsToOutsideCompilation(
const std::unordered_map<const Node*, Node*>& node_images);
// Adds the RecvFromHost nodes for each outside_compilation subgraph once
// the edges have all been recorded via
// RecordOutsideCompilationOutputOrControl.
Status AddRecvsFromOutsideCompilation(
// Adds the HostCompute nodes for each outside_compilation subgraph.
Status AddHostComputes(
const string& subgraph_name,
const std::unordered_map<const Node*, Node*>& node_images);
// Creates the sequencer node if it doesn't exist, adding it to graph_out.
@ -323,10 +331,16 @@ class Encapsulator {
// all the downstream nodes of call_node_outputs.
void ConnectSequencerToOutputs(Graph* graph_out);
Status AddShapeInferenceInfo(
const string& outside_compilation_subgraph_name,
const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph);
Status ReplaceFunctionDef(FunctionLibraryDefinition* library);
private:
struct OutsideCompilationSubgraph {
// Map from source (producer node/slot) tensors in the original graph to
// input index (slot number in the SendToHost/RecvAtHost nodes that will
// input index (slot number in the HostCompute/RecvAtHost nodes that will
// be created) for the outside_compilation subgraph.
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> inputs;
@ -335,14 +349,14 @@ class Encapsulator {
// outside_compilation subgraph. These are recorded by
// RecordOutsideCompilationInputOrControl while walking all the subgraph
// edges, and lifted control edges within the subgraph are added by
// AddSendsToOutsideCompilation once the _SendToHost node has been
// AddSendsToOutsideCompilation once the _HostCompute node has been
// created. The matching control edge from _RecvAtHost to the
// destination is added by CopyEdgeToOutputGraph.
std::unordered_set<const Node*> control_inputs;
// Maps from source (producer node/slot) and destination (consumer
// node/slot) tensors in the original graph to output index (slot number
// in the SendFromHost/RecvFromHost nodes that will be created) for the
// in the SendFromHost/HostCompute nodes that will be created) for the
// outside_compilation subgraph.
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_src;
std::unordered_map<NodeSlot, int, NodeSlot::Hasher> outputs_by_dst;
@ -352,13 +366,13 @@ class Encapsulator {
// containing compiled subgraph. These are recorded by
// RecordOutsideCompilationOutputOrControl while walking all the subgraph
// edges, and lifted control edges within the subgraph are added by
// AddRecvsFromToOutsideCompilation once the _RecvFromHost node has been
// AddRecvsFromToOutsideCompilation once the _HostCompute node has been
// created. The matching control edge from the source to _SendFromHost to
// the destination is added by CopyEdgeToOutputGraph.
std::unordered_set<const Node*> control_outputs;
// _SendToHost node in the subgraph. Not owned.
Node* send_to_host = nullptr;
// Name of the _HostCompute node in the subgraph.
string host_compute_name;
// _RecvAtHost node in the output graph. Not owned.
Node* recv_at_host = nullptr;
@ -516,6 +530,59 @@ class Encapsulator {
const std::unordered_map<const Node*, Node*>& node_images,
bool parallel_checking, Graph* graph_out);
// Constructs a minimal shape inference graph that can be used to determine
// the shape of send_node at the time that the subgraph is compiled.
// recv_at_host_nodes contains the names of all the recv_at_host nodes that
// send_node might depend on. These recv_at_host nodes have shapes that are
// not known during the rewrite pass, but will be known at compile time.
//
// If the shapes of all the inputs to send_node can be determined during the
// rewrite pass, on exit graphdef_out is empty and the shapes are returned in
// static_shape_out. Otherwise graphdef_out contains a graph that can be used
// for shape inference at compile time, where all the source nodes of the
// graph are either constants with known shapes, or nodes named in
// recv_at_host_nodes.
//
// A non-OK status is returned if neither of the above conditions can be
// satisfied, e.g., because send_node depends on a node that doesn't have a
// registered shape inference function.
Status DoStaticShapeInferenceForOutsideCompilationSend(
const Graph& graph_in, const ShapeRefiner& shape_refiner,
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
FunctionLibraryDefinition* library,
std::vector<TensorShapeProto>* static_shape_out,
std::unique_ptr<GraphDef>* graphdef_out);
// Makes a copy of graph containing only nodes that are ancestors of at least
// one node in send_from_host_nodes and store it in pruned_graph. On exit
// nodes_images contains a mapping from nodes in graph to nodes in
// pruned_graph. All functions in the copied graph are inlined.
Status MakePrunedGraphCopyAndInline(
const Graph& graph, const std::vector<Node*>& sink_nodes,
std::unique_ptr<Graph>* pruned_graph,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library);
// Makes a copy of graph containing only nodes that are ancestors of a
// send_from_host node in an outside_compilation subgraph, and store it in
// pruned_graph. Also perform shape inference on the pruned graph, using
// shape_refiner. On exit node_images contains a mapping from nodes in graph
// to nodes in pruned_graph.
Status MakeGraphForOutsideCompilationSends(
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
ShapeRefiner* shape_refiner,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library);
// Performs static shape inference, as far as possible, for the send_from_host
// nodes in each outside_compilation subgraph. Where it is not possible to
// determine the shape statically, stores a serialized GraphDef in the
// HostCompute 'shape_inference_graph' attr, to be used at compile time for
// final inference. If the shapes are known statically they are stored in the
// HostCompute 'shapes' attr.
Status GetShapeInfoForOutsideCompilationSends(
Graph* graph_out, FunctionLibraryDefinition* library);
const string group_attribute_;
const string outside_compilation_attribute_;
const Graph* graph_in_;
@ -682,16 +749,20 @@ void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl(
}
}
Status Encapsulator::Subgraph::AddSendsToOutsideCompilation(
Status Encapsulator::Subgraph::AddHostComputes(
const string& subgraph_name,
const std::unordered_map<const Node*, Node*>& node_images) {
for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
const string& oc_subgraph_name = oc_subgraph_iter.first;
OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) {
// Build a _SendToHost node sending all the args of the appropriate
// types.
std::vector<DataType> dtypes(oc_subgraph.inputs.size(), DT_INVALID);
if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() ||
!oc_subgraph.outputs_by_src.empty() ||
!oc_subgraph.control_outputs.empty()) {
// Build a _HostCompute node.
std::vector<NodeDefBuilder::NodeOut> inputs(oc_subgraph.inputs.size());
std::vector<DataType> input_dtypes(oc_subgraph.inputs.size(), DT_INVALID);
std::vector<DataType> output_dtypes(oc_subgraph.outputs_by_src.size(),
DT_INVALID);
for (const auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
@ -700,94 +771,64 @@ Status Encapsulator::Subgraph::AddSendsToOutsideCompilation(
int input_index = input_src.second;
DataType dtype = src_node->output_type(src_slot);
dtypes[input_index] = dtype;
inputs[input_index].Reset(src_image->name(), src_slot, dtype);
input_dtypes[input_index] = dtype;
}
NodeDef send_def;
NodeDefBuilder builder(
strings::StrCat("outside_compilation_", oc_subgraph_name, "_send"),
kSendToHostOp);
builder.Attr("dtypes", dtypes);
for (const auto& output : oc_subgraph.outputs_by_src) {
DataType dtype = output.first.dtype;
int output_index = output.second;
output_dtypes[output_index] = dtype;
}
NodeDef host_compute_def;
NodeDefBuilder builder(strings::StrCat("outside_compilation_",
oc_subgraph_name, "_host_compute"),
kHostComputeOp);
builder.Input(inputs);
Status s = builder.Finalize(&send_def);
builder.Attr("Tinputs", input_dtypes);
builder.Attr("Toutputs", output_dtypes);
builder.Attr("key",
strings::StrCat("host_compute_channel_", subgraph_name, "_",
oc_subgraph_name));
Status s = builder.Finalize(&host_compute_def);
if (!s.ok()) return s;
oc_subgraph.send_to_host = graph_->AddNode(send_def, &s);
Node* host_compute = graph_->AddNode(host_compute_def, &s);
if (!s.ok()) return s;
oc_subgraph.host_compute_name = host_compute->name();
// Connect the _SendToHost node to its producers in the subgraph.
// Connect the _HostCompute node to its producers in the subgraph.
for (auto& input_src : oc_subgraph.inputs) {
const Node* src_node = input_src.first.node;
Node* src_image = node_images.at(src_node);
int src_slot = input_src.first.slot;
int input_index = input_src.second;
graph_->AddEdge(src_image, src_slot, oc_subgraph.send_to_host,
input_index);
graph_->AddEdge(src_image, src_slot, host_compute, input_index);
}
// Connect the _SendToHost node to its control edge producers in the
// Connect the _HostCompute node to its control edge producers in the
// subgraph.
for (const auto& src_node : oc_subgraph.control_inputs) {
Node* src_image = node_images.at(src_node);
graph_->AddControlEdge(src_image, oc_subgraph.send_to_host);
}
}
}
return Status::OK();
}
Status Encapsulator::Subgraph::AddRecvsFromOutsideCompilation(
const std::unordered_map<const Node*, Node*>& node_images) {
for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) {
const string& oc_subgraph_name = oc_subgraph_iter.first;
OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second;
if (!oc_subgraph.outputs_by_src.empty() ||
!oc_subgraph.control_outputs.empty()) {
// Build a _RecvFromHost node producing all the outputs of the appropriate
// types.
std::vector<DataType> dtypes(oc_subgraph.outputs_by_src.size(),
DT_INVALID);
for (const auto& output : oc_subgraph.outputs_by_src) {
DataType dtype = output.first.dtype;
int output_index = output.second;
dtypes[output_index] = dtype;
graph_->AddControlEdge(src_image, host_compute);
}
NodeDef recv_def;
NodeDefBuilder builder(
strings::StrCat("outside_compilation_", oc_subgraph_name, "_recv"),
kRecvFromHostOp);
builder.Attr("dtypes", dtypes);
Status s = builder.Finalize(&recv_def);
if (!s.ok()) return s;
Node* recv = graph_->AddNode(recv_def, &s);
if (!s.ok()) return s;
// Connect the consumers in the subgraph to the _RecvFromHost node.
// Connect the consumers in the subgraph to the _HostCompute node.
for (const auto& output : oc_subgraph.outputs_by_dst) {
const Node* dst_node = output.first.node;
Node* dst_image = node_images.at(dst_node);
int dst_slot = output.first.slot;
int output_index = output.second;
graph_->AddEdge(recv, output_index, dst_image, dst_slot);
graph_->AddEdge(host_compute, output_index, dst_image, dst_slot);
}
// Connect the control edge consumers in the subgraph to the _RecvFromHost
// Connect the control edge consumers in the subgraph to the _HostCompute
// node.
for (const auto& dst_node : oc_subgraph.control_outputs) {
Node* dst_image = node_images.at(dst_node);
graph_->AddControlEdge(recv, dst_image);
}
// Add a control edge in the subgraph so that the _SendToHost node, if
// any, is compiled before the _RecvFromHost node.
if (oc_subgraph.send_to_host != nullptr) {
graph_->AddControlEdge(oc_subgraph.send_to_host, recv);
graph_->AddControlEdge(host_compute, dst_image);
}
}
}
@ -882,6 +923,63 @@ Status Encapsulator::Subgraph::BuildFunctionDef(
return Status::OK();
}
Status Encapsulator::Subgraph::AddShapeInferenceInfo(
const string& outside_compilation_subgraph_name,
const std::vector<TensorShapeProto>& shapes, GraphDef* inference_graph) {
OutsideCompilationSubgraph& oc_subgraph =
outside_compilation_subgraphs_.at(outside_compilation_subgraph_name);
Node* host_compute = nullptr;
for (Node* n : graph_->nodes()) {
if (n->name() == oc_subgraph.host_compute_name) {
host_compute = n;
break;
}
}
if (host_compute == nullptr) {
return errors::InvalidArgument(
"After rewriting subgraph ", outside_compilation_subgraph_name,
" there is no HostCompute Op for outside compilation subgraph ",
oc_subgraph.host_compute_name);
}
if (inference_graph == nullptr) {
host_compute->AddAttr("shape_inference_graph", "");
host_compute->AddAttr("shapes", shapes);
} else {
string serialized_graph;
if (!inference_graph->SerializeToString(&serialized_graph)) {
return errors::Internal(
"Failed to serialize graph for outside compilation subgraph ",
oc_subgraph.host_compute_name);
}
host_compute->AddAttr("shape_inference_graph", serialized_graph);
host_compute->AddAttr("shapes", std::vector<TensorShapeProto>());
}
return Status::OK();
}
Status Encapsulator::Subgraph::ReplaceFunctionDef(
FunctionLibraryDefinition* library) {
const string& name = call_node_def_.name();
FunctionDef fdef;
TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef));
if (VLOG_IS_ON(1)) {
VLOG(2) << "Replace function def " << name;
dump_graph::DumpGraphToFile(
strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_,
library);
dump_graph::DumpFunctionDefToFile(
strings::StrCat("replace_encapsulate_fdef_", name), fdef);
}
TF_RETURN_IF_ERROR(library->RemoveFunction(name));
TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
return Status::OK();
}
Status Encapsulator::Subgraph::BuildParallelCheckOp(
const std::unordered_map<const Node*, Node*>& node_images,
Graph* graph_out) {
@ -980,7 +1078,9 @@ Status Encapsulator::Subgraph::AddRecvAtHostNode(
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_recv"),
kRecvAtHostOp);
builder.Attr("dtypes", dtypes);
builder.Attr("Toutputs", dtypes);
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
"_", oc_subgraph_name));
Status s = builder.Finalize(&recv_def);
if (!s.ok()) return s;
@ -1020,7 +1120,9 @@ Status Encapsulator::Subgraph::AddSendFromHostNode(
NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name,
"_", oc_subgraph_name, "_send"),
kSendFromHostOp);
builder.Attr("dtypes", dtypes);
builder.Attr("Tinputs", dtypes);
builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name,
"_", oc_subgraph_name));
builder.Input(inputs);
Status s = builder.Finalize(&send_def);
if (!s.ok()) return s;
@ -1062,6 +1164,13 @@ Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes(
return Status::OK();
}
void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
std::vector<string>* names) const {
for (auto& entry : outside_compilation_subgraphs_) {
names->push_back(entry.first);
}
}
Status Encapsulator::GetFunctionNameAttr(
Node const* node, string* attr, string* outside_compilation_attr) const {
Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
@ -1220,8 +1329,7 @@ Status Encapsulator::SplitIntoSubgraphs() {
// single input and output node for it.
for (auto& entry : subgraphs_) {
Subgraph& subgraph = entry.second;
TF_RETURN_IF_ERROR(subgraph.AddSendsToOutsideCompilation(node_images));
TF_RETURN_IF_ERROR(subgraph.AddRecvsFromOutsideCompilation(node_images));
TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images));
}
MarkGuaranteedConstants(*graph_in_, src_arg_pairs);
@ -1509,8 +1617,346 @@ Status Encapsulator::AddEdgesToOutputGraph(
return Status::OK();
}
Status Encapsulator::BuildOutputGraph(bool parallel_checking,
Graph* graph_out) {
namespace {
// Adds a dummy Const node to graph_out. The "constant" has the type of
// data_type and the shape indicated in 'shape'. The dummy node is not a valid
// Const node because it does not have any value defined, but this doesn't
// matter because it will only be used subsequently for shape inference. (It
// would be possible to add a switch statement over data_type to create a value
// for the constant, but that would entail maintaining the logic as new types
// are added, and is not necessary.)
Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape,
Graph* graph_out) {
TensorProto dummy_proto;
dummy_proto.set_dtype(data_type);
*dummy_proto.mutable_tensor_shape() = shape;
// Don't set any value field in the proto, since it is only going to be used
// for shape inference.
GraphDefBuilder::Options options(graph_out, /*status=*/nullptr);
NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const",
options.op_registry());
node_builder.Attr("dtype", data_type).Attr("value", dummy_proto);
return options.FinalizeBuilder(&node_builder);
}
// Adds a copy of node_in to graph_out and adds the mapping to
// copied_node_images.
Status CopyShapeInferenceNodeToGraph(
Node* node_in, const Node* send_node,
const std::unordered_map<Node*, Node*>& dummy_node_images,
FunctionLibraryDefinition* library,
std::unordered_map<Node*, Node*>* copied_node_images, Graph* graph_out) {
// Once all the ancestor nodes have been added to graph_out, add this node
// and connect it to its ancestors.
Node* node_out = graph_out->CopyNode(node_in);
(*copied_node_images)[node_in] = node_out;
// Don't bother to build the shape inference graph if there's a node with no
// shape inference function, since it would just result in an error later at
// compile time.
const OpRegistrationData* op_reg_data;
TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data));
if (op_reg_data->shape_inference_fn == nullptr) {
return errors::InvalidArgument(
"Shape inference is not possible for outside_compilation "
"SendFromHost node ",
send_node->name(), " because it depends on node ", node_in->name(),
" which does not have a shape inference function registered.");
}
// Add all the edges to the newly copied node.
for (const Edge* in_edge : node_in->in_edges()) {
if (!in_edge->IsControlEdge()) {
Node* src = in_edge->src();
const auto iter = dummy_node_images.find(src);
if (iter == dummy_node_images.end()) {
// The src is a copied node so use the original output port.
graph_out->AddEdge((*copied_node_images)[in_edge->src()],
in_edge->src_output(), node_out,
in_edge->dst_input());
} else {
// The src is a dummy node so use output port 0.
graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input());
}
}
}
return Status::OK();
}
} // namespace
Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend(
const Graph& graph_in, const ShapeRefiner& shape_refiner,
const std::unordered_set<string>& recv_at_host_nodes, Node* send_node,
FunctionLibraryDefinition* library,
std::vector<TensorShapeProto>* static_shape_out,
std::unique_ptr<GraphDef>* graphdef_out) {
// Maps from nodes in graph_in to nodes in graph_out.
//
// When an edge has fully defined shape the source node in graph_in is
// replaced in graph_out by a dummy constant node. The mapping from nodes
// in graph_in to dummy nodes is stored in dummy_node_images.
//
// When a node in graph_in has at least one ancestor that doesn't have fully
// defined shape, it is copied into graph_out. The mapping from nodes in
// graph_in to copied nodes is stored in copied_node_images.
//
// The two types of node are treated differently because, when adding edges to
// graph_out, an output from a dummy node always uses port 0, whereas an
// output from a copied node uses the same port that was used in graph_in.
std::unordered_map<Node*, Node*> dummy_node_images;
std::unordered_map<Node*, Node*> copied_node_images;
std::unique_ptr<Graph> graph_out(new Graph(graph_in.op_registry()));
graph_out->set_versions(graph_in.versions());
static_shape_out->resize(send_node->num_inputs());
// We don't use the standard ReverseDFS because we want to cut off traversal
// whenever we find an output with fully defined shape.
// TODO(misard) make this work properly in the presence of control flow.
struct Work {
Node* node;
bool leave; // Are we entering or leaving node?
};
std::vector<Work> stack({{send_node, false}});
std::vector<bool> visited(graph_in.num_node_ids(), false);
while (!stack.empty()) {
Work w = stack.back();
stack.pop_back();
Node* n = w.node;
if (w.leave) {
TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph(
n, send_node, dummy_node_images, library, &copied_node_images,
graph_out.get()));
} else {
if (visited[n->id()]) continue;
visited[n->id()] = true;
// Arrange to revisit when all done with all inputs.
stack.push_back(Work{n, true});
bool has_parent_with_unknown_shape = false;
for (const Edge* in_edge : n->in_edges()) {
if (!in_edge->IsControlEdge()) {
Node* src_node = in_edge->src();
int src_port = in_edge->src_output();
shape_inference::InferenceContext* context =
shape_refiner.GetContext(src_node);
shape_inference::ShapeHandle shape = context->output(src_port);
if (context->FullyDefined(shape)) {
// This ancestor has known shape, so instead of adding it to the
// stack, add a dummy node with that shape to graph_out and
// continue.
TensorShapeProto proto;
context->ShapeHandleToProto(shape, &proto);
dummy_node_images[src_node] = AddDummyShapedNode(
src_node->output_type(src_port), proto, graph_out.get());
if (n == send_node) {
(*static_shape_out)[in_edge->dst_input()] = proto;
}
} else {
if (!visited[src_node->id()]) {
has_parent_with_unknown_shape = true;
stack.push_back({src_node, false});
}
}
}
}
if (!has_parent_with_unknown_shape) {
if (n == send_node) {
// The shapes of all the inputs to send_node are statically known. We
// won't have to do any inference at compile time so return now: the
// shapes were stored in static_shape_out above.
graphdef_out->reset();
return Status::OK();
} else {
// Any shape that is being processed is either the original send node
// or has at least one output with statically-unknown shape. If the
// latter and it doesn't have any inputs with statically-unknown
// shape, then check that it is of the recv nodes that we can fill in
// the shape of at run-time later. If it isn't one of those, then we
// won't have any additional knowledge at compile time, so we already
// know we won't be able to do shape inference and we can return an
// error now.
if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) {
return errors::InvalidArgument(
"Shape inference is not possible for outside_compilation "
"SendFromHost node ",
send_node->name(), " because shape of node ", n->name(),
" will not be known at compilation time.");
}
}
}
}
}
graphdef_out->reset(new GraphDef());
graph_out->ToGraphDef(graphdef_out->get());
return Status::OK();
}
Status Encapsulator::MakePrunedGraphCopyAndInline(
const Graph& graph, const std::vector<Node*>& sink_nodes,
std::unique_ptr<Graph>* pruned_graph,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library) {
// First copy all ancestor nodes of sink_nodes into a new graph.
pruned_graph->reset(new Graph(library));
(*pruned_graph)->set_versions(graph.versions());
ReverseDFSFrom(graph, sink_nodes,
/*enter=*/nullptr,
/*leave=*/[&](Node* n) {
if (!n->IsSource()) {
Node* copied = (*pruned_graph)->CopyNode(n);
node_images->emplace(n, copied);
}
});
// Add all the edges between copied nodes.
for (auto entry : *node_images) {
const Node* orig = entry.first;
Node* image = entry.second;
for (const Edge* out_edge : orig->out_edges()) {
auto iter = node_images->find(out_edge->dst());
if (iter != node_images->end()) {
// The source and destination are both in the copied graph.
(*pruned_graph)
->AddEdge(image, out_edge->src_output(), iter->second,
out_edge->dst_input());
}
}
}
// Find all the function call nodes, and inline them.
std::vector<Node*> function_nodes;
for (auto node : (*pruned_graph)->nodes()) {
const OpRegistrationData* op_reg_data;
TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data));
if (op_reg_data->is_function_op) {
function_nodes.push_back(node);
}
}
for (auto node : function_nodes) {
VLOG(2) << "Inlining function " << node->name();
const FunctionDef* fdef = library->Find(node->type_string());
if (fdef == nullptr) {
return errors::Internal("Failed to find function ", node->type_string(),
" in function library.");
}
FunctionBody* fbody = nullptr;
TF_RETURN_IF_ERROR(
FunctionDefToBodyHelper(*fdef, node->attrs(), library,
[library](const string& op, const OpDef** sig) {
return library->LookUpOpDef(op, sig);
},
&fbody));
InlineFunctionBody(*library, pruned_graph->get(), node, fbody);
delete fbody;
}
return Status::OK();
}
Status Encapsulator::MakeGraphForOutsideCompilationSends(
const Graph& graph, std::unique_ptr<Graph>* pruned_graph,
ShapeRefiner* shape_refiner,
std::unordered_map<const Node*, Node*>* node_images,
FunctionLibraryDefinition* library) {
// Find all the send_from_host nodes in all subgraphs, to use as roots for the
// pruning.
std::vector<Node*> send_from_host_nodes;
for (auto& subgraph_entry : subgraphs_) {
Subgraph& subgraph = subgraph_entry.second;
std::vector<string> outside_compilation_names;
subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
for (const auto& name : outside_compilation_names) {
Node* send_node = subgraph.GetSendFromHostNode(name);
if (send_node != nullptr) {
send_from_host_nodes.push_back(send_node);
}
}
}
// Make a copy of all the graph nodes needed to evaluate the send_from_host
// nodes, inlining any functions as needed.
TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline(
graph, send_from_host_nodes, pruned_graph, node_images, library));
// Perform shape inference on the pruned graph.
shape_refiner->set_require_shape_inference_fns(false);
FixupSourceAndSinkEdges(pruned_graph->get());
std::vector<Node*> post_order;
GetReversePostOrder(*(*pruned_graph), &post_order);
for (auto node : post_order) {
// Ignore the status returned by the shape_refiner. At this point we want
// the best effort shapes, even if no shape function is registered for a
// node.
Status status = shape_refiner->AddNode(node);
if (!status.ok()) {
VLOG(1) << "Shape inference failed for node: " << status;
}
}
return Status::OK();
}
Status Encapsulator::GetShapeInfoForOutsideCompilationSends(
Graph* graph_out, FunctionLibraryDefinition* library) {
std::unique_ptr<Graph> pruned_graph;
ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry());
std::unordered_map<const Node*, Node*> node_images;
TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends(
*graph_out, &pruned_graph, &shape_refiner, &node_images, library));
for (auto& subgraph_entry : subgraphs_) {
Subgraph& subgraph = subgraph_entry.second;
// Find all the recv_at_host nodes in this subgraph.
std::vector<string> outside_compilation_names;
subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names);
std::unordered_set<string> recv_at_host_names;
for (const auto& name : outside_compilation_names) {
Node* recv_node = subgraph.GetRecvAtHostNode(name);
if (recv_node != nullptr) {
recv_at_host_names.insert(recv_node->name());
}
}
// For each send_from_host node, do as much shape inference as possible
// without knowing the shape of the recv_at_host nodes, and store the
// result, along with enough information to complete the job at compile time
// once the recv_at_host shapes are known.
for (const auto& name : outside_compilation_names) {
Node* send_node = subgraph.GetSendFromHostNode(name);
std::vector<TensorShapeProto> static_shape;
std::unique_ptr<GraphDef> graphdef;
if (send_node != nullptr) {
TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend(
*pruned_graph, shape_refiner, recv_at_host_names,
node_images[send_node], library, &static_shape, &graphdef));
if (graphdef == nullptr) {
VLOG(2) << "Send node " << send_node->name() << " shapes";
for (int i = 0; i < static_shape.size(); ++i) {
VLOG(2) << static_shape[i].DebugString();
}
} else {
VLOG(2) << "Send node " << send_node->name() << " graph\n"
<< graphdef->DebugString();
}
}
TF_RETURN_IF_ERROR(
subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get()));
}
if (!outside_compilation_names.empty()) {
TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library));
}
}
return Status::OK();
}
Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out,
FunctionLibraryDefinition* library) {
// Map from nodes in the input graph to nodes in the output graph.
std::unordered_map<const Node*, Node*> node_images;
@ -1522,6 +1968,9 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking,
TF_RETURN_IF_ERROR(
AddEdgesToOutputGraph(node_images, parallel_checking, graph_out));
TF_RETURN_IF_ERROR(
GetShapeInfoForOutsideCompilationSends(graph_out, library));
return Status::OK();
}
@ -1545,7 +1994,7 @@ Status EncapsulateSubgraphsInFunctions(
std::unique_ptr<Graph> out(new Graph(library));
out->set_versions(graph_in.versions());
TF_RETURN_IF_ERROR(
encapsulator.BuildOutputGraph(parallel_checking, out.get()));
encapsulator.BuildOutputGraph(parallel_checking, out.get(), library));
*graph_out = std::move(out);
return Status::OK();

View File

@ -29,17 +29,181 @@ limitations under the License.
namespace tensorflow {
namespace {
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
string* diff) {
// TODO(phawkins) use a more sophisticated equality test.
if (a.DebugString() != b.DebugString()) {
template <class Tkey, class Tvalue>
bool EqualProtoMap(const ::tensorflow::protobuf::Map<Tkey, Tvalue>& a,
const ::tensorflow::protobuf::Map<Tkey, Tvalue>& b,
const std::function<string(const Tkey&)>& key_to_string,
const std::function<string(const Tvalue&)>& value_to_string,
const std::function<bool(const Tkey&, const Tvalue&,
const Tvalue&)>& compare,
const string& map_name, string* diff) {
for (const auto& elt_a : a) {
const auto iter = b.find(elt_a.first);
if (iter == b.end()) {
if (diff) {
*diff = strings::StrCat(
map_name, " expected: contains element with key '",
key_to_string(elt_a.first), "' got: map has no such element");
}
return false;
}
if (!compare(elt_a.first, elt_a.second, iter->second)) {
if (diff) {
*diff = strings::StrCat(map_name, " expected: element with key '",
key_to_string(elt_a.first), " has value '",
value_to_string(elt_a.second), "' got: '",
value_to_string(iter->second), "'");
}
return false;
}
}
for (const auto& elt_b : b) {
const auto iter = a.find(elt_b.first);
if (iter == a.end()) {
if (diff) {
*diff = strings::StrCat(map_name, " got: contains element with key '",
key_to_string(elt_b.first),
"' expected: map has no such element");
}
return false;
}
}
return true;
}
bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b,
const string& diff_preamble, string* diff) {
if (a.op() != b.op()) {
if (diff) {
*diff = strings::StrCat("Definition mismatch for function ",
a.signature().name(), ", expected:\n",
a.DebugString(), "\ngot:\n", b.DebugString());
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
", expected op '", a.op(), "' got '", b.op());
}
return false;
}
if (a.device() != b.device()) {
if (diff) {
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
", expected device '", a.device(), "' got '",
b.device());
}
return false;
}
if (a.input_size() != b.input_size()) {
if (diff) {
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
", expected ", a.input_size(), " inputs got ",
b.input_size(), " expected:\n", a.DebugString(),
"\ngot:\n", b.DebugString());
}
return false;
}
for (int i = 0; i < a.input_size(); ++i) {
if (a.input(i) != b.input(i)) {
if (diff) {
*diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(),
" input ", i, ", expected ", a.input(i),
" got ", b.input(i), " expected:\n",
a.DebugString(), "\ngot:\n", b.DebugString());
}
return false;
}
}
return EqualProtoMap<string, AttrValue>(
a.attr(), b.attr(), [](const string& s) { return s; },
[](const AttrValue& v) { return v.DebugString(); },
[](const string& key, const AttrValue& av, const AttrValue& bv) {
if (key == "shape_inference_graph") {
// Default serialization of GraphDef is unstable because maps don't
// serialize deterministically. Rather than go through the hoops to
// turn on deterministic serialization of this attr just for this
// test, add logic here to compare determinstically.
GraphDef ga;
if (!ga.ParseFromString(av.s())) {
return false;
}
GraphDef gb;
if (!gb.ParseFromString(bv.s())) {
return false;
}
return EqualGraphDef(ga, gb, nullptr);
} else {
return av.DebugString() == bv.DebugString();
}
},
strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()),
diff);
}
bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
string* diff) {
if (a.signature().DebugString() != b.signature().DebugString()) {
if (diff) {
*diff = strings::StrCat("Signature mismatch for function ",
a.signature().name(), ", expected:\n",
a.signature().DebugString(), "\ngot:\n",
b.signature().DebugString());
}
return false;
}
if (!EqualProtoMap<string, AttrValue>(
a.attr(), b.attr(), [](const string& s) { return s; },
[](const AttrValue& v) { return v.DebugString(); },
[](const string& key, const AttrValue& av, const AttrValue& bv) {
return av.DebugString() == bv.DebugString();
},
strings::StrCat("attr mismatch for function ", a.signature().name()),
diff)) {
return false;
}
if (!EqualProtoMap<string, string>(
a.ret(), b.ret(), [](const string& s) { return s; },
[](const string& s) { return s; },
[](const string& key, const string& av, const string& bv) {
return av == bv;
},
strings::StrCat("ret mismatch for function ", a.signature().name()),
diff)) {
return false;
}
for (int i = 0; i < a.node_def_size(); ++i) {
bool found = false;
for (int j = 0; j < b.node_def_size(); ++j) {
if (a.node_def(i).name() == b.node_def(j).name()) {
if (!EqualFunctionNodeDef(
a.node_def(i), b.node_def(j),
strings::StrCat("Function ", a.signature().name()), diff)) {
return false;
}
found = true;
break;
}
}
if (!found) {
if (diff) {
*diff = strings::StrCat("Function ", a.signature().name(),
", expected: has node '", a.node_def(i).name(),
"' got: no node of that name");
}
return false;
}
}
for (int i = 0; i < b.node_def_size(); ++i) {
bool found = false;
for (int j = 0; j < a.node_def_size(); ++j) {
if (b.node_def(i).name() == a.node_def(j).name()) {
found = true;
break;
}
}
if (!found) {
if (diff) {
*diff = strings::StrCat("Function ", a.signature().name(),
", got: has node '", b.node_def(i).name(),
"' expected: no node of that name");
}
return false;
}
}
return true;
}
@ -84,29 +248,64 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
// TODO(misard): remove these fake registrations once there are real Ops to be
// compiled.
REGISTER_OP("_XlaSendToHost")
.Input("input: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("_XlaRecvFromHost")
.Output("output: dtypes")
.Attr("dtypes: list(type) >= 0");
REGISTER_OP("_XlaHostCompute")
.Input("inputs: Tinputs")
.Output("outputs: Toutputs")
.Attr("Tinputs: list(type) >= 0")
.Attr("Toutputs: list(type) >= 0")
.Attr("key: string")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("_XlaSendFromHost")
.Input("input: dtypes")
.Attr("dtypes: list(type) >= 0");
.Input("input: Tinputs")
.Attr("Tinputs: list(type) >= 0")
.Attr("key: string")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("_XlaRecvAtHost")
.Output("output: dtypes")
.Attr("dtypes: list(type) >= 0");
.Output("output: Toutputs")
.Attr("Toutputs: list(type) >= 0")
.Attr("key: string")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("InputTest").Output("o: float");
REGISTER_OP("InputTest")
.Output("o: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->UnknownShape());
return Status::OK();
});
REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
REGISTER_OP("InputTestShaped")
.Output("o: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->Vector(2));
return Status::OK();
});
REGISTER_OP("UnaryTest")
.Input("a: float")
.Output("o: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
::tensorflow::shape_inference::ShapeHandle o;
TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
c->set_output(0, o);
return Status::OK();
});
REGISTER_OP("BinaryTest")
.Input("a: float")
.Input("b: float")
.Output("o: float");
.Output("o: float")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
::tensorflow::shape_inference::ShapeHandle o;
TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o));
c->set_output(0, o);
return Status::OK();
});
REGISTER_OP("BinaryTest2")
.Input("a: float")
.Input("b: float")
.Output("o: float")
.SetShapeFn(::tensorflow::shape_inference::UnknownShape);
REGISTER_OP("AddNLikeTest")
.Input("inputs: N * T")
@ -124,22 +323,48 @@ Node* Input(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("InputTest", opts);
}
Node* RecvAtHost(const gtl::ArraySlice<DataType>& dtypes,
Node* InputShaped(const GraphDefBuilder::Options& opts) {
return ops::SourceOp("InputTestShaped", opts);
}
Node* KnownShape(const gtl::ArraySlice<int>& shape,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const",
opts.op_registry());
TensorProto value;
value.set_dtype(DT_FLOAT);
for (int dim : shape) {
value.mutable_tensor_shape()->add_dim()->set_size(dim);
}
return opts.WithAttr("value", value)
.WithAttr("dtype", DT_FLOAT)
.FinalizeBuilder(&node_builder);
}
Node* RecvAtHost(const string& key, const gtl::ArraySlice<DataType>& dtypes,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"),
"_XlaRecvAtHost", opts.op_registry());
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
return opts.WithAttr("Toutputs", dtypes)
.WithAttr("key", key)
.FinalizeBuilder(&node_builder);
}
Node* SendFromHost(const std::vector<ops::NodeOut>& inputs,
const gtl::ArraySlice<DataType>& dtypes,
Node* SendFromHost(const string& key, const std::vector<ops::NodeOut>& inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"),
"_XlaSendFromHost", opts.op_registry());
node_builder.Input(inputs);
return opts.WithAttr("dtypes", dtypes).FinalizeBuilder(&node_builder);
std::vector<DataType> dtypes;
for (const auto& node : inputs) {
dtypes.push_back(node.dt);
}
return opts.WithAttr("key", key)
.WithAttr("Tinputs", dtypes)
.FinalizeBuilder(&node_builder);
}
Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
@ -151,6 +376,11 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b,
return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts);
}
Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b,
const GraphDefBuilder::Options& opts) {
return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts);
}
Node* AddNLike(const std::vector<ops::NodeOut>& inputs,
const GraphDefBuilder::Options& opts) {
if (opts.HaveError()) return nullptr;
@ -576,6 +806,21 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
string shape_string_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* recv =
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
shape.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape.opts().WithName("E"));
SendFromHost("host_compute_channel_F1_O1", {e},
shape.opts().WithName("outside_compilation_F1_O1_send"));
GraphDef shape_graph;
TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
}
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
@ -584,19 +829,18 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
{{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_recv:output:0"},
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_recv"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"_XlaHostCompute",
{"C:o:0", "c:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", shape_string_expected},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"c"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
@ -612,11 +856,11 @@ TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) {
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Node* recv =
RecvAtHost({DT_FLOAT, DT_FLOAT},
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
b2.opts().WithName("E").WithControlInputs({recv, b}));
Node* send = SendFromHost({e}, {DT_FLOAT},
Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
@ -674,37 +918,71 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
string shape_string_expected_1;
{
GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately);
Node* recv =
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
shape1.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape1.opts().WithName("E"));
SendFromHost("host_compute_channel_F1_O1", {e},
shape1.opts().WithName("outside_compilation_F1_O1_send"));
GraphDef shape1_graph;
TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph));
EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1));
}
string shape_string_expected_2;
{
GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately);
Node* recv1 =
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
shape2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
shape2.opts().WithName("E"));
Node* recv2 =
RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
shape2.opts().WithName("outside_compilation_F1_O2_recv"));
Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H"));
SendFromHost("host_compute_channel_F1_O2", {h},
shape2.opts().WithName("outside_compilation_F1_O2_send"));
GraphDef shape2_graph;
TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph));
EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2));
}
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {},
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}},
{{"I"}, "UnaryTest", {"outside_compilation_O2_recv:output:0"}},
{{"I"},
"UnaryTest",
{"outside_compilation_O2_host_compute:outputs:0"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_recv:output:0"},
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_recv"}},
{{"outside_compilation_O2_send"},
"_XlaSendToHost",
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O2_host_compute"},
"_XlaHostCompute",
{"D:o:0", "F:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O2"},
{"shape_inference_graph", shape_string_expected_2},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"F"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{{"outside_compilation_O1_host_compute"},
"_XlaHostCompute",
{"C:o:0", "D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", shape_string_expected_1},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"D"}},
{{"outside_compilation_O2_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O2_send"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"i_0_retval", "I:o:0"}});
@ -720,23 +998,24 @@ TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) {
Node* call = b2.opts().FinalizeBuilder(&node_builder);
Node* recv1 =
RecvAtHost({DT_FLOAT, DT_FLOAT},
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts().WithName("E").WithControlInputs({recv1, b}));
Node* send1 = SendFromHost({e}, {DT_FLOAT},
Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
Node* recv2 =
RecvAtHost({DT_FLOAT, DT_FLOAT},
RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O2_recv"));
Node* g = Binary(e, ops::NodeOut(recv2, 1),
b2.opts().WithName("G").WithControlInputs({recv2, e}));
Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H"));
Node* send2 = SendFromHost(
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O2_send"));
Node* send2 =
SendFromHost("host_compute_channel_F1_O2", {h},
b2.opts().WithName("outside_compilation_F1_O2_send"));
Node* s = NoOp(b2.opts()
.WithName("F1_sequencer")
@ -758,8 +1037,8 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = InputShaped(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1"));
@ -791,6 +1070,24 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
string shape_string_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* recv =
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
shape.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1),
shape.opts().WithName("E"));
SendFromHost("host_compute_channel_F1_O1", {e},
shape.opts().WithName("outside_compilation_F1_O1_send"));
GraphDef shape_graph;
TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
}
TensorShapeProto shape_proto_expected;
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"},
{"f_0_retval:float", "d_0_retval:float"}, {},
@ -799,19 +1096,18 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"C:o:0", "outside_compilation_O1_recv:output:0"},
{"C:o:0", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_recv"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"_XlaHostCompute",
{"C:o:0", "D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})}},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT, DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", shape_string_expected},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"D"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}});
@ -822,16 +1118,16 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
{{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}},
{{"I"},
"BinaryTest",
{"f_0_arg", "outside_compilation_O1_recv:output:0"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"_XlaHostCompute",
{"G:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F2_O1"},
{"shape_inference_graph", ""},
{"shapes",
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
},
{{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}});
@ -839,15 +1135,15 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = InputShaped(b2.opts().WithName("B"));
Node* recv1 =
RecvAtHost({DT_FLOAT, DT_FLOAT},
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1),
b2.opts().WithName("E").WithControlInputs({recv1, b}));
Node* send1 = SendFromHost({e}, {DT_FLOAT},
Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
@ -857,12 +1153,14 @@ TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) {
Node* s1 = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1}));
Node* recv2 = RecvAtHost(
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_recv"));
Node* recv2 =
RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT},
b2.opts().WithName("outside_compilation_F2_O1_recv"));
Node* h = Binary(ops::NodeOut(call1, 1), recv2,
b2.opts().WithName("H").WithControlInput(s1));
Node* send2 = SendFromHost(
{h}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F2_O1_send"));
Node* send2 =
SendFromHost("host_compute_channel_F2_O1", {h},
b2.opts().WithName("outside_compilation_F2_O1_send"));
NodeBuilder node_builder2("F2", "F2", lib_def.get());
node_builder2.Input(e).Input(call1);
@ -888,7 +1186,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
@ -908,6 +1206,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
TensorShapeProto shape_proto_expected;
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
@ -915,11 +1216,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"_XlaHostCompute",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
{{"Tinputs", gtl::ArraySlice<DataType>({})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}}},
},
{{"f_0_retval", "F:o:0"}});
@ -927,12 +1233,13 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) {
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* e = Unary(a, b2.opts().WithName("E"));
Node* send1 = SendFromHost(
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
Node* send1 =
SendFromHost("host_compute_channel_F1_O1", {e},
b2.opts().WithName("outside_compilation_F1_O1_send"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
@ -954,7 +1261,7 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
{
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = Input(b1.opts().WithName("A"));
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
Node* d =
@ -975,6 +1282,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
TensorShapeProto shape_proto_expected;
shape_proto_expected.add_dim()->set_size(2);
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {},
{
@ -982,17 +1292,17 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"},
"BinaryTest",
{"D:o:0", "outside_compilation_O1_recv:output:0"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}},
{{"outside_compilation_O1_host_compute"},
"_XlaHostCompute",
{},
{{"dtypes", gtl::ArraySlice<DataType>({})}},
{{"Tinputs", gtl::ArraySlice<DataType>({})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes",
gtl::ArraySlice<TensorShapeProto>({shape_proto_expected})}},
{"D"}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}},
{"outside_compilation_O1_send"}},
},
{{"f_0_retval", "F:o:0"}});
@ -1000,14 +1310,16 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) {
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = Input(b2.opts().WithName("A"));
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 =
RecvAtHost({}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
RecvAtHost("host_compute_channel_F1_O1", {},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1));
Node* send1 = SendFromHost(
{e}, {DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_send"));
Node* send1 =
SendFromHost("host_compute_channel_F1_O1", {e},
b2.opts().WithName("outside_compilation_F1_O1_send"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
Node* call1 = b2.opts().FinalizeBuilder(&node_builder1);
@ -1055,10 +1367,14 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{{"outside_compilation_O1_host_compute"},
"_XlaHostCompute",
{"D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
},
{{"f_0_retval", "F:o:0"}});
@ -1069,8 +1385,9 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) {
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 = RecvAtHost(
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* recv1 =
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(recv1, b2.opts().WithName("E"));
NodeBuilder node_builder1("F1", "F1", lib_def.get());
node_builder1.Input(a).Input(b);
@ -1118,16 +1435,19 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
{
{{"C"}, "UnaryTest", {"a_0_arg"}},
{{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}},
{{"F"}, "UnaryTest", {"D:o:0"}, {}, {"outside_compilation_O1_recv"}},
{{"outside_compilation_O1_send"},
"_XlaSendToHost",
{{"F"},
"UnaryTest",
{"D:o:0"},
{{"dtypes", gtl::ArraySlice<DataType>({DT_FLOAT})}}},
{{"outside_compilation_O1_recv"},
"_XlaRecvFromHost",
{},
{{"dtypes", gtl::ArraySlice<DataType>({})}},
{"outside_compilation_O1_send"}},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"_XlaHostCompute",
{"D:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", ""},
{"shapes", gtl::ArraySlice<TensorShapeProto>({})}}},
},
{{"f_0_retval", "F:o:0"}});
@ -1138,10 +1458,11 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) {
Node* a = Input(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* recv1 = RecvAtHost(
{DT_FLOAT}, b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* recv1 =
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = Unary(recv1, b2.opts().WithName("E"));
Node* send1 = SendFromHost({}, {},
Node* send1 = SendFromHost("host_compute_channel_F1_O1", {},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
@ -1215,5 +1536,110 @@ TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) {
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
// Test for shape inference of outside compilation.
TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) {
FunctionDefLibrary library;
GraphDef graphdef;
{
*library.add_function() = test::function::XTimesTwo();
GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
Node* a = InputShaped(b1.opts().WithName("A"));
Node* b = Input(b1.opts().WithName("B"));
// Give nodes 'c' and 'd' names that collide after lowercasing.
Node* c = Unary(a, b1.opts().WithName("C"));
Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr(
"_encapsulate", "F1"));
Node* e = BinaryUnknownShape(c, d,
b1.opts()
.WithName("E")
.WithControlInputs({b, d})
.WithAttr("_encapsulate", "F1")
.WithAttr("_outside", "O1"));
Node* f = Binary(c, e,
b1.opts().WithName("F").WithControlInput(e).WithAttr(
"_encapsulate", "F1"));
Binary(a, f, b1.opts().WithName("G").WithControlInput(e));
TF_EXPECT_OK(b1.ToGraphDef(&graphdef));
}
TF_EXPECT_OK(Encapsulate(&graphdef, &library));
FunctionDefLibrary library_expected;
GraphDef graphdef_expected;
string shape_string_expected;
{
GraphDefBuilder shape(GraphDefBuilder::kFailImmediately);
Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0"));
Node* recv =
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
shape.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E"));
SendFromHost("host_compute_channel_F1_O1", {e},
shape.opts().WithName("outside_compilation_F1_O1_send"));
GraphDef shape_graph;
TF_EXPECT_OK(shape.ToGraphDef(&shape_graph));
EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected));
}
*library_expected.add_function() = test::function::XTimesTwo();
*library_expected.add_function() = FunctionDefHelper::Create(
"F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {},
{
{{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}},
{{"F"},
"BinaryTest",
{"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"},
{},
{"outside_compilation_O1_host_compute"}},
{{"outside_compilation_O1_host_compute"},
"_XlaHostCompute",
{"c:o:0"},
{{"Tinputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"Toutputs", gtl::ArraySlice<DataType>({DT_FLOAT})},
{"key", "host_compute_channel_F1_O1"},
{"shape_inference_graph", shape_string_expected},
{"shapes", gtl::ArraySlice<DataType>({})}},
{"c"}},
},
{{"f_0_retval", "F:o:0"}});
{
std::unique_ptr<FunctionLibraryDefinition> lib_def(
new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
Node* a = InputShaped(b2.opts().WithName("A"));
Node* b = Input(b2.opts().WithName("B"));
Node* c = Unary(a, b2.opts().WithName("C"));
NodeBuilder node_builder("F1", "F1", lib_def.get());
node_builder.Input(b).Input(c);
Node* call =
b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder);
Node* recv =
RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT},
b2.opts().WithName("outside_compilation_F1_O1_recv"));
Node* e = BinaryUnknownShape(
c, ops::NodeOut(recv, 0),
b2.opts().WithName("E").WithControlInputs({recv, b}));
Node* send = SendFromHost("host_compute_channel_F1_O1", {e},
b2.opts()
.WithName("outside_compilation_F1_O1_send")
.WithControlInput(e));
Node* s = NoOp(
b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send}));
Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e}));
TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected));
}
TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
}
} // namespace
} // namespace tensorflow

View File

@ -45,7 +45,7 @@ namespace tensorflow {
// see comment on `AllowsAsynchronousDeallocation()`.
class XlaAllocator : public xla::DeviceMemoryAllocator {
public:
XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context);
XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context);
~XlaAllocator() override;
xla::StatusOr<gpu::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
bool retry_on_failure) override;
@ -79,7 +79,8 @@ class XlaAllocator : public xla::DeviceMemoryAllocator {
std::unordered_map<void*, Tensor> tensors_;
};
XlaAllocator::XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context)
XlaAllocator::XlaAllocator(const gpu::Platform* platform,
OpKernelContext* op_context)
: xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
XlaAllocator::~XlaAllocator() = default;
@ -248,12 +249,16 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
// Builds an XLA allocator for the device.
XlaAllocator xla_allocator(client->platform(), ctx);
XlaCompiler::Options options;
options.client = client;
options.device_type = &cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
options.device_allocator = &xla_allocator;
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
@ -264,9 +269,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
VLOG(1) << "Executing XLA Computation...";
// Builds an XLA allocator for the device.
XlaAllocator xla_allocator(client->platform(), ctx);
std::unique_ptr<xla::ShapedBuffer> output;
// Build xla::ShapedBuffers that point directly to the Tensor buffers.
std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
@ -374,8 +376,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES(ctx,
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
errors::Internal("Invalid input index for variable write."));
TensorShape write_shape;
OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape));
gpu::DeviceMemoryBase buffer = output->buffer({output_num});
@ -397,7 +397,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
// Looks up the owning Tensor by buffer address.
OP_REQUIRES_OK(
ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape,
ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape,
variable->tensor()));
++output_num;
}

View File

@ -148,8 +148,7 @@ Status BuildArguments(int num_constant_args,
XlaCompiler::Argument& arg = (*args)[input_num];
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input.dtype();
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
arg.shape = input.shape();
arg.constant_value = input;
++input_num;
}
@ -170,8 +169,7 @@ Status BuildArguments(int num_constant_args,
arg.constant_value = input;
}
arg.type = input.dtype();
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
arg.shape = input.shape();
++input_num;
}
@ -189,8 +187,7 @@ Status BuildArguments(int num_constant_args,
if (variable_args[variable_id].present) {
const Tensor& value = variable_args[variable_id].value;
arg.type = value.dtype();
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape));
arg.shape = value.shape();
arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since
@ -199,7 +196,7 @@ Status BuildArguments(int num_constant_args,
// uninitialized variables.
arg.initialized = false;
arg.type = DT_INVALID;
arg.shape = xla::Shape();
arg.shape = TensorShape();
}
++input_num;
}
@ -223,6 +220,7 @@ Status XlaCompilationCache::BuildExecutable(
xla::ExecutableBuildOptions build_options;
build_options.set_device_ordinal(client_->default_device_ordinal());
build_options.set_result_layout(result.xla_output_shape);
build_options.set_device_allocator(options.device_allocator);
auto compile_result =
client_->Compile(*result.computation, argument_layouts, build_options);

View File

@ -144,6 +144,21 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "matrix_triangular_solve_op_test",
size = "small",
srcs = ["matrix_triangular_solve_op_test.py"],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
)
tf_xla_py_test(
name = "clustering_test",
size = "small",
@ -240,6 +255,18 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "extract_image_patches_op_test",
size = "small",
srcs = ["extract_image_patches_op_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "fft_test",
size = "medium",
@ -326,6 +353,19 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "matrix_band_part_test",
size = "medium",
srcs = ["matrix_band_part_test.py"],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "momentum_test",
size = "small",
@ -437,6 +477,18 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "reverse_sequence_op_test",
size = "small",
srcs = ["reverse_sequence_op_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "rmsprop_test",
size = "small",

View File

@ -1181,6 +1181,50 @@ class BinaryOpsTest(XLATestCase):
np.array([4, 5, 6], dtype=np.int32),
expected=None)
def testMatrixSetDiag(self):
for dtype in self.numeric_types:
# Square
self._testBinary(
array_ops.matrix_set_diag,
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
dtype=dtype),
np.array([1.0, 2.0, 3.0], dtype=dtype),
expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]],
dtype=dtype))
self._testBinary(
array_ops.matrix_set_diag,
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
dtype=dtype),
np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype),
expected=np.array(
[[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]],
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]],
dtype=dtype))
# Rectangular
self._testBinary(
array_ops.matrix_set_diag,
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
np.array([3.0, 4.0], dtype=dtype),
expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
self._testBinary(
array_ops.matrix_set_diag,
np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
np.array([3.0, 4.0], dtype=dtype),
expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
self._testBinary(
array_ops.matrix_set_diag,
np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
[[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
np.array([[-1.0, -2.0], [-4.0, -5.0]],
dtype=dtype),
expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
[[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
dtype=dtype))
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,134 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for ExtractImagePatches op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ExtractImagePatches(XLATestCase):
"""Functional tests for ExtractImagePatches op."""
def _VerifyValues(self, image, ksizes, strides, rates, padding, patches):
"""Tests input-output pairs for the ExtractImagePatches op.
Args:
image: Input tensor with shape: [batch, in_rows, in_cols, depth].
ksizes: Patch size specified as: [ksize_rows, ksize_cols].
strides: Output strides, specified as [stride_rows, stride_cols].
rates: Atrous rates, specified as [rate_rows, rate_cols].
padding: Padding type.
patches: Expected output.
"""
ksizes = [1] + ksizes + [1]
strides = [1] + strides + [1]
rates = [1] + rates + [1]
with self.test_session():
image_placeholder = array_ops.placeholder(dtypes.float32)
with self.test_scope():
out_tensor = array_ops.extract_image_patches(
image_placeholder,
ksizes=ksizes,
strides=strides,
rates=rates,
padding=padding,
name="im2col")
feed_dict = {image_placeholder: image}
self.assertAllClose(patches, out_tensor.eval(feed_dict=feed_dict))
def testKsize1x1Stride1x1Rate1x1(self):
"""Verifies that for 1x1 kernel the output equals the input."""
# [2, 3, 4, 5]
image = np.reshape(range(120), [2, 3, 4, 5])
# [2, 3, 4, 5]
patches = np.reshape(range(120), [2, 3, 4, 5])
for padding in ["VALID", "SAME"]:
self._VerifyValues(
image,
ksizes=[1, 1],
strides=[1, 1],
rates=[1, 1],
padding=padding,
patches=patches)
def testKsize1x1Stride2x3Rate1x1(self):
"""Test for 1x1 kernel and strides."""
# [2, 4, 5, 3]
image = np.reshape(range(120), [2, 4, 5, 3])
# [2, 2, 2, 3]
patches = image[:, ::2, ::3, :]
for padding in ["VALID", "SAME"]:
self._VerifyValues(
image,
ksizes=[1, 1],
strides=[2, 3],
rates=[1, 1],
padding=padding,
patches=patches)
def testKsize2x2Stride1x1Rate1x1Valid(self):
"""Test for 2x2 kernel with VALID padding."""
# [1, 2, 2, 1]
image = [[[[1], [2]], [[3], [4]]]]
# [1, 1, 1, 4]
patches = [[[[1, 2, 3, 4]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[1, 1],
padding="VALID",
patches=patches)
def testKsize2x2Stride1x1Rate1x1Same(self):
"""Test for 2x2 kernel with SAME padding."""
# [1, 2, 2, 1]
image = [[[[1], [2]], [[3], [4]]]]
# [1, 2, 2, 4]
patches = [[[[1, 2, 3, 4], [2, 0, 4, 0]], [[3, 4, 0, 0], [4, 0, 0, 0]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[1, 1],
padding="SAME",
patches=patches)
def testKsize2x2Stride1x1Rate2x2Valid(self):
"""Test for 2x2 kernel with 2x2 dilation."""
# [1, 2, 2, 1]
image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32)
# [1, 2, 2, 4]
patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]],
[[4, 6, 12, 14], [5, 7, 13, 15]]]]
self._VerifyValues(
image,
ksizes=[2, 2],
strides=[1, 1],
rates=[2, 2],
padding="VALID",
patches=patches)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,64 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class MatrixBandPartTest(XLATestCase):
def _testMatrixBandPart(self, dtype, shape):
with self.test_session():
batch_shape = shape[:-2]
mat = np.ones(shape).astype(dtype)
batch_mat = np.tile(mat, batch_shape + [1, 1])
for lower in -1, 0, 1, shape[-2] - 1:
for upper in -1, 0, 1, shape[-1] - 1:
band_np = mat
if lower >= 0:
band_np = np.triu(band_np, -lower)
if upper >= 0:
band_np = np.tril(band_np, upper)
if batch_shape:
band_np = np.tile(band_np, batch_shape + [1, 1])
placeholder = array_ops.placeholder(dtype)
with self.test_scope():
band = array_ops.matrix_band_part(
placeholder,
constant_op.constant(lower, dtype=dtypes.int32),
constant_op.constant(upper, dtype=dtypes.int32))
feed_dict = {placeholder: batch_mat}
self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict))
def testMatrixBandPart(self):
for dtype in self.float_types:
for batch_shape in [[], [2,], [1, 3, 2]]:
for rows in 1, 2, 7:
for cols in 1, 2, 7:
self._testMatrixBandPart(dtype, batch_shape + [rows, cols])
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,130 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.MatrixTriangularSolve."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
def MakePlaceholder(x):
return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape)
class MatrixTriangularSolveOpTest(XLATestCase):
def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca,
placeholder_b, a, clean_a, b, verification,
atol):
feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b}
verification_np = sess.run(verification, feed_dict)
self.assertAllClose(b, verification_np, atol=atol)
def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol):
clean_a = np.tril(a) if lower else np.triu(a)
with self.test_session() as sess:
placeholder_a = MakePlaceholder(a)
placeholder_ca = MakePlaceholder(clean_a)
placeholder_b = MakePlaceholder(b)
with self.test_scope():
x = linalg_ops.matrix_triangular_solve(
placeholder_a, placeholder_b, lower=lower, adjoint=adjoint)
verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint)
self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca,
placeholder_b, a, clean_a, b,
verification, atol)
def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4):
transp = lambda x: np.swapaxes(x, -1, -2)
for lower, adjoint in itertools.product([True, False], repeat=2):
self._VerifyTriangularSolve(
a if lower else transp(a), b, lower, adjoint, atol)
def testBasic(self):
rng = np.random.RandomState(0)
a = np.tril(rng.randn(5, 5))
b = rng.randn(5, 7)
for dtype in self.float_types:
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
def testBasicNotActuallyTriangular(self):
rng = np.random.RandomState(0)
a = rng.randn(5, 5) # the `a` matrix is not lower-triangular
b = rng.randn(5, 7)
for dtype in self.float_types:
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
def testBasicComplexDtypes(self):
rng = np.random.RandomState(0)
a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j)
b = rng.randn(5, 7) + rng.randn(5, 7) * 1j
for dtype in self.complex_types:
self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype))
def testBatch(self):
rng = np.random.RandomState(0)
shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)),
((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))]
tuples = itertools.product(self.float_types, shapes)
for dtype, (a_shape, b_shape) in tuples:
n = a_shape[-1]
a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n)
b = rng.randn(*b_shape)
self._VerifyTriangularSolveCombo(
a.astype(dtype), b.astype(dtype), atol=1e-3)
def testLarge(self):
n = 1024
rng = np.random.RandomState(0)
a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n)
b = rng.randn(n, n)
self._VerifyTriangularSolve(
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
def testNonSquareCoefficientMatrix(self):
rng = np.random.RandomState(0)
for dtype in self.float_types:
a = rng.randn(3, 4).astype(dtype)
b = rng.randn(4, 4).astype(dtype)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(a, b)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(a, b)
def testWrongDimensions(self):
randn = np.random.RandomState(0).randn
for dtype in self.float_types:
lhs = constant_op.constant(randn(3, 3), dtype=dtype)
rhs = constant_op.constant(randn(4, 3), dtype=dtype)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(lhs, rhs)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(lhs, rhs)
if __name__ == "__main__":
test.main()

View File

@ -0,0 +1,93 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.reverse_sequence_op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ReverseSequenceTest(XLATestCase):
def _testReverseSequence(self,
x,
batch_axis,
seq_axis,
seq_lengths,
truth,
expected_err_re=None):
with self.test_session():
p = array_ops.placeholder(dtypes.as_dtype(x.dtype))
lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype))
with self.test_scope():
ans = array_ops.reverse_sequence(
p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths)
if expected_err_re is None:
tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths})
self.assertAllClose(tf_ans, truth, atol=1e-10)
else:
with self.assertRaisesOpError(expected_err_re):
ans.eval(feed_dict={p: x, lengths: seq_lengths})
def testSimple(self):
x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32)
self._testReverseSequence(
x,
batch_axis=0,
seq_axis=1,
seq_lengths=np.array([1, 3, 2], np.int32),
truth=expected)
def _testBasic(self, dtype, len_dtype):
x = np.asarray(
[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]],
[[17, 18, 19, 20], [21, 22, 23, 24]]],
dtype=dtype)
x = x.reshape(3, 2, 4, 1, 1)
x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
# reverse dim 2 up to (0:3, none, 0:4) along dim=0
seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype)
truth_orig = np.asarray(
[
[[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3
[[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none
[[20, 19, 18, 17], [24, 23, 22, 21]]
], # reverse 0:4 (all)
dtype=dtype)
truth_orig = truth_orig.reshape(3, 2, 4, 1, 1)
truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2
seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.)
batch_axis = 2
self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth)
def testSeqLength(self):
for dtype in self.all_types:
for seq_dtype in self.int_types:
self._testBasic(dtype, seq_dtype)
if __name__ == "__main__":
test.main()

View File

@ -154,6 +154,21 @@ class UnaryOpsTest(XLATestCase):
def testFloatOps(self):
for dtype in self.float_types:
x = np.arange(-0.90, 0.90, 0.25)
self._assertOpOutputMatchesExpected(
math_ops.acos,
x.astype(dtype),
expected=np.arccos(x).astype(dtype))
self._assertOpOutputMatchesExpected(
math_ops.asin,
x.astype(dtype),
expected=np.arcsin(x).astype(dtype))
x = np.arange(-3, 3).reshape(1, 3, 2)
self._assertOpOutputMatchesExpected(
math_ops.atan,
x.astype(dtype),
expected=np.arctan(x).astype(dtype))
self._assertOpOutputMatchesExpected(
math_ops.acosh,
np.array([1, 2, 3, 4], dtype=dtype),

View File

@ -427,16 +427,36 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
// identity nodes are values used by the loop body or condition.
// The Identity node may have the wrong device so copy the device from
// one of its outputs instead.
std::deque<const Edge*> possible_exit;
for (const Edge* edge : arg.switch_node->out_edges()) {
if (edge->src_output() == 0 && IsExit(edge->dst())) {
if (edge->src_output() == 0) {
possible_exit.push_back(edge);
}
if (IsIdentity(edge->dst())) {
TF_RETURN_IF_ERROR(
SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
}
}
// TODO(b/67425339): Allow general graph between switch and exit.
while (!possible_exit.empty()) {
const Edge* edge = possible_exit.front();
possible_exit.pop_front();
if (IsExit(edge->dst())) {
if (arg.exit != nullptr) {
return errors::InvalidArgument("Duplicate Exit successors to ",
arg.switch_node->name());
}
arg.exit = edge->dst();
} else if (StringPiece(edge->dst()->type_string()) == "Identity") {
TF_RETURN_IF_ERROR(
SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true));
} else {
if (!IsIdentity(edge->dst())) {
return errors::Unimplemented("General graph between switch (",
arg.switch_node->name(),
") and exit node of frame ",
frame->name, " not supported yet.");
}
for (const Edge* out : edge->dst()->out_edges()) {
possible_exit.push_back(out);
}
}
}
}

View File

@ -6,6 +6,9 @@ Operator | Type Constraint
`Acosh` | `T={complex64,double,float}`
`Add` | `T={complex64,double,float,int32,int64}`
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
`AdjustContrastv2` |
`AdjustHue` |
`AdjustSaturation` |
`All` | `Tidx={int32,int64}`
`Angle` | `Tout={double,float}`<br>`T={complex64}`
`Any` | `Tidx={int32,int64}`
@ -34,7 +37,7 @@ Operator | Type Constraint
`BroadcastGradientArgs` | `T={int32,int64}`
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Ceil` | `T={double,float}`
`Cholesky` | `T={complex64,double,float}`
`Cholesky` | `T={double,float}`
`Complex` | `Tout={complex64}`<br>`T={double,float}`
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
@ -68,7 +71,11 @@ Operator | Type Constraint
`Exp` | `T={complex64,double,float}`
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Expm1` | `T={complex64,double,float}`
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
`FFT` |
`FFT2D` |
`FFT3D` |
`Fill` | `index_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Floor` | `T={double,float}`
`FloorDiv` | `T={complex64,double,float,int32,int64}`
`FloorMod` | `T={double,float,int32,int64}`
@ -80,6 +87,13 @@ Operator | Type Constraint
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
`HSVToRGB` | `T={double,float}`
`IFFT` |
`IFFT2D` |
`IFFT3D` |
`IRFFT` |
`IRFFT2D` |
`IRFFT3D` |
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Imag` | `Tout={double,float}`<br>`T={complex64}`
@ -105,11 +119,14 @@ Operator | Type Constraint
`MatMul` | `T={complex64,double,float}`
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixTriangularSolve` | `T={complex64,double,float}`
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`MaxPool` | `T={double,float,int32,int64}`
`MaxPool3D` | `T={float}`
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolV2` | `T={double,float,int32,int64}`
`Maximum` | `T={double,float,int32,int64}`
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
@ -131,6 +148,10 @@ Operator | Type Constraint
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`QuantizeAndDequantizeV2` | `T={double,float}`
`RFFT` |
`RFFT2D` |
`RFFT3D` |
`RGBToHSV` | `T={double,float}`
`RandomStandardNormal` | `dtype={float}`
`RandomUniform` | `T={int32,int64}`<br>`dtype={double,float}`
`RandomUniformInt` | `T={int32,int64}`<br>`Tout={int32,int64}`
@ -146,6 +167,8 @@ Operator | Type Constraint
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ResizeBilinear` | `T={double,float,int32,int64}`
`ResizeBilinearGrad` | `T={double,float}`
`ResourceApplyAdagrad` | `T={double,float}`
`ResourceApplyAdam` | `T={double,float}`
`ResourceApplyFtrl` | `T={double,float}`
@ -156,6 +179,7 @@ Operator | Type Constraint
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
`RightShift` | `T={int32,int64,uint32,uint64}`
`Rint` | `T={double,float}`

View File

@ -6,6 +6,9 @@ Operator | Type Constraint
`Acosh` | `T={complex64,double,float}`
`Add` | `T={complex64,double,float,int32,int64}`
`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}`
`AdjustContrastv2` |
`AdjustHue` |
`AdjustSaturation` |
`All` | `Tidx={int32,int64}`
`Angle` | `Tout={double,float}`<br>`T={complex64}`
`Any` | `Tidx={int32,int64}`
@ -34,7 +37,7 @@ Operator | Type Constraint
`BroadcastGradientArgs` | `T={int32,int64}`
`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`<br>`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Ceil` | `T={double,float}`
`Cholesky` | `T={complex64,double,float}`
`Cholesky` | `T={double,float}`
`Complex` | `Tout={complex64}`<br>`T={double,float}`
`ComplexAbs` | `Tout={double,float}`<br>`T={complex64}`
`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
@ -68,7 +71,11 @@ Operator | Type Constraint
`Exp` | `T={complex64,double,float}`
`ExpandDims` | `Tdim={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Expm1` | `T={complex64,double,float}`
`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}`
`FFT` |
`FFT2D` |
`FFT3D` |
`Fill` | `index_type={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Floor` | `T={double,float}`
`FloorDiv` | `T={complex64,double,float,int32,int64}`
`FloorMod` | `T={double,float,int32,int64}`
@ -80,6 +87,13 @@ Operator | Type Constraint
`GatherV2` | `Taxis={int32,int64}`<br>`Tindices={int32,int64}`<br>`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Greater` | `T={double,float,int32,int64,uint32,uint64}`
`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}`
`HSVToRGB` | `T={double,float}`
`IFFT` |
`IFFT2D` |
`IFFT3D` |
`IRFFT` |
`IRFFT2D` |
`IRFFT3D` |
`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Imag` | `Tout={double,float}`<br>`T={complex64}`
@ -105,11 +119,14 @@ Operator | Type Constraint
`MatMul` | `T={complex64,double,float}`
`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`MatrixTriangularSolve` | `T={complex64,double,float}`
`Max` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`MaxPool` | `T={double,float,int32,int64}`
`MaxPool3D` | `T={float}`
`MaxPool3DGrad` | `TInput={float}`<br>`T={float}`
`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}`
`MaxPoolV2` | `T={double,float,int32,int64}`
`Maximum` | `T={double,float,int32,int64}`
`Mean` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`Min` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
@ -131,6 +148,10 @@ Operator | Type Constraint
`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Prod` | `Tidx={int32,int64}`<br>`T={complex64,double,float,int32,int64,uint32,uint64}`
`QuantizeAndDequantizeV2` | `T={double,float}`
`RFFT` |
`RFFT2D` |
`RFFT3D` |
`RGBToHSV` | `T={double,float}`
`Range` | `Tidx={double,float,int32,int64}`
`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}`
@ -143,6 +164,8 @@ Operator | Type Constraint
`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}`
`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}`
`Reshape` | `Tshape={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ResizeBilinear` | `T={double,float,int32,int64}`
`ResizeBilinearGrad` | `T={double,float}`
`ResourceApplyAdagrad` | `T={double,float}`
`ResourceApplyAdam` | `T={double,float}`
`ResourceApplyFtrl` | `T={double,float}`
@ -153,6 +176,7 @@ Operator | Type Constraint
`ResourceGather` | `Tindices={int32,int64}`<br>`dtype={complex64,double,float,int32,int64,uint32,uint64}`
`ResourceStridedSliceAssign` | `Index={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Reverse` | `T={bool,complex64,double,float,int32,int64}`
`ReverseSequence` | `Tlen={int32,int64}`<br>`T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`<br>`Tidx={int32,int64}`
`RightShift` | `T={int32,int64,uint32,uint64}`
`Rint` | `T={double,float}`

View File

@ -60,9 +60,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
for (int i = 0; i < args->size(); ++i) {
XlaCompiler::Argument& arg = (*args)[i];
arg.type = ctx->input_type(i);
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
arg.shape = ctx->InputShape(i);
if (arg.type == DT_RESOURCE) {
return errors::InvalidArgument(

View File

@ -31,6 +31,7 @@ tf_kernel_library(
"diag_op.cc",
"dynamic_stitch_op.cc",
"elu_op.cc",
"extract_image_patches_op.cc",
"fft_ops.cc",
"fill_op.cc",
"function_ops.cc",
@ -43,6 +44,9 @@ tf_kernel_library(
"l2loss_op.cc",
"lrn_ops.cc",
"matmul_op.cc",
"matrix_band_part_op.cc",
"matrix_set_diag_op.cc",
"matrix_triangular_solve_op.cc",
"mirror_pad_op.cc",
"no_op.cc",
"one_hot_op.cc",
@ -58,6 +62,7 @@ tf_kernel_library(
"reshape_op.cc",
"retval_op.cc",
"reverse_op.cc",
"reverse_sequence_op.cc",
"scan_ops.cc",
"segment_reduction_ops.cc",
"select_op.cc",
@ -92,6 +97,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
"//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/lib:triangular_solve",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/tf2xla/ops:sendrecv_ops",
"//tensorflow/compiler/xla:array4d",

View File

@ -28,8 +28,9 @@ class BatchMatMulOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
auto result =
BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_);
auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1),
/*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_,
/*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_);
OP_REQUIRES_OK(ctx, result.status());
ctx->SetOutput(0, result.ValueOrDie());
}

View File

@ -33,7 +33,7 @@ class CholeskyOp : public XlaOpKernel {
}
};
REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp);
REGISTER_XLA_OP(Name("Cholesky").TypeConstraint("T", kFloatTypes), CholeskyOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,169 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
namespace {
class ExtractImagePatchesOp : public XlaOpKernel {
public:
explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
}
void Compile(XlaOpKernelContext* ctx) override {
const TensorFormat data_format = FORMAT_NHWC;
const int num_dims = ksizes_.size();
OP_REQUIRES(
ctx, num_dims >= 3,
errors::InvalidArgument("Kernel size must have at least 3 dimensions"));
const int num_spatial_dims = num_dims - 2;
OP_REQUIRES(ctx, strides_.size() == num_dims,
errors::InvalidArgument("Sliding window strides field must "
"specify ",
num_dims, " dimensions"));
OP_REQUIRES(ctx, dilations_.size() == num_dims,
errors::InvalidArgument("Dilations field must "
"specify ",
num_dims, " dimensions"));
int batch_dim = GetTensorBatchDimIndex(num_dims, data_format);
int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format);
OP_REQUIRES(
ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
errors::Unimplemented("Current implementation does not yet support "
"kernel sizes > 1 in the batch and depth "
"dimensions."));
OP_REQUIRES(
ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
errors::Unimplemented("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES(
ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
errors::Unimplemented("Current implementation does not support "
"dilations in the batch and depth dimensions."));
for (int i = 0; i < num_spatial_dims; ++i) {
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
OP_REQUIRES(
ctx, ksizes_[input_dim] >= 0,
errors::Unimplemented("Kernel size values must be non-negative; ", i,
"th spatial dimension had dilation ",
dilations_[input_dim]));
OP_REQUIRES(ctx, strides_[input_dim] >= 1,
errors::Unimplemented("Stride values must be positive; ", i,
"th spatial dimension had dilation ",
dilations_[input_dim]));
OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
errors::Unimplemented("Dilation values must be positive; ", i,
"th spatial dimension had dilation ",
dilations_[input_dim]));
}
xla::PrimitiveType type;
OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type));
const TensorShape input_shape = ctx->InputShape(0);
OP_REQUIRES(
ctx, input_shape.dims() == num_dims,
errors::InvalidArgument("input must be ", num_dims, "-dimensional",
input_shape.DebugString()));
const int64 depth = input_shape.dim_size(feature_dim);
xla::ComputationBuilder* builder = ctx->builder();
// The following code is equivalent to:
// eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD])
int64 kernel_size = 1;
std::vector<int64> lhs_shape(num_dims, 1);
for (int i = 0; i < num_spatial_dims; ++i) {
int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
lhs_shape[i] = ksizes_[input_dim];
kernel_size *= ksizes_[input_dim];
}
lhs_shape[num_spatial_dims] = depth;
lhs_shape[num_spatial_dims + 1] = 1;
// Builds an identity matrix as a broadcast equality of iotas.
// iota = np.arange(np.prod(ksize), depth)
// filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32)
xla::ComputationDataHandle iota;
TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
kernel_size * depth, &iota));
auto lhs = builder->Reshape(iota, lhs_shape);
auto filter = builder->ConvertElementType(
builder->Eq(lhs, iota, {num_spatial_dims + 1}), type);
xla::ConvolutionDimensionNumbers dims;
std::vector<int64> window_strides(num_spatial_dims);
std::vector<int64> lhs_dilation(num_spatial_dims, 1);
std::vector<int64> rhs_dilation(num_spatial_dims);
std::vector<std::pair<int64, int64>> padding(num_spatial_dims);
dims.set_input_batch_dimension(batch_dim);
dims.set_output_batch_dimension(batch_dim);
dims.set_input_feature_dimension(feature_dim);
dims.set_output_feature_dimension(feature_dim);
dims.set_kernel_input_feature_dimension(num_spatial_dims);
dims.set_kernel_output_feature_dimension(num_spatial_dims + 1);
for (int i = 0; i < num_spatial_dims; ++i) {
const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i);
dims.add_input_spatial_dimensions(dim);
dims.add_kernel_spatial_dimensions(i);
dims.add_output_spatial_dimensions(dim);
window_strides[i] = strides_.at(dim);
rhs_dilation[i] = dilations_.at(dim);
int64 unused_output_size;
OP_REQUIRES_OK(
ctx, GetWindowedOutputSizeVerboseV2(
input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i],
window_strides[i], padding_, &unused_output_size,
&padding[i].first, &padding[i].second));
}
xla::ComputationDataHandle conv =
builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides,
padding, lhs_dilation, rhs_dilation, dims);
ctx->SetOutput(0, conv);
}
protected:
std::vector<int32> ksizes_;
std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
private:
TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp);
};
REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,98 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace {
class MatrixBandPartOp : public XlaOpKernel {
public:
explicit MatrixBandPartOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
// Preliminary validation of sizes.
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
errors::InvalidArgument(
"input must be at least 2-dim, received shape: ",
input_shape.DebugString()));
const TensorShape num_lower_in_shape = context->InputShape(1);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape),
errors::InvalidArgument("num_lower must be scalar, got shape ",
num_lower_in_shape.DebugString()));
const TensorShape num_upper_in_shape = context->InputShape(2);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape),
errors::InvalidArgument("num_upper must be scalar, got shape ",
num_upper_in_shape.DebugString()));
xla::ComputationBuilder* builder = context->builder();
xla::ComputationDataHandle input = context->Input(0);
xla::ComputationDataHandle num_lower = context->Input(1);
xla::ComputationDataHandle num_upper = context->Input(2);
DataType input_type = context->input_type(0);
DataType index_type = context->input_type(1);
TensorShape batch_shape = input_shape;
batch_shape.RemoveLastDims(2);
const int64 m = input_shape.dim_size(input_shape.dims() - 2);
const int64 n = input_shape.dim_size(input_shape.dims() - 1);
// Compute 'offset', which is how many diagonals we are above/below the
// diagonal.
xla::ComputationDataHandle iota_m;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m));
xla::ComputationDataHandle iota_n;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n));
auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m,
/*broadcast_dimensions=*/{0});
// If num_lower or num_upper are negative, include all lower/upper
// diagonals.
auto zero_index = XlaHelpers::Zero(builder, index_type);
num_lower = builder->Select(
builder->Lt(num_lower, zero_index),
XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower);
num_upper = builder->Select(
builder->Lt(num_upper, zero_index),
XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper);
auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset),
builder->Le(offset, num_upper));
indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
auto zero_input = XlaHelpers::Zero(builder, input_type);
auto output = builder->Select(
indicator, input,
builder->Broadcast(zero_input, input_shape.dim_sizes()));
context->SetOutput(0, output);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp);
};
REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,93 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
namespace tensorflow {
class MatrixSetDiagOp : public XlaOpKernel {
public:
explicit MatrixSetDiagOp(OpKernelConstruction* context)
: XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
const TensorShape diag_shape = context->InputShape(1);
const int rank = input_shape.dims();
// Preliminary validation of sizes.
OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape),
errors::InvalidArgument(
"input must be at least 2-dim, received shape: ",
input_shape.DebugString()));
// Check to make sure the last dimension of diag is equal to the smaller of
// the last two dimensions of input.
const int64 m = input_shape.dim_size(rank - 2);
const int64 n = input_shape.dim_size(rank - 1);
const int64 min_dim = std::min(m, n);
TensorShape batch_shape = input_shape;
batch_shape.RemoveLastDims(2);
TensorShape expected_diag_shape = batch_shape;
expected_diag_shape.AddDim(min_dim);
OP_REQUIRES(context, expected_diag_shape == diag_shape,
errors::InvalidArgument(
"must have diagonal.shape == input.shape[:-2] + "
"min(input.shape[-2:]), but received input shape: ",
input_shape.DebugString(),
" and diagonal shape: ", diag_shape.DebugString()));
xla::ComputationBuilder* builder = context->builder();
xla::ComputationDataHandle input = context->Input(0);
xla::ComputationDataHandle diag = context->Input(1);
auto zero = XlaHelpers::Zero(builder, context->input_type(0));
// Create an indicator tensor that is true only on the diagonal.
xla::ComputationDataHandle iota_m;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m));
xla::ComputationDataHandle iota_n;
OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n));
auto indicator = builder->Eq(iota_m,
builder->Broadcast(iota_n, {m}),
/*broadcast_dimensions=*/{0});
indicator = builder->Broadcast(indicator, batch_shape.dim_sizes());
// Broadcast diag up to the input shape. Use an implicit broadcast (Add)
// because we need to broadcast on the right.
std::vector<int64> diag_broadcast_dims(rank - 1);
std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0);
if (min_dim != m) {
diag_broadcast_dims.back() = rank - 1;
}
diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()),
/*broadcast_dimensions=*/diag_broadcast_dims);
auto output = builder->Select(indicator, diag, input);
context->SetOutput(0, output);
}
private:
TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
};
REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
namespace tensorflow {
namespace {
class MatrixTriangularSolveOp : public XlaOpKernel {
public:
explicit MatrixTriangularSolveOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto result = TriangularSolve(
ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true,
/*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_);
if (!result.ok()) {
ctx->SetStatus(result.status());
return;
}
ctx->SetOutput(0, result.ValueOrDie());
}
private:
bool lower_;
bool adjoint_;
};
REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp);
} // namespace
} // namespace tensorflow

View File

@ -37,21 +37,23 @@ class PoolingOp : public XlaOpKernel {
public:
PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims)
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
std::vector<int32> ksize_int;
std::vector<int32> stride_int;
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
errors::InvalidArgument("Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
OP_REQUIRES(ctx, stride_int.size() == num_dims(),
errors::InvalidArgument("Sliding window stride field must "
"specify ",
num_dims(), " dimensions"));
for (int i = 0; i < num_dims(); ++i) {
ksize_.push_back(ksize_int[i]);
stride_.push_back(stride_int[i]);
if (ctx->num_inputs() == 1) {
std::vector<int32> ksize_int;
std::vector<int32> stride_int;
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
errors::InvalidArgument("Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
OP_REQUIRES(ctx, stride_int.size() == num_dims(),
errors::InvalidArgument("Sliding window stride field must "
"specify ",
num_dims(), " dimensions"));
for (int i = 0; i < num_dims(); ++i) {
ksize_.push_back(ksize_int[i]);
stride_.push_back(stride_int[i]);
}
}
Padding padding;
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
@ -77,6 +79,33 @@ class PoolingOp : public XlaOpKernel {
xla::ComputationDataHandle input = ctx->Input(0);
const TensorShape input_shape = ctx->InputShape(0);
std::vector<int64> ksize = ksize_;
std::vector<int64> stride = stride_;
if (ctx->num_inputs() != 1) {
const TensorShape ksize_shape = ctx->InputShape(1);
// Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
errors::InvalidArgument("ksize must be a vector, not shape ",
ksize_shape.DebugString()));
OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
errors::InvalidArgument("Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
ksize.clear();
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
const TensorShape stride_shape = ctx->InputShape(2);
// Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
errors::InvalidArgument("stride must be a vector, not shape ",
stride_shape.DebugString()));
OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
errors::InvalidArgument("Sliding window stride field must "
"specify ",
num_dims(), " dimensions"));
stride.clear();
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
}
OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
errors::InvalidArgument("Input to ", type_string(),
" operator must have ", num_dims(),
@ -84,8 +113,8 @@ class PoolingOp : public XlaOpKernel {
const DataType type = input_type(0);
xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow(
input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize_,
stride_, padding_);
input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize,
stride, padding_);
ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape));
}
@ -130,6 +159,10 @@ class MaxPool2DOp : public MaxPoolOp {
}
};
REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp);
REGISTER_XLA_OP(Name("MaxPoolV2")
.CompileTimeConstInput("ksize")
.CompileTimeConstInput("strides"),
MaxPool2DOp);
class MaxPool3DOp : public MaxPoolOp {
public:
@ -243,22 +276,44 @@ class MaxPoolGradOp : public XlaOpKernel {
public:
MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
OP_REQUIRES(ctx, ksize_.size() == num_dims(),
errors::InvalidArgument("Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
OP_REQUIRES(ctx, stride_.size() == num_dims(),
errors::InvalidArgument("Sliding window strides field must "
"specify ",
num_dims(), " dimensions"));
if (ctx->num_inputs() == 3) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
}
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
}
int num_dims() const { return num_spatial_dims_ + 2; }
void Compile(XlaOpKernelContext* ctx) override {
if (ctx->num_inputs() != 3) {
OP_REQUIRES(
ctx, ctx->num_inputs() == 5,
errors::InvalidArgument("Must supply ksize and stride arguments."));
const TensorShape ksize_shape = ctx->InputShape(3);
// Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
errors::InvalidArgument("ksize must be a vector, not shape ",
ksize_shape.DebugString()));
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
const TensorShape stride_shape = ctx->InputShape(4);
// Validate input sizes.
OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
errors::InvalidArgument("stride must be a vector, not shape ",
stride_shape.DebugString()));
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
}
OP_REQUIRES(ctx, ksize_.size() == num_dims(),
errors::InvalidArgument("Sliding window ksize field must "
"specify ",
num_dims(), " dimensions"));
OP_REQUIRES(ctx, stride_.size() == num_dims(),
errors::InvalidArgument("Sliding window strides field must "
"specify ",
num_dims(), " dimensions"));
const TensorShape tensor_in_shape = ctx->InputShape(0);
const TensorShape tensor_out_shape = ctx->InputShape(1);
const TensorShape out_backprop_shape = ctx->InputShape(2);
@ -315,6 +370,10 @@ class MaxPool2DGradOp : public MaxPoolGradOp {
}
};
REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp);
REGISTER_XLA_OP(Name("MaxPoolGradV2")
.CompileTimeConstInput("ksize")
.CompileTimeConstInput("strides"),
MaxPool2DGradOp);
class MaxPool3DGradOp : public MaxPoolGradOp {
public:

View File

@ -0,0 +1,182 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace {
class ReverseSequenceOp : public XlaOpKernel {
public:
explicit ReverseSequenceOp(OpKernelConstruction* context)
: XlaOpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_));
OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_));
}
void Compile(XlaOpKernelContext* context) override {
const TensorShape input_shape = context->InputShape(0);
const TensorShape seq_lens_shape = context->InputShape(1);
OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape),
errors::InvalidArgument("seq_lens input must be 1-dim, not ",
seq_lens_shape.dims()));
OP_REQUIRES(context, batch_dim_ != seq_dim_,
errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_));
OP_REQUIRES(
context, seq_dim_ < input_shape.dims(),
errors::InvalidArgument("seq_dim must be < input.dims()", "( ",
seq_dim_, " vs. ", input_shape.dims(), ")"));
OP_REQUIRES(
context, batch_dim_ < input_shape.dims(),
errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
batch_dim_, " vs. ", input_shape.dims(), ")"));
OP_REQUIRES(
context,
seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_,
"), ", "(", seq_lens_shape.num_elements(),
" vs. ", input_shape.dim_size(batch_dim_)));
xla::ComputationBuilder* builder = context->builder();
const auto input = context->Input(0);
const auto seq_lens = context->Input(1);
const int64 batch_size = input_shape.dim_size(batch_dim_);
const DataType input_type = context->input_type(0);
const DataType seq_lens_type = context->input_type(1);
const int64 max_seq_len = input_shape.dim_size(seq_dim_);
xla::Shape input_xla_shape;
OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape,
&input_xla_shape));
xla::Shape seq_lens_xla_shape;
OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape,
&seq_lens_xla_shape));
const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({
xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}),
seq_lens_xla_shape,
input_xla_shape,
});
// For each entry in the batch, reverse the sequence.
// TODO(b/65689298): generalize the Map() operator to non-scalar cases and
// use it here, instead of a While loop.
// Condition: lambda (i, _, _): i < batch_size
auto condition_builder =
builder->CreateSubBuilder("reverse_sequence_condition");
{
auto param = condition_builder->Parameter(0, tuple_shape, "param");
auto i = condition_builder->GetTupleElement(param, 0);
condition_builder->Lt(
i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type,
batch_size));
}
auto condition = condition_builder->Build();
OP_REQUIRES_OK(context, condition.status());
auto body_builder = builder->CreateSubBuilder("reverse_sequence_body");
{
auto param = body_builder->Parameter(0, tuple_shape, "param");
auto i = body_builder->GetTupleElement(param, 0);
auto seq_lens = body_builder->GetTupleElement(param, 1);
auto output = body_builder->GetTupleElement(param, 2);
// seq_len is the sequence length of the current batch element (rank 1)
auto seq_len = body_builder->DynamicSlice(
seq_lens, body_builder->Reshape(i, {1}), {1});
// Indices is the offset of the batch element in the input.
auto indices = body_builder->Broadcast(
XlaHelpers::Zero(body_builder.get(), seq_lens_type),
{input_shape.dims()});
indices = body_builder->DynamicUpdateSlice(
indices, body_builder->Reshape(i, {1}),
body_builder->Reshape(
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
batch_dim_),
{1}));
// slice_indices is the offset of the start of the reversed sequence in
// the input.
auto slice_indices = body_builder->DynamicUpdateSlice(
indices,
body_builder->Sub(XlaHelpers::IntegerLiteral(
body_builder.get(), seq_lens_type, max_seq_len),
seq_len),
body_builder->Reshape(
XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type,
seq_dim_),
{1}));
// Slice out the reversed sequence. The slice will overflow the end of the
// sequence, and the contents of the overflow are implementation-defined.
// However, we will mask off these elements and replace them with elements
// from the original input so their values do not matter.
TensorShape slice_shape = input_shape;
slice_shape.set_dim(batch_dim_, 1);
auto slice = body_builder->DynamicSlice(output, slice_indices,
slice_shape.dim_sizes());
// Shift the reversed sequence to the left.
output = body_builder->DynamicUpdateSlice(output, slice, indices);
body_builder->Tuple(
{body_builder->Add(
i, XlaHelpers::One(body_builder.get(), seq_lens_type)),
seq_lens, output});
}
auto body = body_builder->Build();
OP_REQUIRES_OK(context, body.status());
auto loop_output = builder->While(
condition.ValueOrDie(), body.ValueOrDie(),
builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens,
builder->Rev(input, {seq_dim_})}));
auto output = builder->GetTupleElement(loop_output, 2);
// Mask out elements after the sequence length.
xla::ComputationDataHandle iota;
OP_REQUIRES_OK(
context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota));
std::vector<int64> dims(input_shape.dims(), 1);
dims[batch_dim_] = batch_size;
auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_});
// Broadcast the mask up to the input shape.
mask =
builder->Or(mask, builder->Broadcast(builder->ConstantR0<bool>(false),
input_shape.dim_sizes()));
output = builder->Select(mask, output, input);
context->SetOutput(0, output);
}
private:
int32 batch_dim_;
int32 seq_dim_;
};
REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp);
} // namespace
} // namespace tensorflow

View File

@ -77,10 +77,8 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder,
// Stack has not been initialized.
xla::ComputationDataHandle zero =
XlaHelpers::Zero(builder, resource->type());
TF_RETURN_IF_ERROR(resource->SetValue(
dtype,
builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()),
builder->ConstantR0<int32>(0)})));
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
// Checks the expected shape matches the actual shape.
TensorShape actual_shape;
@ -119,8 +117,8 @@ class StackOp : public XlaOpKernel {
string name = strings::StrCat("Stack: ", stack_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
value, &resource));
resource->set_tensor_array_size(size);
TensorShape(), value, /*tensor_array_size=*/size,
/*tensor_array_gradients=*/{}, &resource));
ctx->SetResourceOutput(0, resource);
}
@ -164,11 +162,9 @@ class StackPushOp : public XlaOpKernel {
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
OP_REQUIRES_OK(
ctx,
resource->SetValue(
dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices),
b->Add(index, b->ConstantR0<int32>(1))})));
OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple(
{b->DynamicUpdateSlice(ta, update, start_indices),
b->Add(index, b->ConstantR0<int32>(1))})));
ctx->SetOutput(0, value);
}
@ -208,7 +204,7 @@ class StackPopOp : public XlaOpKernel {
xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
index = b->Sub(index, b->ConstantR0<int32>(1));
OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index})));
OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =

View File

@ -231,6 +231,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
@ -252,9 +253,9 @@ class StridedSliceAssignOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
&strides_tensor));
DataType lhs_type;
TensorShape lhs_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
xla::ComputationDataHandle lhs;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
const TensorShape rhs_shape = ctx->InputShape(4);
@ -282,9 +283,6 @@ class StridedSliceAssignOp : public XlaOpKernel {
" does not match r-value shape ", rhs_shape.DebugString(),
". Automatic broadcasting not yet implemented."));
xla::ComputationDataHandle lhs;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
xla::ComputationDataHandle rhs = ctx->Input(4);
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
@ -320,13 +318,14 @@ class StridedSliceAssignOp : public XlaOpKernel {
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
}
private:
int32 begin_mask_, end_mask_;
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
DataType index_type_;
DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")

View File

@ -62,15 +62,13 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
TF_RET_CHECK(resource->tensor_array_size() >= 0)
<< resource->name() << " size " << resource->tensor_array_size();
TensorShape ta_shape;
ta_shape.AddDim(resource->tensor_array_size());
ta_shape.AppendShape(elem_shape);
if (!resource->initialized()) {
xla::ComputationDataHandle zero =
XlaHelpers::Zero(builder, resource->type());
TF_RETURN_IF_ERROR(resource->SetValue(
dtype, builder->Broadcast(zero, ta_shape.dim_sizes())));
TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
// Checks the elem_shape matches the TensorArray shape.
auto shape_or_status = builder->GetShape(resource->value());
@ -80,6 +78,10 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
TensorShape shape;
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
TensorShape ta_shape;
ta_shape.AddDim(resource->tensor_array_size());
ta_shape.AppendShape(elem_shape);
if (ta_shape != shape) {
return errors::InvalidArgument(
"Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
@ -114,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
Status GetTensorArrayShape(const XlaResource* resource,
xla::ComputationBuilder* builder,
TensorShape* shape) {
TF_RETURN_IF_ERROR(resource->GetShape(builder, shape));
if (shape->dims() < 1) {
return errors::InvalidArgument("TensorArray rank must be >= 1");
}
*shape = resource->shape();
shape->InsertDim(0, resource->tensor_array_size());
return Status::OK();
}
@ -160,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel {
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
xla::ComputationDataHandle value;
TensorShape shape;
if (element_shape_.IsFullyDefined()) {
TensorShape shape;
CHECK(element_shape_.AsTensorShape(&shape));
TensorShape ta_shape;
ta_shape.AddDim(size);
@ -175,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel {
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
dtype_, value, &var));
var->set_tensor_array_size(size);
dtype_, shape, value, /*tensor_array_size=*/size,
/*tensor_array_gradients=*/{}, &var));
ctx->SetResourceOutput(0, var);
Tensor flow(DT_FLOAT, TensorShape({}));
@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written));
OP_REQUIRES_OK(ctx, resource->SetValue(written));
ctx->SetOutput(0, flow);
}
@ -421,7 +421,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
}
OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta));
OP_REQUIRES_OK(ctx, resource->SetValue(ta));
ctx->SetOutput(0, flow);
}
@ -525,9 +525,8 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
OP_REQUIRES_OK(
ctx, resource->SetValue(
dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()))));
OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
ta, b->Reshape(value, ta_shape.dim_sizes()))));
ctx->SetOutput(0, flow);
}

View File

@ -32,9 +32,24 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle handle;
xla::ComputationBuilder* b = ctx->builder();
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
DataType type = ctx->input_type(1);
TensorShape var_shape;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
TensorShape alpha_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
errors::InvalidArgument("alpha is not a scalar: ",
alpha_shape.DebugString()));
TensorShape delta_shape = ctx->InputShape(2);
OP_REQUIRES(
ctx, var_shape.IsSameSize(delta_shape),
errors::InvalidArgument("var and delta do not have the same shape: ",
var_shape.DebugString(), " vs ",
delta_shape.DebugString()));
handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
REGISTER_XLA_OP(
@ -52,18 +67,10 @@ class ResourceApplyMomentum : public XlaOpKernel {
DataType type = ctx->input_type(2);
DataType var_type, accum_type;
TensorShape var_shape, accum_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
OP_REQUIRES_OK(ctx,
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
OP_REQUIRES(
ctx, type == var_type && type == accum_type,
errors::InvalidArgument(
"Types of variable arguments to ResourceApplyMomentum must match: ",
DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
DataTypeString(accum_type)));
xla::ComputationDataHandle var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
@ -86,10 +93,6 @@ class ResourceApplyMomentum : public XlaOpKernel {
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
xla::ComputationDataHandle var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
xla::ComputationDataHandle lr = ctx->Input(2);
xla::ComputationDataHandle grad = ctx->Input(3);
xla::ComputationDataHandle momentum = ctx->Input(4);
@ -122,18 +125,10 @@ class ResourceApplyAdagrad : public XlaOpKernel {
DataType type = ctx->input_type(2);
DataType var_type, accum_type;
TensorShape var_shape, accum_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
OP_REQUIRES_OK(ctx,
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
OP_REQUIRES(
ctx, type == var_type && type == accum_type,
errors::InvalidArgument(
"Types of variable arguments to ResourceApplyAdagrad must match: ",
DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
DataTypeString(accum_type)));
xla::ComputationDataHandle var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
@ -151,9 +146,6 @@ class ResourceApplyAdagrad : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
xla::ComputationDataHandle var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
xla::ComputationDataHandle lr = ctx->Input(2);
xla::ComputationDataHandle grad = ctx->Input(3);
@ -175,18 +167,11 @@ class ResourceApplyAdam : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
DataType var_type, m_type, v_type;
TensorShape var_shape, m_shape, v_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape));
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape));
OP_REQUIRES(
ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type,
errors::InvalidArgument(
"Types of variable arguments to ResourceApplyRMSProp must match: ",
DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ",
DataTypeString(m_type), " vs. ", DataTypeString(v_type)));
xla::ComputationDataHandle var, m, v;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
TensorShape beta1_power_shape = ctx->InputShape(3);
TensorShape beta2_power_shape = ctx->InputShape(4);
@ -228,10 +213,6 @@ class ResourceApplyAdam : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
xla::ComputationDataHandle var, m, v;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v));
xla::ComputationDataHandle beta1_power = ctx->Input(3);
xla::ComputationDataHandle beta2_power = ctx->Input(4);
xla::ComputationDataHandle lr = ctx->Input(5);
@ -278,18 +259,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
DataType type = ctx->input_type(3);
DataType var_type, ms_type, mom_type;
TensorShape var_shape, ms_shape, mom_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape));
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape));
OP_REQUIRES(
ctx, type == var_type && type == ms_type && type == mom_type,
errors::InvalidArgument(
"Types of variable arguments to ResourceApplyRMSProp must match: ",
DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ",
DataTypeString(ms_type), " vs. ", DataTypeString(mom_type)));
xla::ComputationDataHandle var, ms, mom;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
TensorShape lr_shape = ctx->InputShape(3);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
@ -323,10 +297,6 @@ class ResourceApplyRMSProp : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
xla::ComputationDataHandle var, ms, mom;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom));
xla::ComputationDataHandle lr = ctx->Input(3);
xla::ComputationDataHandle rho = ctx->Input(4);
xla::ComputationDataHandle momentum = ctx->Input(5);
@ -373,20 +343,11 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
bool has_l2_shrinkage) {
xla::ComputationBuilder* b = ctx->builder();
DataType var_type, accum_type, linear_type;
TensorShape var_shape, accum_shape, linear_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
OP_REQUIRES_OK(ctx,
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
OP_REQUIRES_OK(ctx,
ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape));
OP_REQUIRES(
ctx, dtype == var_type && dtype == accum_type && dtype == linear_type,
errors::InvalidArgument(
"Types of variable arguments to ResourceApplyFtrlV2 must match: ",
DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ",
DataTypeString(accum_type), " and ", DataTypeString(linear_type)));
xla::ComputationDataHandle var, accum, linear;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
@ -438,10 +399,6 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
errors::InvalidArgument("lr_power is not a scalar: ",
lr_power_shape.DebugString()));
xla::ComputationDataHandle var, accum, linear;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear));
xla::ComputationDataHandle grad = ctx->Input(3);
xla::ComputationDataHandle lr = ctx->Input(4);
xla::ComputationDataHandle l1 = ctx->Input(5);

View File

@ -50,18 +50,41 @@ XLAJIT_MAKE_UNARY(Conj, b->Conj(x));
// Return x if x>0, otherwise -x.
XLAJIT_MAKE_UNARY(Abs, b->Abs(x));
// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x))
XLAJIT_MAKE_UNARY(
Acos,
b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
b->Mul(x, x)),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)),
b->Add(XlaHelpers::One(b, input_type(0)), x))));
// acosh(x) = log(x + sqrt(x^2 - 1))
XLAJIT_MAKE_UNARY(
Acosh,
b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x),
XlaHelpers::One(b, input_type(0))),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
XLAJIT_MAKE_UNARY(
Asin,
b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0),
b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)),
b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)),
b->Mul(x, x)),
XlaHelpers::FloatLiteral(b, input_type(0),
0.5))))));
// asinh(x) = log(x + sqrt(x^2 + 1))
XLAJIT_MAKE_UNARY(
Asinh,
b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x),
XlaHelpers::One(b, input_type(0))),
XlaHelpers::FloatLiteral(b, input_type(0), 0.5)))));
XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0))));
// atanh(x) = 0.5 * log((1 + x) / (1 - x))
XLAJIT_MAKE_UNARY(
Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x),

View File

@ -33,21 +33,29 @@ class VarIsInitializedOp : public XlaOpKernel {
public:
explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle handle;
bool initialized = ctx->ReadVariableInput(0, &handle).ok();
ctx->SetOutput(0, ctx->builder()->ConstantR0<bool>(initialized));
XlaResource* variable;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
ctx->SetOutput(0,
ctx->builder()->ConstantR0<bool>(variable->initialized()));
}
};
REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
class ReadVariableOp : public XlaOpKernel {
public:
explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle handle;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
OP_REQUIRES_OK(
ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
ctx->SetOutput(0, handle);
}
private:
DataType dtype_;
};
REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp);
@ -65,10 +73,12 @@ class AssignAddVariableOp : public XlaOpKernel {
public:
explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(1);
xla::ComputationDataHandle handle;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Add(handle, ctx->Input(1));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
REGISTER_XLA_OP(
@ -79,10 +89,12 @@ class AssignSubVariableOp : public XlaOpKernel {
public:
explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
DataType type = ctx->input_type(1);
xla::ComputationDataHandle handle;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
OP_REQUIRES_OK(ctx,
ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Sub(handle, ctx->Input(1));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
REGISTER_XLA_OP(
@ -95,28 +107,19 @@ class ResourceGatherOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* builder = ctx->builder();
// Get the shape of the resource tensor.
DataType type = ctx->expected_output_dtype(0);
TensorShape resource_shape;
DataType resource_dtype;
OP_REQUIRES_OK(
ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape));
DataType expected_output_dtype = ctx->expected_output_dtype(0);
OP_REQUIRES(ctx, resource_dtype == expected_output_dtype,
errors::InvalidArgument(
"Variable dtype is ", DataTypeString(resource_dtype),
" but expected output dtype is ",
DataTypeString(expected_output_dtype), "."));
xla::ComputationDataHandle resource_handle;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
&resource_handle));
auto indices = ctx->Input(1);
auto indices_shape = ctx->InputShape(1);
DataType index_type = ctx->input_type(1);
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
ctx, resource_handle, resource_shape, indices, indices_shape, 0,
resource_dtype, index_type, builder);
ctx, resource_handle, resource_shape, indices, indices_shape, 0, type,
index_type, builder);
ctx->SetOutput(0, gather);
}
};

View File

@ -58,9 +58,8 @@ Status MakeXlaCompilerArgumentsFromInputs(
}
arg.type = resource->type();
if (arg.initialized) {
TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape));
} else {
arg.shape = resource->shape();
if (!arg.initialized) {
*has_uninitialized_vars = true;
}
arg.tensor_array_size = resource->tensor_array_size();
@ -70,14 +69,13 @@ Status MakeXlaCompilerArgumentsFromInputs(
arg.name = resource->name();
VLOG(2) << " resource " << resource->name()
<< " type: " << DataTypeString(arg.type)
<< " shape: " << xla::ShapeUtil::HumanString(arg.shape)
<< " shape: " << arg.shape.DebugString()
<< " initialized: " << arg.initialized;
} else {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = ctx->input_type(i);
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
arg.shape = ctx->InputShape(i);
}
}
return Status::OK();
@ -154,17 +152,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
XlaCompiler::Argument& arg = arguments[update.input_index];
if (!arg.initialized) {
VLOG(2) << "Update shape for argument " << update.input_index << " "
<< xla::ShapeUtil::HumanString(update.shape);
<< update.shape.DebugString();
arg.initialized = true;
xla::Shape shape = update.shape;
if (!update.tensor_array_gradients_accessed.empty()) {
shape = xla::ShapeUtil::GetTupleElementShape(shape, 0);
}
std::unique_ptr<xla::Literal> zero =
xla::Literal::CreateFromShape(shape);
OP_REQUIRES_OK(ctx, resource->SetValue(
update.type, builder->ConstantLiteral(*zero)));
arg.shape = update.shape;
OP_REQUIRES_OK(ctx,
resource->SetTypeAndShape(update.type, update.shape));
OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
}
// Add any TensorArray gradients touched by the body to the enclosing
@ -182,9 +177,6 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
// Recompute the argument shape.
OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape));
}
// Recompile the body with the "correct" resource shapes.
VLOG(1) << "Recompiling body with corrected resource shapes";
@ -292,13 +284,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,
builder->GetTupleElement(while_result, pos),
/*reset_initial_values=*/false, builder));
builder->GetTupleElement(while_result, pos), builder));
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
<< " name: " << resource->name() << " modified: " << update.modified
<< " type: " << DataTypeString(update.type)
<< " shape: " << xla::ShapeUtil::HumanString(update.shape);
<< " shape: " << update.shape.DebugString();
// Copies the identity of the resource variable from input to output
// unchanged, even if the variable was not modified.
ctx->op_kernel_context()->set_output(

View File

@ -60,6 +60,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/core:lib",

View File

@ -25,11 +25,10 @@ limitations under the License.
namespace tensorflow {
// The current implementation simply unrolls the computation along the batch
// dimension.
xla::StatusOr<xla::ComputationDataHandle> BatchDot(
xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) {
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
bool conjugate_x, bool conjugate_y) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> x_shape,
builder->GetShape(x));
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> y_shape,
@ -89,10 +88,10 @@ xla::StatusOr<xla::ComputationDataHandle> BatchDot(
dimensions);
}
if (x_shape->element_type() == xla::C64 && transpose_x) {
if (x_shape->element_type() == xla::C64 && conjugate_x) {
x = builder->Conj(x);
}
if (y_shape->element_type() == xla::C64 && transpose_y) {
if (y_shape->element_type() == xla::C64 && conjugate_y) {
y = builder->Conj(y);
}

View File

@ -27,7 +27,10 @@ namespace tensorflow {
// viewed as an element of a batch), and arranges the individual results
// in a single output tensor of the same batch size. Each of the
// individual slices can optionally be transposed before multiplication by
// setting the `transpose_x` or `transpose_y` flag to `true`.
// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each
// can be elementwise-complex-conjugated by setting the `conjugate_x` or
// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both
// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`.
//
// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]`
// and `[..., r_y, c_y]`.
@ -40,11 +43,10 @@ namespace tensorflow {
// It is computed as:
//
// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :])
// TODO(phawkins): add an option to take the complex conjugate of the LHS or
// RHS.
xla::StatusOr<xla::ComputationDataHandle> BatchDot(
xla::ComputationBuilder* builder, xla::ComputationDataHandle x,
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y);
xla::ComputationDataHandle y, bool transpose_x, bool transpose_y,
bool conjugate_x = false, bool conjugate_y = false);
} // namespace tensorflow

View File

@ -71,11 +71,14 @@ xla::StatusOr<xla::ComputationDataHandle> CholeskyUnblocked(
SliceInMinorDims(builder, l, {j + 1, 0}, {n, j}));
TF_ASSIGN_OR_RETURN(auto r_squared,
BatchDot(builder, r, r, /*transpose_x=*/false,
/*transpose_y=*/true));
/*transpose_y=*/true, /*conjugate_x=*/false,
/*conjugate_y=*/false));
new_d_squared = builder->Sub(new_d_squared, r_squared);
TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false,
/*transpose_y=*/true));
/*transpose_y=*/true,
/*conjugate_x=*/false,
/*conjugate_y=*/false));
}
auto new_d_inv = builder->Pow(
new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5));
@ -134,7 +137,8 @@ xla::StatusOr<xla::ComputationDataHandle> Cholesky(
SliceInMinorDims(builder, l, {i, 0}, {i + k, i}));
TF_ASSIGN_OR_RETURN(auto delta,
BatchDot(builder, lhs, rhs, /*transpose_x=*/false,
/*transpose_y=*/true));
/*transpose_y=*/true, /*conjugate_x=*/false,
/*conjugate_y=*/false));
TF_ASSIGN_OR_RETURN(auto before,
SliceInMinorDims(builder, a, {i, i}, {n, i + k}));
TF_ASSIGN_OR_RETURN(
@ -155,6 +159,10 @@ xla::StatusOr<xla::ComputationDataHandle> Cholesky(
SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
TF_ASSIGN_OR_RETURN(auto update,
TriangularSolve(builder, factorized, panel,
/*left_side=*/false,
/*lower=*/true,
/*transpose_a=*/true,
/*conjugate_a=*/false,
/*block_size=*/8));
TF_ASSIGN_OR_RETURN(
l, UpdateSliceInMinorDims(builder, l, update, {i + k, i}));

View File

@ -29,6 +29,7 @@ namespace tensorflow {
// the block size to use.
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(mattjj): handle the complex Hermitian case
xla::StatusOr<xla::ComputationDataHandle> Cholesky(
xla::ComputationBuilder* builder, xla::ComputationDataHandle a,
int64 block_size = 256);

View File

@ -24,13 +24,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
xla::ComputationDataHandle b, int64 block_size) {
xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
bool conjugate_a, int64 block_size) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
builder->GetShape(a));
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
@ -60,14 +62,15 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
batch_dimensions.push_back(a_size);
}
const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1);
const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) {
if (xla::ShapeUtil::GetDimension(*a_shape, -1) !=
xla::ShapeUtil::GetDimension(*a_shape, -2)) {
return errors::InvalidArgument(
"The 'a' arguments to TriangularSolve must be square matrices: ",
xla::ShapeUtil::HumanString(*a_shape));
}
if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) {
const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) {
return errors::InvalidArgument(
"Arguments to TriangularSolve have incompatible matrix shapes: ",
xla::ShapeUtil::HumanString(*a_shape), " vs ",
@ -89,6 +92,14 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
return output;
};
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
auto maybe_conj = [&](xla::ComputationBuilder* builder,
xla::ComputationDataHandle x) {
auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
return perform_conj ? builder->Conj(x) : x;
};
std::map<int, xla::Computation> base_computations;
auto get_base_triangular_solve =
[&](int k) -> xla::StatusOr<xla::Computation*> {
@ -103,19 +114,35 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
prepend_batch_dims({k, k})),
"a");
std::array<int64, 2> b_lastd;
if (left_side) {
b_lastd = {k, n};
} else {
b_lastd = {m, k};
}
auto b_param =
sub->Parameter(1,
xla::ShapeUtil::MakeShape(b_shape->element_type(),
prepend_batch_dims({m, k})),
prepend_batch_dims(b_lastd)),
"b");
// TODO(phawkins): it might make sense to use a while loop here, rather
// than unrolling.
// TODO(phawkins): the left-looking variant of the algorithm might be more
// efficient at block size 1.
TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
/*block_size=*/1)
.status());
// We use a left-looking subroutine on the block diagonal in some common
// cases, while falling back to a recursive call in unsupported cases. The
// left-looking subroutine is written with a While loop and so yields much
// faster compile times. Moreover, the left-looking variant can give
// higher performance on smaller (sub)problems.
if (left_side && lower) {
TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param,
b_param, transpose_a,
conjugate_a)
.status());
} else {
TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param,
left_side, lower, transpose_a,
conjugate_a,
/*block_size=*/1)
.status());
}
TF_ASSIGN_OR_RETURN(computation, sub->Build());
}
@ -129,47 +156,396 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
// Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation
// of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1
// (2008): 4.
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
// if k > 1:
// output[..., :, i:i+k] = triangular_solve(
// a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right',
// kind='Lower', transpose=True, block_size=1)
// else:
// output[..., :, i] = b[..., :, i] / a[..., i, i]
TF_ASSIGN_OR_RETURN(auto a_slice,
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
xla::ComputationDataHandle update;
if (k > 1) {
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
update = builder->Div(b_slice, a_slice);
// In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if
// conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if
// conjugate_a is True.
if (!left_side && lower == transpose_a) {
// for i in range(0, a.shape[-1], block_size):
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
// output[..., :, i:i+k] = triangular_solve(
// a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
TF_ASSIGN_OR_RETURN(auto a_slice,
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
xla::ComputationDataHandle update;
if (k > 1) {
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
// if i + k < a.shape[-1]:
// a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2)
if (i + k < n) {
xla::ComputationDataHandle a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
} else {
TF_ASSIGN_OR_RETURN(
a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n}));
}
TF_ASSIGN_OR_RETURN(auto b_update,
BatchDot(builder, update, a_slice_2,
/*transpose_x=*/false,
/*transpose_y=*/transpose_a,
/*conjugate_x=*/false,
/*conjugate_y=*/conjugate_a));
TF_ASSIGN_OR_RETURN(auto b_slice_2,
SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
b_update = builder->Sub(b_slice_2, b_update);
TF_ASSIGN_OR_RETURN(
b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
}
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
// b[..., :, i+k:] -= np.dot(output[..., :, i:i+k],
// np.transpose(..., a[i+k:, i:i+k]))
if (i + k < n) {
TF_ASSIGN_OR_RETURN(auto a_slice_2,
SliceInMinorDims(builder, a, {i + k, i}, {n, i + k}));
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2,
/*transpose_x=*/false,
/*transpose_y=*/true));
} else if (left_side && lower != transpose_a) {
// for i in range(0, a.shape[-1], block_size):
for (int64 i = 0; i < m; i += block_size) {
int64 k = std::min(block_size, m - i);
TF_ASSIGN_OR_RETURN(auto b_slice_2,
SliceInMinorDims(builder, b, {0, i + k}, {m, n}));
b_update = builder->Sub(b_slice_2, b_update);
// output[..., i:i+k, :] = triangular_solve(
// a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
TF_ASSIGN_OR_RETURN(auto a_slice,
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
xla::ComputationDataHandle update;
if (k > 1) {
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
}
TF_ASSIGN_OR_RETURN(
b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k}));
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
// if i + k < a.shape[-1]:
// a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:]
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :])
if (i + k < m) {
xla::ComputationDataHandle a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(
a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k}));
} else {
TF_ASSIGN_OR_RETURN(
a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m}));
}
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
/*transpose_x=*/transpose_a,
/*transpose_y=*/false,
/*conjugate_x=*/conjugate_a,
/*conjugate_y=*/false));
TF_ASSIGN_OR_RETURN(auto b_slice_2,
SliceInMinorDims(builder, b, {i + k, 0}, {m, n}));
b_update = builder->Sub(b_slice_2, b_update);
TF_ASSIGN_OR_RETURN(
b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0}));
}
}
} else if (!left_side && lower != transpose_a) {
// for i in reversed(range(0, a.shape[-1], block_size)):
const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size;
for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
int64 k = std::min(block_size, n - i);
// output[..., :, i:i+k] triangular_solve(
// a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1)
TF_ASSIGN_OR_RETURN(auto a_slice,
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {0, i}, {m, i + k}));
xla::ComputationDataHandle update;
if (k > 1) {
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {0, i}));
// if i - k >= 0:
// a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2)
if (i - k >= 0) {
xla::ComputationDataHandle a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(a_slice_2,
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
} else {
TF_ASSIGN_OR_RETURN(a_slice_2,
SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
}
TF_ASSIGN_OR_RETURN(auto b_update,
BatchDot(builder, update, a_slice_2,
/*transpose_x=*/false,
/*transpose_y=*/transpose_a,
/*conjugate_x=*/false,
/*conjugate_y=*/conjugate_a));
TF_ASSIGN_OR_RETURN(auto b_slice_2,
SliceInMinorDims(builder, b, {0, 0}, {m, i}));
b_update = builder->Sub(b_slice_2, b_update);
TF_ASSIGN_OR_RETURN(
b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
}
}
} else { // left_side && lower == transpose_a
// for i in reversed(range(0, a.shape[-1], block_size)):
const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size;
for (int64 i = last_blk_ix; i >= 0; i -= block_size) {
int64 k = std::min(block_size, m - i);
// output[..., i:i+k, :] triangular_solve(
// a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1)
TF_ASSIGN_OR_RETURN(auto a_slice,
SliceInMinorDims(builder, a, {i, i}, {i + k, i + k}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {i, 0}, {i + k, n}));
xla::ComputationDataHandle update;
if (k > 1) {
TF_ASSIGN_OR_RETURN(xla::Computation * solve,
get_base_triangular_solve(k));
update = builder->Call(*solve, {a_slice, b_slice});
} else {
update = builder->Div(b_slice, maybe_conj(builder, a_slice));
}
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
// if i - k >= 0:
// a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k]
// a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2
// b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :])
if (i - k >= 0) {
xla::ComputationDataHandle a_slice_2;
if (lower) {
TF_ASSIGN_OR_RETURN(a_slice_2,
SliceInMinorDims(builder, a, {i, 0}, {i + k, i}));
} else {
TF_ASSIGN_OR_RETURN(a_slice_2,
SliceInMinorDims(builder, a, {0, i}, {i, i + k}));
}
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update,
/*transpose_x=*/transpose_a,
/*transpose_y=*/false,
/*conjugate_x=*/conjugate_a,
/*conjugate_y=*/false));
TF_ASSIGN_OR_RETURN(auto b_slice_2,
SliceInMinorDims(builder, b, {0, 0}, {i, n}));
b_update = builder->Sub(b_slice_2, b_update);
TF_ASSIGN_OR_RETURN(
b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0}));
}
}
}
return output;
}
xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> a_shape,
builder->GetShape(a));
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> b_shape,
builder->GetShape(b));
const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2);
const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1);
const int64 ndims = xla::ShapeUtil::Rank(*a_shape);
std::vector<int64> batch_dimensions;
for (int i = 0; i < ndims - 2; ++i) {
int64 a_size = a_shape->dimensions(i);
batch_dimensions.push_back(a_size);
}
auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
std::vector<int64> output(ndims);
std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
std::copy(indices.begin(), indices.end(),
output.begin() + batch_dimensions.size());
return output;
};
auto maybe_conj = [&](xla::ComputationBuilder* builder,
xla::ComputationDataHandle x) {
auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
return perform_conj ? builder->Conj(x) : x;
};
// The main computation is performed in a While loop.
// Allocate the output and set its first or last row,
// output = np.zeros_like(b)
// if transpose_a:
// output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:]
// else:
// output[..., :1, :] = b[..., :1, :] / a[..., :1, :1]
xla::ComputationDataHandle output = Zeros(builder, *b_shape);
{
auto i = transpose_a ? m - 1 : 0;
TF_ASSIGN_OR_RETURN(auto a_slice,
SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1}));
TF_ASSIGN_OR_RETURN(auto b_slice,
SliceInMinorDims(builder, b, {i, 0}, {i + 1, n}));
auto update = builder->Div(b_slice, maybe_conj(builder, a_slice));
TF_ASSIGN_OR_RETURN(
output, UpdateSliceInMinorDims(builder, output, update, {i, 0}));
}
// Construct the initial loop carry tuple,
// if transpose_a:
// init = (m-2, output, a, b)
// else:
// init = (1, output, a, b)
std::vector<xla::Shape> tuple_shapes = {
// The loop iteration counter is a scalar, incremented each iteration.
xla::ShapeUtil::MakeShape(xla::S32, {}),
// The output has the shape of b, with one row updated each iteration.
*b_shape,
// The coefficient matrix a is a loop invariant.
*a_shape,
// The right-hand-side matrix b is a loop invariant.
*b_shape};
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
auto init_i = builder->ConstantR0<int32>(transpose_a ? m - 2 : 1);
auto init = builder->Tuple({init_i, output, a, b});
// Construct the loop condition function,
// def cond_fun(loop_carry):
// i, output, a, b = loop_carry
// return i >= 0 if transpose_a else i < m
std::unique_ptr<xla::ComputationBuilder> condb =
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond");
{
auto i = condb->GetTupleElement(
condb->Parameter(0, tuple_shape,
"TriangularSolveLeftLookingWhileTuple"),
0);
if (transpose_a) {
condb->Ge(i, condb->ConstantR0<int32>(0));
} else {
condb->Lt(i, condb->ConstantR0<int32>(m));
}
}
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
// Construct the loop body function,
// def body_fun(loop_carry):
// i, output, a, b = loop_carry
// if transpose_a:
// a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2)
// else:
// a_row = a[..., i:i+1, :i]
// result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :])
// output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
// if transpose_a:
// return (i - 1, output, a, b)
// else:
// return (i + 1, output, a, b)
// We have to do some extra FLOPs propagating zeros in the matrix multiply
// because we can't have the size of its arguments depend on the loop counter.
std::unique_ptr<xla::ComputationBuilder> bodyb =
builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody");
{
auto input_tuple = bodyb->Parameter(0, tuple_shape,
"TriangularSolveLeftLookingWhileTuple");
// i, output, a, b = loop_carry
auto i = bodyb->GetTupleElement(input_tuple, 0);
auto body_out = bodyb->GetTupleElement(input_tuple, 1);
auto body_a = bodyb->GetTupleElement(input_tuple, 2);
auto body_b = bodyb->GetTupleElement(input_tuple, 3);
auto zero = bodyb->ConstantR0<int32>(0);
// Set up some helper functions.
auto prepend_zeros = [&](std::array<xla::ComputationDataHandle, 2> starts) {
auto zero = bodyb->Reshape(bodyb->ConstantR0<int32>(0), {1});
std::vector<xla::ComputationDataHandle> padded_starts(ndims, zero);
padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1});
padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1});
return bodyb->ConcatInDim(padded_starts, 0);
};
auto dynamic_slice = [&](xla::ComputationDataHandle x,
std::array<xla::ComputationDataHandle, 2> starts,
std::array<int64, 2> sizes) {
auto padded_starts = prepend_zeros(starts);
auto padded_sizes = prepend_batch_dims(sizes);
return bodyb->DynamicSlice(x, padded_starts, padded_sizes);
};
auto update = [&](xla::ComputationDataHandle x,
xla::ComputationDataHandle update,
std::array<xla::ComputationDataHandle, 2> starts) {
auto padded_starts = prepend_zeros(starts);
return bodyb->DynamicUpdateSlice(x, update, padded_starts);
};
// We'd like to implement this:
// if transpose_a:
// a_row = T(a[..., i+1:, i:i+1])
// result_row = (b[..., i:i+1, :]
// - np.matmul(a_row, body_out[..., i+1:, :]))
// else:
// result_row = (b[..., i:i+1, :]
// - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :]))
// But since we can't have intermediate array sizes depend on the loop
// counter, we instead exploit the fact that we initialized the output to
// all zeros and use that as zero-padding (doing unnecessary FLOPs).
xla::ComputationDataHandle a_row;
if (transpose_a) {
a_row = dynamic_slice(body_a, {zero, i}, {m, 1});
} else {
a_row = dynamic_slice(body_a, {i, zero}, {1, m});
}
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out,
/*transpose_x=*/transpose_a,
/*transpose_y=*/false,
/*conjugate_x=*/conjugate_a,
/*conjugate_y=*/false));
auto result_row =
bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update);
// body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1});
auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt));
body_out = update(body_out, div_result, {i, zero});
// if transpose_a:
// return (i - 1, body_out, a, b)
// else:
// return (i + 1, body_out, a, b)
auto next_i = bodyb->Add(i, bodyb->ConstantR0<int32>(transpose_a ? -1 : 1));
bodyb->Tuple({next_i, body_out, body_a, body_b});
}
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
// Construct the While loop and return the result,
// return while_loop(cond_fun, body_fun, init)[1]
auto triangular_solve_left_looking_while = builder->While(cond, body, init);
return builder->GetTupleElement(triangular_solve_left_looking_while, 1);
}
} // namespace tensorflow

View File

@ -21,25 +21,50 @@ limitations under the License.
namespace tensorflow {
// Solves systems of linear equations with upper or lower triangular matrices by
// backsubstitution.
// Solves systems of linear equations with lower or upper triangular coefficient
// matrices by forward- or back-substitution. Broadcasting along leading
// dimensions, this routine solves one of the matrix systems
// `op(a) * x = b`, or `x * op(a) = b`,
// for the variable `x` given `a` and `b`, where `op(a)` is either
// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`.
// That is, the innermost matrices in the output satisfy a scalar system
// depending on the value of the value of (left_side, transpose_a, conjugate_a)
// according to:
// (F, F, F) => `output[..., i, k] a[..., k, j] = b[..., i, j]`,
// (F, F, T) => `output[..., i, k] a*[..., k, j] = b[..., i, j]`,
// (F, T, F) => `output[..., i, k] a[..., j, k] = b[..., i, j]`,
// (F, T, T) => `output[..., i, k] a*[..., j, k] = b[..., i, j]`,
// (T, F, F) => ` a[..., i, k] output[..., k, j] = b[..., i, j]`,
// (T, F, T) => `a*[..., i, k] output[..., k, j] = b[..., i, j]`,
// (T, T, F) => ` a[..., i, k] output[..., j, k] = b[..., i, j]`,
// (T, T, T) => `a*[..., i, k] output[..., j, k] = b[..., i, j]`,
// where * denotes complex conjugation and where the index `k` is summed over.
//
// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
// square matrices. The strictly upper triangular part of each inner-most matrix
// is assumed to be zero and not accessed.
// `b` is a tensor of shape `[..., M, K]`.
//
// The innermost matrices in the output satisfy matrix equations
// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`.
// `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
// square matrices. If lower is true (false), then the strictly upper (lower)
// triangular part of each innermost matrix in `a` is assumed to be zero and is
// not accessed.
// `b` is a tensor of shape `[..., M, K]` if left_side is true, otherwise a
// tensor of shape `[..., K, M]`.
// `left_side` is a boolean, indicating whether to solve a system of the form
// op(a) * x = b (true) or x * op(a) = b (false).
// `lower` is a boolean, indicating whether the argument `a` is lower-triangular
// (true) or upper-triangular (false).
// `transpose_a` is a boolean indicating whether the matrix `a` is transposed.
// `conjugate_a` is a boolean indicating whether the entries of `a` are complex
// conjugated (independently of whether they are transposed), so that when both
// transpose_a and conjugate_a are true the effect is a Hermitian adjoint.
//
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
// blocking is used.
// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right,
// kind=lower, and transposed_a=true. Implement the other possible combinations
// of side, kind and transposed_a.
xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
xla::ComputationDataHandle b, int64 block_size = 256);
xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a,
bool conjugate_a, int64 block_size = 256);
xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a,
const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a);
} // namespace tensorflow

View File

@ -27,32 +27,68 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
using TriangularSolveTest = xla::ClientLibraryTestBase;
using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase;
using complex64 = xla::complex64;
XLA_TEST_F(TriangularSolveTest, Simple) {
xla::Array2D<float> AValsLower() {
return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}};
}
xla::Array2D<float> AValsUpper() {
return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}};
}
xla::Array2D<float> BValsRight() {
return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
}
xla::Array2D<float> BValsLeft() {
return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
}
xla::Array2D<complex64> AValsLowerComplex() {
return {{2, 0, 0, 0},
{complex64(3, 1), 6, 0, 0},
{4, complex64(7, 2), 9, 0},
{5, 8, complex64(10, 3), 11}};
}
xla::Array2D<complex64> AValsUpperComplex() {
return {{2, 3, complex64(4, 3), 5},
{0, 6, complex64(7, 2), 8},
{0, 0, complex64(9, 1), 10},
{0, 0, 0, 11}};
}
xla::Array2D<complex64> BValsRightComplex() {
return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
}
xla::Array2D<complex64> BValsLeftComplex() {
return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
}
xla::Array2D<float> AValsFull() {
return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
}
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) {
xla::ComputationBuilder builder(client_, TestName());
xla::Array2D<float> a_vals({
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 9, 0},
{5, 8, 10, 11},
});
xla::Array2D<float> b_vals({
{1, 2, 3, 4},
{5, 6, 7, 8},
{9, 10, 11, 12},
});
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(b_vals, 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b, /*block_size=*/2);
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/false, /*lower=*/true,
/*transpose_a=*/true, /*conjugate_a=*/false,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
@ -62,7 +98,267 @@ XLA_TEST_F(TriangularSolveTest, Simple) {
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(2e-3, 2e-3));
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/false, /*lower=*/true,
/*transpose_a=*/false, /*conjugate_a=*/false,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
{0.64393939, 0.06565657, -0.03030303, 0.72727273},
{1.4520202, 0.2003367, 0.01010101, 1.09090909},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/false, /*lower=*/false,
/*transpose_a=*/true, /*conjugate_a=*/false,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
{-0.16414141, -0.06902357, -0.07070707, 0.36363636},
{0.64393939, 0.06565657, -0.03030303, 0.72727273},
{1.4520202, 0.2003367, 0.01010101, 1.09090909},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsRight(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/false, /*lower=*/false,
/*transpose_a=*/false, /*conjugate_a=*/false,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
{0.5, 0.08333334, 0.04629629, 0.03367003},
{2.5, -0.25, -0.1388889, -0.1010101},
{4.5, -0.58333331, -0.32407406, -0.23569024},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/true, /*lower=*/true,
/*transpose_a=*/true, /*conjugate_a=*/false,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
{-0.89646465, -0.69444444, -0.49242424},
{-0.27441077, -0.24074074, -0.20707071},
{-0.23232323, -0.22222222, -0.21212121},
{0.90909091, 1., 1.09090909},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/true, /*lower=*/true,
/*transpose_a=*/false, /*conjugate_a=*/false,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
{0.41666667, 0.33333333, 0.25},
{0.23148148, 0.18518519, 0.13888889},
{0.16835017, 0.13468013, 0.1010101},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/true, /*lower=*/false,
/*transpose_a=*/true, /*conjugate_a=*/false,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
{0.41666667, 0.33333333, 0.25},
{0.23148148, 0.18518519, 0.13888889},
{0.16835017, 0.13468013, 0.1010101},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(AValsUpper(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/true, /*lower=*/false,
/*transpose_a=*/false, /*conjugate_a=*/false,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
{-0.89646465, -0.69444444, -0.49242424},
{-0.27441077, -0.24074074, -0.20707071},
{-0.23232323, -0.22222222, -0.21212121},
{0.90909091, 1., 1.09090909},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data =
CreateR2Parameter<complex64>(AValsLowerComplex(), 0, "a", &builder, &a);
auto b_data =
CreateR2Parameter<complex64>(BValsRightComplex(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/false, /*lower=*/true,
/*transpose_a=*/true, /*conjugate_a=*/true,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<complex64> expected({
{0.5, complex64(0.08333333, 0.08333333),
complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)},
{2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963),
complex64(0.08670034, -0.02104377)},
{4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296),
complex64(0.11026936, -0.03114478)},
});
ComputeAndCompareR2<complex64>(&builder, expected,
{a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data =
CreateR2Parameter<complex64>(AValsUpperComplex(), 0, "a", &builder, &a);
auto b_data =
CreateR2Parameter<complex64>(BValsLeftComplex(), 1, "b", &builder, &b);
auto result = TriangularSolve(&builder, a, b,
/*left_side=*/true, /*lower=*/false,
/*transpose_a=*/true, /*conjugate_a=*/false,
/*block_size=*/2);
TF_ASSERT_OK(result.status());
xla::Array2D<complex64> expected({
{0.5, 1., 1.5},
{0.41666667, 0.33333333, 0.25},
{complex64(0.20020325, -2.81504065e-01),
complex64(0.13821138, -4.22764228e-01),
complex64(0.07621951, -5.64024390e-01)},
{complex64(0.19678492, 2.55912786e-01),
complex64(0.17738359, 3.84331116e-01),
complex64(0.15798226, 5.12749446e-01)},
});
ComputeAndCompareR2<complex64>(&builder, expected,
{a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(AValsLower(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolveLeftLooking(&builder, a, b,
/*transpose_a=*/false,
/*conjugate_a=*/false);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
{0.41666667, 0.33333333, 0.25},
{0.23148148, 0.18518519, 0.13888889},
{0.16835017, 0.13468013, 0.1010101},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) {
xla::ComputationBuilder builder(client_, TestName());
xla::ComputationDataHandle a, b;
auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
auto result = TriangularSolveLeftLooking(&builder, a, b,
/*transpose_a=*/false,
/*conjugate_a=*/false);
TF_ASSERT_OK(result.status());
xla::Array2D<float> expected({
{0.5, 1.0, 1.5},
{0.41666667, 0.33333333, 0.25},
{0.23148148, 0.18518519, 0.13888889},
{0.16835017, 0.13468013, 0.1010101},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
xla::ErrorSpec(1e-2, 1e-2));
}
} // namespace

View File

@ -107,4 +107,15 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
return UpdateSlice(builder, x, update, padded_start);
}
xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
const int64 n_dims = xla::ShapeUtil::Rank(*shape);
TF_RET_CHECK(n_dims >= 2);
std::vector<int64> permutation(n_dims);
std::iota(permutation.begin(), permutation.end(), 0);
std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
return builder->Transpose(x, permutation);
}
} // namespace tensorflow

View File

@ -49,6 +49,10 @@ xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start);
// Transposes a stack of matrices `x` by swapping the last two dimensions.
xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_

View File

@ -241,9 +241,7 @@ Status CreateXlaArgs(const Graph& graph,
XlaCompiler::Argument arg;
arg.kind = XlaCompiler::Argument::kParameter;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
TensorShape shape;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
xla_args->push_back(arg);
}

View File

@ -66,13 +66,14 @@ Status CheckSignature(const DataTypeVector& types,
bool XlaCompiler::Argument::operator==(
const XlaCompiler::Argument& other) const {
if (std::tie(kind, resource_kind, type, name, tensor_array_size,
if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size,
tensor_array_gradients) !=
std::tie(other.kind, other.resource_kind, other.type, other.name,
other.tensor_array_size, other.tensor_array_gradients)) {
other.initialized, other.tensor_array_size,
other.tensor_array_gradients)) {
return false;
}
if (!xla::ShapeUtil::Equal(shape, other.shape)) {
if (shape != other.shape) {
return false;
}
if (constant_value.shape() != other.constant_value.shape()) {
@ -230,6 +231,64 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
return Status::OK();
}
// Computes the XLA shape for argument 'arg'.
/*static*/ Status XlaCompiler::XLAShapeForArgument(
const XlaCompiler::Argument& arg, xla::Shape* xla_shape) {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
xla_shape);
case XlaCompiler::Argument::kParameter:
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
switch (arg.resource_kind) {
case XlaResource::kVariable:
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
case XlaResource::kTensorArray: {
if (arg.tensor_array_size < 0) {
return errors::InvalidArgument(
"Negative tensor_array_size in XLAShapeForArgument");
}
TensorShape shape;
shape.AddDim(arg.tensor_array_size);
shape.AppendShape(arg.shape);
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
if (!arg.tensor_array_gradients.empty()) {
std::vector<xla::Shape> tuple_shape(
arg.tensor_array_gradients.size() + 1, *xla_shape);
*xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
}
return Status::OK();
}
case XlaResource::kStack: {
if (arg.tensor_array_size < 0) {
return errors::InvalidArgument(
"Negative tensor_array_size in XLAShapeForArgument");
}
TensorShape shape;
shape.AddDim(arg.tensor_array_size);
shape.AppendShape(arg.shape);
xla::Shape buffer_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
*xla_shape = xla::ShapeUtil::MakeTupleShape(
{buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
return Status::OK();
}
case XlaResource::kInvalid:
return errors::Internal(
"Invalid resource type in XLAShapeForArgument()");
}
}
case XlaCompiler::Argument::kInvalid:
return errors::Internal("Invalid argument type in XLAShapeForArgument()");
}
}
namespace {
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
@ -275,8 +334,9 @@ Status BuildArguments(const Graph& graph,
// Argument numbers of arguments and resources that are to be passed to the
// XLA computation as runtime parameters.
std::vector<int> parameters, resources;
parameters.reserve(args.size());
input_mapping->clear();
input_mapping->reserve(args.size());
std::vector<int> resources;
resources.reserve(args.size());
// Fills in constant arguments, and computes non-constant argument order.
@ -290,18 +350,20 @@ Status BuildArguments(const Graph& graph,
// TODO(phawkins): this code assumes that resource arguments do not
// alias.
XlaResource* resource;
TF_RETURN_IF_ERROR(
context->CreateResource(arg.resource_kind, i, arg.name, arg.type,
xla::ComputationDataHandle(), &resource));
resource->set_tensor_array_size(arg.tensor_array_size);
TF_RETURN_IF_ERROR(context->CreateResource(
arg.resource_kind, i, arg.name, arg.type, arg.shape,
xla::ComputationDataHandle(),
/*tensor_array_size=*/arg.tensor_array_size,
/*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
arg_expression.set_resource(resource);
if (arg.initialized) {
resources.push_back(i);
}
break;
case XlaCompiler::Argument::kParameter:
parameters.push_back(i);
case XlaCompiler::Argument::kParameter: {
input_mapping->push_back(i);
break;
}
case XlaCompiler::Argument::kConstant:
arg_expression.set_constant_value(arg.constant_value);
break;
@ -312,19 +374,17 @@ Status BuildArguments(const Graph& graph,
// Append parameters containing variable values after the other runtime
// parameters.
parameters.insert(parameters.end(), resources.begin(), resources.end());
if (parameters.empty()) {
input_mapping->insert(input_mapping->end(), resources.begin(),
resources.end());
if (input_mapping->empty()) {
return Status::OK();
}
std::vector<xla::Shape> arg_shapes;
arg_shapes.reserve(parameters.size());
input_mapping->resize(parameters.size());
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const XlaCompiler::Argument& arg = args[parameters[i]];
std::vector<xla::Shape> arg_shapes(input_mapping->size());
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
// Computes the shapes of non-constant arguments.
arg_shapes.push_back(arg.shape);
(*input_mapping)[i] = parameters[i];
TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument(
args[(*input_mapping)[i]], &arg_shapes[i]));
}
if (use_tuple_arg) {
@ -354,13 +414,13 @@ Status BuildArguments(const Graph& graph,
}
// Build parameter handles for non-constant arguments.
std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
if (use_tuple_arg) {
xla::ComputationDataHandle tuple;
if (is_entry_computation) {
xla::OpSharding tuple_sharding;
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
for (int64 parameter : parameters) {
for (int64 parameter : *input_mapping) {
const int core = (*arg_cores)[parameter];
const int root_device = 0;
*tuple_sharding.add_tuple_shardings() =
@ -373,16 +433,16 @@ Status BuildArguments(const Graph& graph,
} else {
tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
}
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const int core = (*arg_cores)[parameters[i]];
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
xla::ScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
arg_handles[i] = builder->GetTupleElement(tuple, i);
}
} else {
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const int core = (*arg_cores)[parameters[i]];
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const int core = (*arg_cores)[input_mapping->at(i)];
xla::ScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
@ -393,19 +453,18 @@ Status BuildArguments(const Graph& graph,
// Fill in the handles in non-constant arguments.
VLOG(2) << "XLA computation inputs:";
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const XlaCompiler::Argument& arg = args[parameters[i]];
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
VLOG(2) << " XLA arg " << i
<< " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
<< " name: " << arg.name << " TF arg " << parameters[i];
XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
<< " name: " << arg.name << " TF arg " << input_mapping->at(i);
XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)];
switch (arg.kind) {
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
XlaResource* resource = arg_expression.resource();
TF_RETURN_IF_ERROR(
resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i],
/*reset_initial_values=*/true, builder));
TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
arg_handles[i], builder));
VLOG(2) << " resource: num_gradients: "
<< arg.tensor_array_gradients.size();
break;
@ -486,6 +545,7 @@ Status BuildComputation(
XlaCompiler::ResourceUpdate& update = resource_updates->back();
update.input_index = resource->arg_num();
update.type = resource->type();
update.shape = resource->shape();
update.modified = modified;
for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first);
@ -616,13 +676,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
++computation_output;
}
}
for (std::vector<ResourceUpdate>::size_type i = 0;
i < result->resource_updates.size(); ++i) {
result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
result->xla_output_shape, computation_output);
++computation_output;
}
return Status::OK();
}

View File

@ -104,9 +104,17 @@ class XlaCompiler {
// is the type of the variable's value, not DT_RESOURCE.
DataType type;
// The shape of the argument. If the argument is a resource, this is the
// shape of the resource's value.
xla::Shape shape;
// The shape of the argument. For:
// * a parameter: the shape of the parameter.
// * a constant: ignored; the shape given by constant_value is used
// instead.
// * an uninitialized resource: ignored. We don't yet know the shape of an
// uninitialized resource (otherwise we would have initialized it!)
// * an initialized variable: the shape of the variable's value.
// * an initialized TensorArray or Stack resource: the shape of an entry in
// the TensorArray/Stack. Note this is the size of a single entry, not the
// XLA data structure that represents the complete stack/array.
TensorShape shape;
// The value of the argument, if it is a compile-time constant. Must be a
// host-memory tensor.
@ -175,8 +183,9 @@ class XlaCompiler {
int input_index;
// Type and shape of the tensor to be written back.
// The `shape` field has the same meaning as the Argument::shape field.
DataType type;
xla::Shape shape;
TensorShape shape;
// Was the value of the variable modified by the computation?
// (Always true, unless `return_updated_values_for_all_resources` is true.)
@ -235,6 +244,19 @@ class XlaCompiler {
// device is created, and can be used to create metadata objects
// that can be accessed by XLA op kernels.
std::function<Status(ResourceMgr*)>* populate_resource_manager = nullptr;
// If not nullptr, this memory allocator can be used by the compiler for
// temporary allocations it might want to make during compilation.
//
// For example, the compiler may want to try out different algorithms and
// choose the fastest one, and it might run those algorithms over buffers
// created using this allocator.
//
// The compiler can function correctly without an explicit allocator given
// here, but on some devices (notably, GPUs), TensorFlow tends to eagerly
// allocate most or all available memory on the device, leaving none for the
// compiler to access, unless it can use TensorFlow's allocator.
xla::DeviceMemoryAllocator* device_allocator = nullptr;
};
explicit XlaCompiler(Options options);
@ -253,11 +275,10 @@ class XlaCompiler {
const std::vector<Argument>& args,
CompilationResult* result);
Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func,
const std::vector<DataType>& types,
const std::vector<TensorShape>& shapes,
const std::vector<const XlaExpression*>& expressions,
std::vector<Argument>* args);
// Returns the shape of the XLA parameter for an argument 'arg'.
// See the class comment for more details about the argument passing
// convention.
static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.

View File

@ -191,10 +191,10 @@ TEST_F(XlaCompilerTest, Simple) {
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
@ -242,10 +242,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
@ -281,7 +281,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
args[0].shape = TensorShape({2});
XlaCompiler::Options options = DefaultOptions();
XlaCompiler compiler(options);
@ -373,7 +373,7 @@ TEST_F(XlaCompilerTest, ResourceManager) {
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
args[0].shape = TensorShape({2});
DummyResourceForTest* resource = new DummyResourceForTest();
@ -420,7 +420,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
args[0].shape = TensorShape({2});
// Compiles the graph.
auto options = DefaultOptions();
@ -472,9 +472,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {2}),
xla::ShapeUtil::MakeShape(xla::S32, {2})});
args[0].shape = TensorShape({});
args[0].tensor_array_size = 2;
args[0].tensor_array_gradients = {"grad2"};
@ -540,9 +538,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {2}),
xla::ShapeUtil::MakeShape(xla::S32, {2})});
args[0].shape = TensorShape({});
args[0].tensor_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
@ -574,9 +570,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {2}),
xla::ShapeUtil::MakeShape(xla::S32, {2})});
args[0].shape = TensorShape({});
args[0].tensor_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};

View File

@ -103,12 +103,14 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
string name, DataType type,
const xla::ComputationDataHandle& handle,
XlaResource** resource) {
Status XlaContext::CreateResource(
XlaResource::Kind kind, int arg_num, string name, DataType type,
TensorShape shape, const xla::ComputationDataHandle& handle,
int64 tensor_array_size, const std::set<string>& tensor_array_gradients,
XlaResource** resource) {
resources_.emplace_back(
new XlaResource(kind, arg_num, std::move(name), type, handle));
new XlaResource(kind, arg_num, std::move(name), type, std::move(shape),
handle, tensor_array_size, tensor_array_gradients));
*resource = resources_.back().get();
return Status::OK();
}

View File

@ -71,11 +71,15 @@ class XlaContext : public ResourceBase {
Status AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal);
// Creates a resource with resource `kind` and initial type `type` and
// value `handle`. `name` is a descriptive name for use in error messages.
// Creates a resource with resource `kind` and initial value `handle`. `name`
// is a descriptive name for use in error messages. See the `XlaResource`
// constructor for a description of the remaining arguments.
// Fails if the resource already exists.
Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
DataType type, const xla::ComputationDataHandle& handle,
DataType type, TensorShape shape,
const xla::ComputationDataHandle& handle,
int64 tensor_array_size,
const std::set<string>& tensor_array_gradients,
XlaResource** resource);
const std::vector<std::unique_ptr<XlaResource>>& resources() {

View File

@ -286,7 +286,8 @@ Status XlaOpKernelContext::ConstantInputList(
}
Status XlaOpKernelContext::ReadVariableInput(
int index, xla::ComputationDataHandle* value) {
int index, DataType type, TensorShape* shape,
xla::ComputationDataHandle* value) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
@ -296,7 +297,15 @@ Status XlaOpKernelContext::ReadVariableInput(
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name());
}
if (variable->type() != type) {
return errors::InvalidArgument(
"Type mismatch for read of variable ", variable->name(), ". Expected ",
DataTypeString(type), "; got ", DataTypeString(variable->type()));
}
*value = variable->value();
if (shape) {
*shape = variable->shape();
}
return Status::OK();
}
@ -312,12 +321,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
variable->name());
}
*type = variable->type();
auto shape_or_status = builder()->GetShape(variable->value());
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape));
*shape = variable->shape();
return Status::OK();
}
@ -405,7 +409,17 @@ Status XlaOpKernelContext::AssignVariable(
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
return variable->SetValue(type, handle);
auto shape_or_status = builder()->GetShape(handle);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
TensorShape shape;
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
return variable->SetValue(handle);
}
XlaCompiler* XlaOpKernelContext::compiler() const {

View File

@ -164,11 +164,16 @@ class XlaOpKernelContext {
TensorShape* shape) const;
// Reads the current value of the resouce variable referred to by input
// 'index'.
Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
// 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the
// variable. Returns an error if the variable has not been initialized, or if
// its type does not match `type`.
Status ReadVariableInput(int index, DataType type, TensorShape* shape,
xla::ComputationDataHandle* value);
// Assigns the value `handle` to the variable referenced by input
// `input_index`. Marks the operator as having side effects.
// `input_index`. The variable must be of `type`. Returns an error if the
// variable has been initialized with a different type or with a
// different shape.
Status AssignVariable(int input_index, DataType type,
const xla::ComputationDataHandle& handle);

View File

@ -25,51 +25,99 @@ limitations under the License.
namespace tensorflow {
XlaResource::XlaResource(Kind kind, int arg_num, string name,
DataType initial_type,
const xla::ComputationDataHandle& initial_value)
XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
TensorShape shape,
const xla::ComputationDataHandle& initial_value,
int64 tensor_array_size,
const std::set<string>& tensor_array_gradients)
: kind_(kind),
arg_num_(arg_num),
name_(std::move(name)),
type_(initial_type),
type_(type),
shape_(std::move(shape)),
value_(initial_value),
initial_value_(initial_value) {
initial_value_(initial_value),
tensor_array_size_(tensor_array_size) {
CHECK(kind_ != kInvalid);
for (const string& gradient : tensor_array_gradients) {
tensor_array_gradients_[gradient].reset(
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
type_, shape_, xla::ComputationDataHandle(),
tensor_array_size_, /*tensor_array_gradients=*/{}));
}
}
Status XlaResource::SetValue(DataType type,
const xla::ComputationDataHandle& value) {
if (type_ == DT_INVALID && type == DT_INVALID) {
return errors::InvalidArgument("Attempted to initialized resource ", name_,
" to an invalid type");
Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
if (type == DT_INVALID) {
return errors::InvalidArgument("Attempted to set type of resource '", name_,
"'' to an invalid type");
}
if (type_ != DT_INVALID && type_ != type) {
if (initialized() && type_ != type) {
return errors::InvalidArgument("Type of resource ", name_,
" cannot be changed after initialization: "
"old type was ",
DataTypeString(type_), ", new type is ",
DataTypeString(type));
}
if (initialized() && shape_ != shape) {
return errors::InvalidArgument("Shape of resource ", name_,
" cannot be changed after initialization: "
"old shape was ",
shape_.DebugString(), ", new shape is ",
shape.DebugString());
}
type_ = type;
shape_ = shape;
return Status::OK();
}
Status XlaResource::SetValue(const xla::ComputationDataHandle& value) {
if (type_ == DT_INVALID) {
return errors::InvalidArgument(
"Resource '", name_,
"' must be initialized with a valid type before use.");
}
value_ = value;
return Status::OK();
}
Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder,
xla::Shape* shape) const {
auto shape_or_status = builder->GetShape(value_);
if (!shape_or_status.ok()) {
return shape_or_status.status();
Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) {
if (type_ == DT_INVALID) {
return errors::InvalidArgument(
"Resource '", name_,
"' must be initialized with a valid type before use.");
}
*shape = *shape_or_status.ValueOrDie();
return Status::OK();
}
switch (kind_) {
case kVariable: {
value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
shape_.dim_sizes());
break;
}
case kTensorArray: {
TensorShape ta_shape;
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
ta_shape.dim_sizes());
break;
}
case kStack: {
TensorShape ta_shape;
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
value_ =
builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_),
ta_shape.dim_sizes()),
builder->ConstantR0<int32>(0)});
break;
}
Status XlaResource::GetShape(xla::ComputationBuilder* builder,
TensorShape* shape) const {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape));
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape));
case kInvalid:
default:
LOG(FATAL) << "Invalid resource type";
}
return Status::OK();
}
@ -82,36 +130,20 @@ Status XlaResource::GetOrCreateTensorArrayGradient(
std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
if (!gradient) {
TensorShape ta_shape;
TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape));
ta_shape.AddDim(tensor_array_size_);
ta_shape.AppendShape(shape_);
xla::ComputationDataHandle gradient_value = builder->Broadcast(
XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
gradient.reset(
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
type_, gradient_value));
gradient->tensor_array_size_ = tensor_array_size_;
type_, shape_, gradient_value, tensor_array_size_,
/*tensor_array_gradients=*/{}));
}
*gradient_out = gradient.get();
return Status::OK();
}
Status XlaResource::PackedShape(xla::ComputationBuilder* builder,
xla::Shape* packed_shape) const {
if (tensor_array_gradients_.empty()) {
return GetXlaShape(builder, packed_shape);
}
TF_RET_CHECK(kind_ == kTensorArray);
std::vector<xla::Shape> elem_shapes(1 + tensor_array_gradients_.size());
int pos = 0;
TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++]));
for (const auto& gradient : tensor_array_gradients_) {
TF_RETURN_IF_ERROR(
gradient.second->GetXlaShape(builder, &elem_shapes[pos++]));
}
*packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
return Status::OK();
}
Status XlaResource::Pack(xla::ComputationDataHandle* pack,
xla::ComputationBuilder* builder) const {
if (tensor_array_gradients_.empty()) {
@ -130,27 +162,32 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack,
Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
const xla::ComputationDataHandle& pack,
bool reset_initial_values,
xla::ComputationBuilder* builder) {
if (gradient_sources.empty()) {
if (!initialized()) {
initial_value_ = pack;
}
value_ = pack;
} else {
TF_RET_CHECK(kind_ == kTensorArray);
int pos = 0;
value_ = builder->GetTupleElement(pack, pos++);
auto v = builder->GetTupleElement(pack, pos++);
if (!initialized()) {
initial_value_ = v;
}
value_ = v;
for (const auto& source : gradient_sources) {
XlaResource* gradient;
TF_RETURN_IF_ERROR(
GetOrCreateTensorArrayGradient(source, builder, &gradient));
gradient->value_ = builder->GetTupleElement(pack, pos++);
if (reset_initial_values) {
gradient->initial_value_ = gradient->value_;
auto v = builder->GetTupleElement(pack, pos++);
if (!gradient->initialized()) {
gradient->initial_value_ = v;
}
gradient->value_ = v;
}
}
if (reset_initial_values) {
initial_value_ = value_;
}
return Status::OK();
}

View File

@ -36,8 +36,11 @@ class XlaResource {
kStack,
};
XlaResource(Kind kind, int arg_num, string name, DataType initial_type,
const xla::ComputationDataHandle& initial_value);
XlaResource(Kind kind, int arg_num, string name, DataType type,
TensorShape shape,
const xla::ComputationDataHandle& initial_value,
int64 tensor_array_size,
const std::set<string>& tensor_array_gradients);
XlaResource(const XlaResource&) = delete;
XlaResource(XlaResource&&) = delete;
@ -60,6 +63,12 @@ class XlaResource {
// a resource is first initialized we do not yet know its type, so we keep
// track of its type dynamically.
DataType type() const { return type_; }
// Shape of the resource. For an uninitialized resource, this is ignored.
// For a Variable, this is the shape of the value. For a TensorArray or Stack
// this is the shape of each entry in the TensorArray/Stack.
const TensorShape& shape() const { return shape_; }
const xla::ComputationDataHandle& value() const { return value_; }
// Value of the resource at computation entry. Used to detect which
@ -68,17 +77,19 @@ class XlaResource {
return initial_value_;
}
// A variable is initialized if it has a value.
bool initialized() const { return value_.handle() > 0; }
// Sets the current type/value of the resource.
Status SetValue(DataType type, const xla::ComputationDataHandle& value);
// Sets the type and shape of the resource. The type and shape of a resource
// must not change once the variable has been initialized.
Status SetTypeAndShape(DataType type, const TensorShape& shape);
// Returns the shape of the resource as an xla::Shape.
Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const;
// Sets the current value of the resource. Returns an error if the type is not
// set to a valid value.
Status SetValue(const xla::ComputationDataHandle& value);
// Returns the shape of the resource as an TensorShape. Fails if the shape is
// not representable as a TensorShape.
Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const;
// Sets the current value of the resource to an all-zero value.
Status SetZeroValue(xla::ComputationBuilder* builder);
// Looks up the gradient for `source`, or creates it if it does not already
// exist. The call target must be an initialized TensorArray resource. A
@ -96,10 +107,6 @@ class XlaResource {
Status Pack(xla::ComputationDataHandle* pack,
xla::ComputationBuilder* builder) const;
// Returns the shape of the `pack` value computed by `Pack()`.
Status PackedShape(xla::ComputationBuilder* builder,
xla::Shape* packed_shape) const;
// Updates the resource with values from `pack`. If `gradient_sources` is
// non-empty, treats `pack` as a tuple that represents a TensorArray and
// its gradients, and unpacks and updates the gradient resources.
@ -108,14 +115,14 @@ class XlaResource {
// Opposite of Pack().
Status SetFromPack(const std::set<string>& gradient_sources,
const xla::ComputationDataHandle& pack,
bool reset_initial_values,
xla::ComputationBuilder* builder);
// TensorArray-specific fields
// TensorArray and Stack specific fields
// 'tensor_array_size' stores the expected size of the TensorArray or Stack.
// We need to store this since sometimes TensorArrays must be initialized
// lazily since we do not know the element shape at construction time.
// Used by both TensorArrays and Stacks.
int64 tensor_array_size() const { return tensor_array_size_; }
void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
@ -136,6 +143,7 @@ class XlaResource {
const string name_;
DataType type_;
TensorShape shape_;
xla::ComputationDataHandle value_;
xla::ComputationDataHandle initial_value_;

View File

@ -88,7 +88,6 @@ cc_library(
visibility = [":friends"],
deps = [
"//tensorflow/core:framework_lite",
"//tensorflow/core:lib",
"//third_party/eigen3",
],
)

View File

@ -80,6 +80,18 @@ cc_library(
],
)
cc_library(
name = "executable_build_options",
srcs = ["executable_build_options.cc"],
hdrs = ["executable_build_options.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/core:lib",
],
)
cc_library(
name = "local_client",
srcs = ["local_client.cc"],
@ -87,6 +99,7 @@ cc_library(
deps = [
":client",
":computation",
":executable_build_options",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",

View File

@ -0,0 +1,79 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator(
DeviceMemoryAllocator* allocator) {
device_allocator_ = allocator;
return *this;
}
DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const {
return device_allocator_;
}
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
int device_ordinal) {
CHECK_GE(device_ordinal, 0);
device_ordinal_ = device_ordinal;
return *this;
}
int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; }
ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout(
const Shape& shape_with_layout) {
result_layout_set_ = true;
result_layout_ = shape_with_layout;
return *this;
}
const Shape* ExecutableBuildOptions::result_layout() const {
return result_layout_set_ ? &result_layout_ : nullptr;
}
string ExecutableBuildOptions::ToString() const {
string result_layout = "nullopt";
if (result_layout_set_) {
result_layout = ShapeUtil::HumanStringWithLayout(result_layout_);
}
string generate_hlo_graph = "nullopt";
if (generate_hlo_graph_.has_value()) {
generate_hlo_graph = generate_hlo_graph_.value();
}
return tensorflow::strings::Printf(
"ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, "
"generate_hlo_graph=%s}",
device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str());
}
ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph(
string regex) {
generate_hlo_graph_ = std::move(regex);
return *this;
}
const tensorflow::gtl::optional<string>&
ExecutableBuildOptions::generate_hlo_graph() const {
return generate_hlo_graph_;
}
} // namespace xla

View File

@ -0,0 +1,74 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
// Class containing options for building an LocalExecutable with
// LocalClient::Compile.
class ExecutableBuildOptions {
public:
// If set, this is the device to build the computation for. Valid
// device_ordinal values are: 0 to # of devices - 1. These values are
// identical to the device ordinal values used by StreamExecutor. The built
// executable will be executable on any device equivalent to the specified
// device as determined by Backend::devices_equivalent(). A value of -1
// indicates this option has not been set.
ExecutableBuildOptions& set_device_ordinal(int device_ordinal);
int device_ordinal() const;
// If set, this specifies the layout of the result of the computation. If not
// set, the service will chose the layout of the result. A Shape is used to
// store the layout to accommodate tuple result shapes. A value of nullptr
// indicates the option has not been set.
ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
const Shape* result_layout() const;
// If set, this specifies an allocator that can be used to allocate temporary
// space on the device during compilation. For example, the compiler might
// want to run various algorithms on the device and pick the fastest one -- it
// might allocate buffers for use by these algorithms using this allocator.
//
// This does not need to be the same as the DeviceMemoryAllocator passed when
// running the executable.
ExecutableBuildOptions& set_device_allocator(
DeviceMemoryAllocator* allocator);
DeviceMemoryAllocator* device_allocator() const;
// If set, specifies a regexp of HLO graphs to dump (as in DebugOptions).
ExecutableBuildOptions& set_generate_hlo_graph(string regex);
const tensorflow::gtl::optional<string>& generate_hlo_graph() const;
// Returns a string representation of the build options, suitable for
// debugging.
string ToString() const;
private:
int device_ordinal_ = -1;
Shape result_layout_;
bool result_layout_set_ = false;
tensorflow::gtl::optional<string> generate_hlo_graph_;
DeviceMemoryAllocator* device_allocator_ = nullptr;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_

View File

@ -30,25 +30,6 @@ using xla::source_map_util::InvalidParameterArgument;
namespace xla {
ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
int device_ordinal) {
device_ordinal_ = device_ordinal;
return *this;
}
int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; }
ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout(
const Shape& shape_with_layout) {
result_layout_set_ = true;
result_layout_ = shape_with_layout;
return *this;
}
const Shape* ExecutableBuildOptions::result_layout() const {
return result_layout_set_ ? &result_layout_ : nullptr;
}
namespace {
StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
Backend* backend) {
@ -60,16 +41,18 @@ StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
} // namespace
LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
Backend* backend, int device_ordinal,
const ExecutableBuildOptions& build_options)
Backend* backend,
ExecutableBuildOptions build_options)
: executable_(std::move(executable)),
backend_(backend),
build_device_ordinal_(device_ordinal),
build_options_(build_options) {}
build_options_(std::move(build_options)) {
CHECK_GE(build_options_.device_ordinal(), 0)
<< "Must have a valid device ordinal that the executable was built for.";
}
tensorflow::Status LocalExecutable::ValidateExecutionOptions(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& options, const Backend& backend) {
const ExecutableRunOptions& run_options, const Backend& backend) {
const ComputationLayout& computation_layout =
executable_->module_config().entry_computation_layout();
@ -93,14 +76,14 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
}
}
if (options.stream() != nullptr) {
if (!options.stream()->ok()) {
if (run_options.stream() != nullptr) {
if (!run_options.stream()->ok()) {
return InvalidArgument("stream is uninitialized or in an error state");
}
// Check stream matches service platform.
const se::Platform* stream_platform =
options.stream()->parent()->platform();
run_options.stream()->parent()->platform();
if (stream_platform != backend_->platform()) {
return InvalidArgument(
"stream is for platform %s, but service targets platform %s",
@ -110,7 +93,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
// Cannot specify device_ordinal with a stream. The stream determines these
// values.
if (options.device_ordinal() != -1) {
if (run_options.device_ordinal() != -1) {
return InvalidArgument(
"cannot set both device ordinal and stream options in "
"ExecutableRunOptions; the stream determines the device ordinal");
@ -119,34 +102,34 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
// Verify that the device the executable was built for is equivalent to the
// device it will run on.
int run_device_ordinal = options.device_ordinal() == -1
int run_device_ordinal = run_options.device_ordinal() == -1
? backend_->default_device_ordinal()
: options.device_ordinal();
TF_ASSIGN_OR_RETURN(
bool devices_equivalent,
backend_->devices_equivalent(run_device_ordinal, build_device_ordinal_));
: run_options.device_ordinal();
TF_ASSIGN_OR_RETURN(bool devices_equivalent,
backend_->devices_equivalent(
run_device_ordinal, build_options_.device_ordinal()));
if (!devices_equivalent) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
backend_->stream_executor(run_device_ordinal));
TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
backend_->stream_executor(build_device_ordinal_));
backend_->stream_executor(build_device_ordinal()));
return InvalidArgument(
"executable is built for device %s of type \"%s\"; cannot run it on "
"device %s of type \"%s\"",
backend_->device_name(build_device_ordinal_).c_str(),
backend_->device_name(build_device_ordinal()).c_str(),
build_executor->GetDeviceDescription().name().c_str(),
backend_->device_name(run_device_ordinal).c_str(),
run_executor->GetDeviceDescription().name().c_str());
}
if (!options.allocator()) {
if (!run_options.allocator()) {
return InvalidArgument("an allocator must be provided to ExecuteLocally");
}
if (options.allocator()->platform() != backend.platform()) {
if (run_options.allocator()->platform() != backend.platform()) {
return InvalidArgument(
"allocator platform (%s) does not match service platform (%s)",
options.allocator()->platform()->Name().c_str(),
run_options.allocator()->platform()->Name().c_str(),
backend.platform()->Name().c_str());
}
@ -155,23 +138,22 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions(
StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& options) {
TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_));
ExecutableRunOptions actual_options = options;
ExecutableRunOptions run_options) {
TF_RETURN_IF_ERROR(
ValidateExecutionOptions(arguments, run_options, *backend_));
Backend::StreamPtr stream;
if (options.stream() == nullptr) {
if (run_options.stream() == nullptr) {
// NB! The lifetime of `stream` needs to match the lifetime of
// `actual_options` (otherwise we will end up using a returned stream in
// ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if"
// scope.
TF_ASSIGN_OR_RETURN(
stream, BorrowStreamForDevice(options.device_ordinal(), backend_));
actual_options.set_stream(stream.get());
stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_));
run_options.set_stream(stream.get());
}
if (options.allocator() == nullptr) {
actual_options.set_allocator(backend_->memory_allocator());
if (run_options.allocator() == nullptr) {
run_options.set_allocator(backend_->memory_allocator());
}
// For local client execution on CPU backends:
@ -180,7 +162,7 @@ StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
// *) The thread pool used for XLA CPU ops is from
// backend_->eigen_intra_op_thread_pool().
ServiceExecutableRunOptions service_options(
actual_options, backend_->StreamBorrower(),
run_options, backend_->StreamBorrower(),
backend_->eigen_intra_op_thread_pool());
if (executable_->dumping()) {
@ -189,9 +171,8 @@ StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::Run(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<ShapedBuffer> result,
executable_->ExecuteOnStreamWrapper(
&service_options, options.execution_profile(), arguments));
return ScopedShapedBuffer::MakeScoped(result.get(),
actual_options.allocator());
&service_options, run_options.execution_profile(), arguments));
return ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator());
}
StatusOr<std::unique_ptr<ScopedShapedBuffer>> LocalExecutable::ExecuteAndDump(
@ -267,16 +248,19 @@ StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
const Computation& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
const ExecutableBuildOptions& options) {
int device_ordinal = options.device_ordinal() == -1
? default_device_ordinal()
: options.device_ordinal();
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
local_service_->CompileExecutable(
computation.handle(), argument_layouts,
options.result_layout(), device_ordinal));
ExecutableBuildOptions updated_options = options;
if (options.device_ordinal() == -1) {
updated_options.set_device_ordinal(default_device_ordinal());
VLOG(3) << "Set device ordinal to default value of: "
<< updated_options.device_ordinal();
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
local_service_->CompileExecutable(computation.handle(), argument_layouts,
updated_options));
return WrapUnique(new LocalExecutable(std::move(executable),
local_service_->mutable_backend(),
device_ordinal, options));
updated_options));
}
StatusOr<std::unique_ptr<ScopedShapedBuffer>>

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@ -33,39 +34,13 @@ limitations under the License.
namespace xla {
// Class containing options for building an LocalExecutable with
// LocalClient::Compile.
class ExecutableBuildOptions {
public:
// If set, this is the device to build the computation for. Valid
// device_ordinal values are: 0 to # of devices - 1. These values are
// identical to the device ordinal values used by StreamExecutor. The built
// executable will be executable on any device equivalent to the specified
// device as determined by Backend::devices_equivalent(). A value of -1
// indicates this option has not been set.
ExecutableBuildOptions& set_device_ordinal(int device_ordinal);
int device_ordinal() const;
// If set, this specifies the layout of the result of the computation. If not
// set, the service will chose the layout of the result. A Shape is used to
// store the layout to accommodate tuple result shapes. A value of nullptr
// indicates the option has not been set.
ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
const Shape* result_layout() const;
private:
int device_ordinal_ = -1;
Shape result_layout_;
bool result_layout_set_ = false;
};
class LocalExecutable {
public:
// Run the compiled computation with the given arguments and options and
// return the result.
StatusOr<std::unique_ptr<ScopedShapedBuffer>> Run(
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& options);
ExecutableRunOptions run_options);
// Return the layout (contained in a shape) of the result produced by the
// computation.
@ -88,8 +63,7 @@ class LocalExecutable {
// Constructor invoked by LocalClient.
LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
int device_ordinal,
const ExecutableBuildOptions& build_options);
ExecutableBuildOptions build_options);
// Validates that the given arguments and options satisfy various constraints
// of the computation.
@ -117,19 +91,19 @@ class LocalExecutable {
StatusOr<std::unique_ptr<Literal>> LiteralFromShapedBuffer(
const ShapedBuffer& shaped_buffer);
// The ordinal of the device which this executable was compiled for. The
// executable can run on all equivalent devices (as determined by
// Backend::devices_equivalent).
int build_device_ordinal() const { return build_options_.device_ordinal(); }
// Compiled computation.
std::unique_ptr<Executable> executable_;
// Execution backend.
Backend* backend_;
// The ordinal of the device which this executable was compiled for. The
// executable can run on all equivalent devices (as determined by
// Backend::devices_equivalent).
int build_device_ordinal_;
Backend* backend_ = nullptr;
// Options used to build the executable.
const ExecutableBuildOptions& build_options_;
const ExecutableBuildOptions build_options_;
};
// An XLA Client specialization for use when the client and service run in

View File

@ -221,13 +221,19 @@ void AllocateFlags() {
flag_values->xla_gpu_disable_multi_streaming(),
"If true, multi-streaming in the GPU backend is disabled."),
tensorflow::Flag(
"xla_dump_hlo_proto_to", flag_values->mutable_xla_dump_hlo_proto_to(),
"Dump compilation artifacts as proto binary into this directory."),
"xla_dump_optimized_hlo_proto_to",
flag_values->mutable_xla_dump_optimized_hlo_proto_to(),
"Dump Hlo after all hlo passes are executed as proto binary into "
"this directory."),
tensorflow::Flag(
"xla_dump_prepass_hlo_proto_to",
flag_values->mutable_xla_dump_prepass_hlo_proto_to(),
"Dump compilation artifacts, before hlo passes are executed, as "
"proto binary into this directory."),
"xla_dump_unoptimized_hlo_proto_to",
flag_values->mutable_xla_dump_unoptimized_hlo_proto_to(),
"Dump HLO before any hlo passes are executed as proto binary into "
"this directory."),
tensorflow::Flag("xla_dump_per_pass_hlo_proto_to",
flag_values->mutable_xla_dump_per_pass_hlo_proto_to(),
"Dump HLO after each pass as an HloProto in binary file "
"format into this directory."),
tensorflow::Flag(
"xla_test_all_output_layouts",
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),

View File

@ -486,6 +486,7 @@ class Literal {
std::vector<std::unique_ptr<Literal>> elements);
// Returns a string representation of the literal value.
// Warning: this function can take minutes for multi-million element Literals.
string ToString(bool print_layout = false) const;
// Invokes the "per cell" callback for each element in the provided

View File

@ -49,6 +49,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/core:framework_lite",

View File

@ -98,15 +98,25 @@ const std::unique_ptr<ScopedShapedBuffer>& LocalShapedBuffer::shaped_buffer()
return shaped_buffer_;
}
static StatusOr<std::unique_ptr<ScopedShapedBuffer>> ToBuffer(
LocalClient* client, int device_ordinal, const Literal& arg) {
return client->LiteralToShapedBuffer(arg, device_ordinal,
client->backend().memory_allocator());
}
/* static */
LocalShapedBuffer* LocalShapedBuffer::FromLiteral(const Literal& argument) {
LocalShapedBuffer* LocalShapedBuffer::FromLiteral(
const Literal& argument,
const tensorflow::gtl::optional<Shape>& shape_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
std::unique_ptr<ScopedShapedBuffer> buf =
client
->LiteralToShapedBuffer(argument,
/*device_ordinal=*/0,
client->backend().memory_allocator())
.ConsumeValueOrDie();
std::unique_ptr<ScopedShapedBuffer> buf;
if (shape_with_layout) {
std::unique_ptr<Literal> relaid =
argument.Relayout(shape_with_layout.value());
buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie();
} else {
buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie();
}
return new LocalShapedBuffer(std::move(buf));
}
@ -120,7 +130,8 @@ CompiledLocalComputation::CompiledLocalComputation(
: executable_(std::move(executable)) {}
StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
const std::vector<Literal>& arguments) {
const std::vector<Literal>& arguments,
const std::vector<tensorflow::gtl::optional<Shape>>& shapes_with_layout) {
LocalClient* client = GetOrCreateLocalClient();
VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas.";
@ -133,7 +144,8 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
GetReplicaCount());
for (int replica = 0; replica < GetReplicaCount(); ++replica) {
pool.Schedule([this, client, replica, &arguments, &results] {
pool.Schedule([this, client, replica, &arguments, &shapes_with_layout,
&results] {
StatusOr<int> device_ordinal_status =
client->ReplicaNumberToDeviceOrdinal(replica);
if (!device_ordinal_status.ok()) {
@ -144,18 +156,28 @@ StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute(
VLOG(3) << "Replica " << replica
<< " mapped to device ordinal for execution: "
<< device_ordinal;
// Transfer arguments in
std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers;
scoped_buffers.reserve(arguments.size());
for (const Literal& argument : arguments) {
StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed =
client->LiteralToShapedBuffer(
argument, device_ordinal,
client->backend().memory_allocator());
for (int i = 0; i < arguments.size(); ++i) {
const Literal& argument = arguments[i];
const tensorflow::gtl::optional<Shape>& shape_with_layout =
shapes_with_layout[i];
StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed;
if (shape_with_layout) {
std::unique_ptr<Literal> relaid =
argument.Relayout(shape_with_layout.value());
pushed = ToBuffer(client, device_ordinal, *relaid);
} else {
pushed = ToBuffer(client, device_ordinal, argument);
}
if (!pushed.ok()) {
results[replica] = pushed.status();
return;
}
scoped_buffers.push_back(std::move(pushed).ValueOrDie());
}
@ -233,7 +255,8 @@ LocalComputation::LocalComputation(Computation computation)
: computation_(std::move(computation)) {}
StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
const std::vector<Shape>& argument_shapes) {
const std::vector<Shape>& argument_shapes,
const ExecutableBuildOptions* build_options) {
std::vector<const Shape*> argument_shape_pointers;
argument_shape_pointers.reserve(argument_shapes.size());
for (auto& argument_shape : argument_shapes) {
@ -242,6 +265,9 @@ StatusOr<CompiledLocalComputation*> LocalComputation::Compile(
LocalClient* client = GetOrCreateLocalClient();
ExecutableBuildOptions options;
if (build_options != nullptr) {
options = *build_options;
}
TF_ASSIGN_OR_RETURN(
auto local_executable,
client->Compile(computation_, argument_shape_pointers, options));
@ -363,12 +389,6 @@ LocalComputationBuilder::SelectAndScatterWithGeneralPadding(
source, init_value, scatter.computation());
}
ComputationDataHandle LocalComputationBuilder::Select(
const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
const ComputationDataHandle& on_false) {
return builder_.Select(pred, on_true, on_false);
}
ComputationDataHandle LocalComputationBuilder::Tuple(
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
return builder_.Tuple(elements);
@ -384,6 +404,12 @@ ComputationDataHandle LocalComputationBuilder::Dot(
return builder_.Dot(lhs, rhs);
}
ComputationDataHandle LocalComputationBuilder::DotGeneral(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
const DotDimensionNumbers& dimension_numbers) {
return builder_.DotGeneral(lhs, rhs, dimension_numbers);
}
ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
@ -483,6 +509,15 @@ ComputationDataHandle LocalComputationBuilder::While(
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \
(lhs, rhs, broadcast_dimensions))
#define _FORWARD_TRIOP(method_name) \
_FORWARD( \
method_name, ComputationDataHandle, \
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
const ComputationDataHandle& ehs), \
(lhs, rhs, ehs))
_FORWARD_TRIOP(Select)
_FORWARD_TRIOP(Clamp)
_FORWARD_BINOP(Eq)
_FORWARD_BINOP(Ne)
_FORWARD_BINOP(Ge)
@ -503,6 +538,7 @@ _FORWARD_UNOP(Abs)
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
@ -519,6 +555,7 @@ _FORWARD_UNOP(Sort)
#undef _FORWARD
#undef _FORWARD_UNOP
#undef _FORWARD_BINOP
#undef _FORWARD_TRIOP
void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) {
delete local_shaped_buffer;

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -58,7 +59,9 @@ StatusOr<std::unique_ptr<Literal> > TransferFromOutfeedLocalReplica(
// client.
class LocalShapedBuffer {
public:
static LocalShapedBuffer* FromLiteral(const Literal& argument);
static LocalShapedBuffer* FromLiteral(
const Literal& argument,
const tensorflow::gtl::optional<Shape>& shape_with_layout);
LocalShapedBuffer(std::unique_ptr<ScopedShapedBuffer> shaped_buffer);
const std::unique_ptr<ScopedShapedBuffer>& shaped_buffer() const;
std::unique_ptr<Literal> ToLiteral() const;
@ -76,8 +79,15 @@ class LocalShapedBuffer {
class CompiledLocalComputation {
public:
CompiledLocalComputation(std::unique_ptr<LocalExecutable> executable);
// Execute the computation with the given argument literals, and
// with optionally-specified argument layouts. The literals will be
// re-laid out according to the corresponding elements of
// shapes_with_layout.
StatusOr<std::unique_ptr<Literal> > Execute(
const std::vector<Literal>& arguments);
const std::vector<Literal>& arguments,
const std::vector<tensorflow::gtl::optional<Shape> >& shapes_with_layout);
LocalShapedBuffer* ExecuteWithShapedBuffers(
tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles);
@ -93,7 +103,8 @@ class LocalComputation {
public:
LocalComputation(Computation computation);
StatusOr<CompiledLocalComputation*> Compile(
const std::vector<Shape>& argument_shapes);
const std::vector<Shape>& argument_shapes,
const ExecutableBuildOptions* build_options);
const Computation& computation() const;
private:
@ -172,10 +183,6 @@ class LocalComputationBuilder {
const ComputationDataHandle& source,
const ComputationDataHandle& init_value, const LocalComputation& scatter);
ComputationDataHandle Select(const ComputationDataHandle& pred,
const ComputationDataHandle& on_true,
const ComputationDataHandle& on_false);
ComputationDataHandle Tuple(
tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
@ -185,6 +192,10 @@ class LocalComputationBuilder {
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs);
ComputationDataHandle DotGeneral(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
const DotDimensionNumbers& dimension_numbers);
ComputationDataHandle ConvGeneralDilated(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
@ -252,6 +263,14 @@ class LocalComputationBuilder {
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions))
#define _FORWARD_TRIOP(method_name) \
_FORWARD( \
method_name, ComputationDataHandle, \
(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \
const ComputationDataHandle& ehs))
_FORWARD_TRIOP(Select)
_FORWARD_TRIOP(Clamp)
_FORWARD_BINOP(Eq)
_FORWARD_BINOP(Ne)
_FORWARD_BINOP(Ge)
@ -272,6 +291,7 @@ class LocalComputationBuilder {
_FORWARD_UNOP(Exp)
_FORWARD_UNOP(Floor)
_FORWARD_UNOP(Ceil)
_FORWARD_UNOP(Round)
_FORWARD_UNOP(Log)
_FORWARD_UNOP(Sign)
_FORWARD_UNOP(Cos)
@ -288,6 +308,7 @@ class LocalComputationBuilder {
#undef _FORWARD
#undef _FORWARD_UNOP
#undef _FORWARD_BINOP
#undef _FORWARD_TRIOP
private:
ComputationBuilder builder_;

View File

@ -27,12 +27,14 @@ limitations under the License.
// ArraySlice<ComputationDataHandle> <- sequence of int
// Literal <-> (nested tuple of) numpy ndarray
// std::vector<Literal> <- sequence of (nested tuple of) ndarray
// Shape <-> pair holding (dtype, dimensions)
// std::vector<Shape> <- sequence of shape information pairs
// Shape -> pair holding (dtype, dimensions)
// <- object duck-typed as xla_client.Shape
// std::vector<Shape> <- sequence of xla_client.Shape objects
// PrimitiveType <- int
// ArraySlice<pair<int64, in64>> <- sequence of int pairs
// PaddingConfig proto <- corresponding Python proto
// ConvolutionDimensionNumbers proto <- corresponding Python proto
// DotDimensionNumbers proto <- corresponding Python proto
//
// Arrows indicate whether a conversion only ever occurs in one
// direction, or whether it is maintained bidirectionally.
@ -55,7 +57,7 @@ limitations under the License.
// translates to a tuple-shaped XLA Literal, whose component subshapes
// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
//
// The Python objects corresponding to C++ Shapes have the type:
// Shapes output by C++ become Python objects with the type:
//
// T = (dtype, S)
// S = DIMENSIONS | TUPLE_SHAPES
@ -176,6 +178,16 @@ tensorflow::ImportNumpy();
}
}
%typemap(out) StatusOr< std::unique_ptr<Literal> > {
if ($1.ok()) {
std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
$result = numpy::PyObjectFromXlaLiteral(*value);
} else {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
return NULL;
}
}
%typemap(out) StatusOr<xla::swig::LocalComputation*> {
if ($1.ok()) {
auto* value = $1.ValueOrDie();
@ -343,15 +355,31 @@ tensorflow::ImportNumpy();
// Shape
%typemap(in) const Shape& (Shape temp) {
Status shape_status = numpy::CheckPyShapeInfo($input);
if (!shape_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str());
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
if (!statusor.ok()) {
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
return NULL;
}
temp = numpy::XlaShapeFromPyShapeInfo($input);
temp = std::move(statusor).ValueOrDie();
$1 = &temp;
}
%typemap(in) const tensorflow::gtl::optional<Shape>& (
tensorflow::gtl::optional<Shape> temp) {
if ($input == Py_None) {
temp = tensorflow::gtl::nullopt;
$1 = &temp;
} else {
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
if (!statusor.ok()) {
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
return NULL;
}
temp = std::move(statusor).ValueOrDie();
$1 = &temp;
}
}
%typemap(out) std::unique_ptr<Shape> {
$result = numpy::PyShapeInfoFromXlaShape(*$1);
}
@ -364,14 +392,37 @@ tensorflow::ImportNumpy();
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
Status shape_status = numpy::CheckPyShapeInfo(o);
if (!shape_status.ok()) {
PyErr_SetString(PyExc_RuntimeError, shape_status.ToString().c_str());
Py_DECREF(o);
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
Py_DECREF(o);
if (!statusor.ok()) {
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
return NULL;
}
temps.push_back(numpy::XlaShapeFromPyShapeInfo(o));
Py_DECREF(o);
temps.push_back(statusor.ConsumeValueOrDie());
}
$1 = &temps;
}
%typemap(in) const std::vector<tensorflow::gtl::optional<Shape> >& (
std::vector<tensorflow::gtl::optional<Shape> > temps) {
if (!PySequence_Check($input)) {
PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
return NULL;
}
const int size = PySequence_Size($input);
for (int i = 0; i < size; ++i) {
PyObject* o = PySequence_GetItem($input, i);
if (o == Py_None) {
temps.push_back(tensorflow::gtl::nullopt);
} else {
StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
Py_DECREF(o);
if (!statusor.ok()) {
PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
return NULL;
}
temps.push_back(statusor.ConsumeValueOrDie());
}
}
$1 = &temps;
}
@ -461,6 +512,135 @@ tensorflow::ImportNumpy();
$1 = temps;
}
// DotDimensionNumbers
%typemap(in) const DotDimensionNumbers&
(DotDimensionNumbers dimension_numbers) {
int length;
/* lhs_contracting_dimensions */
PyObject* lhs_contracting_dimensions = PyObject_GetAttrString(
$input, "lhs_contracting_dimensions");
if (!lhs_contracting_dimensions) {
return NULL;
}
length = PySequence_Size(lhs_contracting_dimensions);
if (length == -1) {
Py_DECREF(lhs_contracting_dimensions);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i);
if (!item) {
Py_DECREF(lhs_contracting_dimensions);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(lhs_contracting_dimensions);
return NULL;
}
dimension_numbers.add_lhs_contracting_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(lhs_contracting_dimensions);
/* rhs_contracting_dimensions */
PyObject* rhs_contracting_dimensions = PyObject_GetAttrString(
$input, "rhs_contracting_dimensions");
if (!lhs_contracting_dimensions) {
return NULL;
}
length = PySequence_Size(rhs_contracting_dimensions);
if (length == -1) {
Py_DECREF(rhs_contracting_dimensions);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i);
if (!item) {
Py_DECREF(rhs_contracting_dimensions);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(rhs_contracting_dimensions);
return NULL;
}
dimension_numbers.add_rhs_contracting_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(rhs_contracting_dimensions);
/* lhs_batch_dimensions */
PyObject* lhs_batch_dimensions = PyObject_GetAttrString(
$input, "lhs_batch_dimensions");
if (!lhs_batch_dimensions) {
return NULL;
}
length = PySequence_Size(lhs_batch_dimensions);
if (length == -1) {
Py_DECREF(lhs_batch_dimensions);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i);
if (!item) {
Py_DECREF(lhs_batch_dimensions);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(lhs_batch_dimensions);
return NULL;
}
dimension_numbers.add_lhs_batch_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(lhs_batch_dimensions);
/* rhs_batch_dimensions */
PyObject* rhs_batch_dimensions = PyObject_GetAttrString(
$input, "rhs_batch_dimensions");
if (!rhs_batch_dimensions) {
return NULL;
}
length = PySequence_Size(rhs_batch_dimensions);
if (length == -1) {
Py_DECREF(rhs_batch_dimensions);
return NULL;
}
for (int i = 0; i < length; ++i) {
PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i);
if (!item) {
Py_DECREF(rhs_batch_dimensions);
return NULL;
}
const int64 dimension = numpy::PyIntOrPyLongToLong(item);
if (dimension == -1 && PyErr_Occurred()) {
Py_DECREF(item);
Py_DECREF(rhs_batch_dimensions);
return NULL;
}
dimension_numbers.add_rhs_batch_dimensions(dimension);
Py_DECREF(item);
}
Py_DECREF(rhs_batch_dimensions);
$1 = &dimension_numbers;
}
// PaddingConfig
%typemap(in) const PaddingConfig&
@ -623,6 +803,30 @@ tensorflow::ImportNumpy();
$1 = &dimension_numbers;
}
// ExecutableBuildOptions
%typemap(in) const ExecutableBuildOptions*
(ExecutableBuildOptions build_options) {
if ($input == Py_None) {
$1 = NULL;
} else {
PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph");
if (!o) {
return NULL;
}
if (o != Py_None) {
if (!PyString_Check(o)) {
PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None.");
return NULL;
}
build_options.set_generate_hlo_graph(PyString_AsString(o));
}
Py_DECREF(o);
$1 = &build_options;
}
}
%ignoreall
%unignore xla;
%unignore xla::swig;
@ -667,6 +871,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Call;
%unignore xla::swig::LocalComputationBuilder::Transpose;
%unignore xla::swig::LocalComputationBuilder::Rev;
%unignore xla::swig::LocalComputationBuilder::Clamp;
%unignore xla::swig::LocalComputationBuilder::Map;
%unignore xla::swig::LocalComputationBuilder::Reduce;
%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding;
@ -681,6 +886,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Lt;
%unignore xla::swig::LocalComputationBuilder::Le;
%unignore xla::swig::LocalComputationBuilder::Dot;
%unignore xla::swig::LocalComputationBuilder::DotGeneral;
%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
%unignore xla::swig::LocalComputationBuilder::Add;
%unignore xla::swig::LocalComputationBuilder::Sub;
@ -696,6 +902,7 @@ tensorflow::ImportNumpy();
%unignore xla::swig::LocalComputationBuilder::Exp;
%unignore xla::swig::LocalComputationBuilder::Floor;
%unignore xla::swig::LocalComputationBuilder::Ceil;
%unignore xla::swig::LocalComputationBuilder::Round;
%unignore xla::swig::LocalComputationBuilder::Log;
%unignore xla::swig::LocalComputationBuilder::Sign;
%unignore xla::swig::LocalComputationBuilder::Cos;

View File

@ -176,85 +176,107 @@ static string PyObjectCppRepr(PyObject* o) {
return ExtractStringAndDecref(r);
}
Status CheckPyShapeInfo(PyObject* o) {
StatusOr<Shape> XlaShapeFromPyShape(PyObject* o) {
auto error = [o](const string& prefix) {
return InvalidArgument("%s; got %s", prefix.c_str(),
PyObjectCppRepr(o).c_str());
};
// The object is a tuple (a pair)
if (!PyTuple_Check(o)) {
return error("Shape record must be a tuple");
}
if (PyTuple_Size(o) != 2) {
return error("Shape record tuple must be of length 2");
}
// It has a first element, which is a numpy dtype object
PyObject* first = PyTuple_GetItem(o, 0);
if (first == nullptr) {
return error("Tuple has no item 0 (shape dtype)");
}
if (first->ob_type != &PyArrayDescr_Type) {
return error(
"Shape record does not have a numpy dtype as its first element");
}
const int np_type = NumpyTypenum(first);
if (!NumpyTypeIsValid(np_type)) {
return error("Shape record has an invalid integer dtype");
}
// It has a second element, which is a tuple, either of shape
// records or of Python ints
PyObject* second = PyTuple_GetItem(o, 1);
if (!second) {
return error("Tuple has no item 0 (shape dimensions)");
}
if (!PyTuple_Check(second)) {
return error("Shape record does not have a tuple as its second element");
}
const int length = PyTuple_Size(second);
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
for (int i = 0; i < length; i++) {
PyObject* dimension = PyTuple_GetItem(second, i);
if (element_type == TUPLE) {
VLOG(3) << "element_type is tuple, checking member: " << i;
Status result = CheckPyShapeInfo(dimension);
if (!result.ok()) {
return AddStatus(
result, tensorflow::strings::StrCat("Validating tuple member ", i,
" of ", PyObjectCppRepr(o)));
}
} else if (!CheckPyIntOrLong(dimension)) {
return error("Non-tuple shape record has a non-integer dimension");
auto get_attr = [o, &error](const string& field) -> StatusOr<PyObject*> {
PyObject* result =
PyObject_GetAttrString(o, const_cast<char*>(field.c_str()));
if (result == nullptr) {
return error(tensorflow::strings::StrCat(
"Failed to get attribute of Shape object:", field));
}
return result;
};
auto call_method = [o, &error](const string& method) -> StatusOr<PyObject*> {
PyObject* result =
PyObject_CallMethod(o, const_cast<char*>(method.c_str()), nullptr);
if (result == nullptr) {
return error(tensorflow::strings::StrCat(
"Failed to call method of shape object:", method));
}
return result;
};
PyObject* np_type;
TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype"));
if (np_type->ob_type != &PyArrayDescr_Type) {
return error("Shape attribute np_dtype is not an integer numpy dtype");
}
if (!NumpyTypeIsValid(NumpyTypenum(np_type))) {
return error("Shape attribute np_dtype is not a valid integer numpy dtype");
}
const PrimitiveType element_type =
NumpyTypeToPrimitiveType(NumpyTypenum(np_type));
Py_DECREF(np_type);
return Status::OK();
}
// Precondition: CheckPyShapeInfo(o)
Shape XlaShapeFromPyShapeInfo(PyObject* o) {
const int np_type = NumpyTypenum(PyTuple_GetItem(o, 0));
const PrimitiveType element_type = NumpyTypeToPrimitiveType(np_type);
PyObject* py_dimensions = PyTuple_GetItem(o, 1);
const int length = PyTuple_Size(py_dimensions);
if (element_type == TUPLE) {
PyObject* py_subshapes;
TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes"));
if (!PyTuple_Check(py_subshapes)) {
return error(
"Return value of Shape method tuple_shapes() is not a tuple");
}
const int length = PyTuple_Size(py_subshapes);
std::vector<Shape> subshapes;
subshapes.reserve(length);
for (int i = 0; i < length; i++) {
subshapes.push_back(
XlaShapeFromPyShapeInfo(PyTuple_GetItem(py_dimensions, i)));
TF_ASSIGN_OR_RETURN(
const Shape& subshape,
XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i)));
subshapes.push_back(subshape);
}
Py_DECREF(py_subshapes);
return ShapeUtil::MakeTupleShape(subshapes);
} else {
PyObject* py_dimensions;
PyObject* py_minor_to_major;
TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions"));
TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major"));
if (!PyTuple_Check(py_dimensions)) {
return error("Return value of Shape method dimensions() is not a tuple");
}
if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) {
return error(
"Return value of Shape method minor_to_major() is neither a tuple "
"nor None");
}
const int length = PyTuple_Size(py_dimensions);
if (py_minor_to_major != Py_None &&
length != PyTuple_Size(py_minor_to_major)) {
return error(
"Shape methods dimensions() and minor_to_major() return "
"different-length tuples");
}
std::vector<int64> dimensions(length);
std::vector<int64> minor_to_major(length);
for (int i = 0; i < length; i++) {
dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i));
if (dimensions[i] == -1) {
CHECK(!PyErr_Occurred());
if (dimensions[i] == -1 && PyErr_Occurred()) {
return error("Dimension is not an int");
}
if (py_minor_to_major != Py_None) {
minor_to_major[i] =
PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i));
if (minor_to_major[i] == -1 && PyErr_Occurred()) {
return error("Minor-to-major value is not an int");
}
}
}
return ShapeUtil::MakeShape(element_type, dimensions);
bool with_layout = py_minor_to_major != Py_None;
Py_DECREF(py_dimensions);
Py_DECREF(py_minor_to_major);
if (with_layout) {
return ShapeUtil::MakeShapeWithLayout(element_type, dimensions,
minor_to_major);
} else {
return ShapeUtil::MakeShape(element_type, dimensions);
}
}
}

View File

@ -56,15 +56,11 @@ bool NumpyTypeIsValid(int np_type);
// The return value is a new reference.
PyObject* PyShapeInfoFromXlaShape(const Shape& shape);
// Returns the outcome of a best-effort check that the Python object
// is a pair of the form (numpy dtype, dimensions), as produced by
// PyShapeInfoFromXlaShape.
Status CheckPyShapeInfo(PyObject* o);
// Performs the inverse conversion to that of PyShapeInfoFromXlaShape.
// Converts a Python object with a method interface mathing that of
// xla_client.Shape into an XLA Shape object.
//
// The return value is a new reference.
Shape XlaShapeFromPyShapeInfo(PyObject* o);
StatusOr<Shape> XlaShapeFromPyShape(PyObject* o);
// Converts a PyObject that represents operation metadata into protocol buffer
// form.

View File

@ -89,6 +89,7 @@ _UNARY_OPS = [
'Abs',
'Exp',
'Floor',
'Round',
'Ceil',
'Log',
'Sign',
@ -155,9 +156,14 @@ class LocalBuffer(object):
self._delete = c_api.DeleteLocalShapedBuffer
@staticmethod
def from_py(npval):
def from_py(npval, layout_fn=None):
npval = require_numpy_array_layout(npval)
return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval))
if layout_fn:
shape = Shape.from_numpy(npval)
shape = shape.map_leaves(layout_fn)
else:
shape = None
return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape))
def to_py(self):
return self.c_local_shaped_buffer.ToLiteral()
@ -182,13 +188,17 @@ class Shape(object):
represents an XLA tuple.
"""
def __init__(self, np_dtype, dimensions):
def __init__(self, np_dtype, dimensions, minor_to_major=None):
assert isinstance(dimensions, tuple)
self.np_dtype = np_dtype
self._dimensions = dimensions
self._minor_to_major = minor_to_major
self._check_minor_to_major()
def __repr__(self):
return 'xla_client.Shape(np_dtype={!r}, dimensions={!r})'.format(
self.np_dtype, self._dimensions)
return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, '
'minor_to_major={!r})').format(self.np_dtype, self._dimensions,
self._minor_to_major)
def element_type(self):
return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)]
@ -201,11 +211,49 @@ class Shape(object):
raise ValueError('Tuple shape has no dimensions')
return self._dimensions
def minor_to_major(self):
return self._minor_to_major
def tuple_shapes(self):
if not self.is_tuple():
raise ValueError('Shape is not a tuple shape')
return self._dimensions
def rank(self):
return len(self.dimensions())
def map_leaves(self, f):
"""Map f over each leaf-level array subshape.
Args:
f: The function to apply. Whenever f returns None, the identity is
applied instead.
Returns:
A new Shape with the mapped leaves.
"""
if self.is_tuple():
children = tuple(child.map_leaves(f) for child in self.tuple_shapes())
return Shape(np.dtype('O'), children)
else:
mapped = f(self)
return self if mapped is None else mapped
def _check_minor_to_major(self):
mtm = self._minor_to_major
if self.is_tuple():
assert mtm is None, self
if mtm is not None:
assert self.rank() == len(mtm), self
assert sorted(mtm) == range(len(mtm)), self
def update_minor_to_major(self, minor_to_major):
if not isinstance(minor_to_major, tuple):
raise TypeError('minor_to_major must be a tuple')
updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major)
updated._check_minor_to_major() # pylint: disable=protected-access
return updated
@staticmethod
def from_numpy(npval):
@ -222,23 +270,10 @@ def _wrap_shape(shape_info):
dtype, dims = shape_info
element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)]
if element_type == xla_data_pb2.TUPLE:
dims = [_wrap_shape(subshape_info) for subshape_info in dims]
dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims)
return Shape(dtype, dims)
def _unwrap_shape(shape):
if shape.is_tuple():
components = tuple(
_unwrap_shape(subshape) for subshape in shape.tuple_shapes())
else:
components = shape.dimensions()
return (shape.np_dtype, components)
def _unwrap_shapes(shapes):
return [_unwrap_shape(shape) for shape in shapes]
def _wrap_data_handle(handle):
cdh = xla_data_pb2.ComputationDataHandle()
cdh.handle = handle
@ -260,6 +295,17 @@ def require_numpy_array_layout(value):
return np.require(value, requirements=['C', 'A'])
class CompileOptions(object):
"""Python object for XLA compile options.
These options can be passed to the 'compile' step when using a local XLA
client.
"""
def __init__(self):
self.generate_hlo_graph = None
def transfer_to_infeed(value, replica_number=None):
"""Transfers the given value into the XLA infeed queue.
@ -291,8 +337,7 @@ def transfer_from_outfeed(shape, replica_number=None):
Returns:
The literal value that is produced from the outfeed queue.
"""
return c_api.TransferFromOutfeedLocalReplica(
_unwrap_shape(shape), replica_number or 0)
return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0)
class LocalComputation(object):
@ -313,22 +358,39 @@ class LocalComputation(object):
else:
self._delete = c_api.DeleteLocalComputation
def Compile(self, argument_shapes=()):
def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None):
if self.is_compiled:
raise ValueError('Attempt to compile a compiled local XLA computation.')
if layout_fn:
argument_shapes = [
shape.map_leaves(layout_fn) for shape in argument_shapes
]
return LocalComputation(
self.c_local_computation.Compile(_unwrap_shapes(argument_shapes)),
self.c_local_computation.Compile(argument_shapes, compile_options),
is_compiled=True)
def CompileWithExampleArguments(self, arguments=()):
def CompileWithExampleArguments(self,
arguments=(),
compile_options=None,
layout_fn=None):
return self.Compile(
argument_shapes=[Shape.from_numpy(arg) for arg in arguments])
argument_shapes=[Shape.from_numpy(arg) for arg in arguments],
compile_options=compile_options,
layout_fn=layout_fn)
def Execute(self, arguments=()):
def Execute(self, arguments=(), layout_fn=None):
"""Execute with Python values as arguments and return value."""
if not self.is_compiled:
raise ValueError('Cannot execute an uncompiled local XLA computation.')
argument_shapes = [Shape.from_numpy(arg) for arg in arguments]
if layout_fn:
argument_shapes = [
shape.map_leaves(layout_fn) for shape in argument_shapes
]
else:
argument_shapes = [None for shape in argument_shapes]
arguments = tuple(map(require_numpy_array_layout, arguments))
return self.c_local_computation.Execute(arguments)
return self.c_local_computation.Execute(arguments, argument_shapes)
def ExecuteWithLocalBuffers(self, arguments=()):
"""Execute with LocalBuffer arguments and return value."""
@ -384,7 +446,7 @@ class ComputationBuilder(object):
Returns:
A ComputationDataHandle message.
"""
return _wrap_data_handle(self._client.Infeed(_unwrap_shape(shape)))
return _wrap_data_handle(self._client.Infeed(shape))
def Outfeed(self, operand):
"""Enqueues an outfeed op onto the computation.
@ -393,7 +455,7 @@ class ComputationBuilder(object):
outfeed queue for subsequent dequeue via the client API.
"""
self._client.Outfeed(
_unwrap_data_handle(operand), _unwrap_shape(self.GetShape(operand)),
_unwrap_data_handle(operand), self.GetShape(operand),
''.encode('utf-8'))
def Constant(self, value):
@ -484,8 +546,7 @@ class ComputationBuilder(object):
parameter_num = next(self._parameter_numbering)
return _wrap_data_handle(
self._client.Parameter(
parameter_num, _unwrap_shape(shape), name.encode('utf8')))
self._client.Parameter(parameter_num, shape, name.encode('utf8')))
def ParameterFromNumpy(self, value, name=None, parameter_num=None):
"""Enqueues a Parameter op onto the computation.
@ -606,6 +667,13 @@ class ComputationBuilder(object):
return _wrap_data_handle(
self._client.Rev(_unwrap_data_handle(operand), dimensions))
def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin
"""Clamp op."""
return _wrap_data_handle(
self._client.Clamp(_unwrap_data_handle(min),
_unwrap_data_handle(operand),
_unwrap_data_handle(max)))
def SelectAndScatter(self, operand, select, window_dimensions, window_strides,
padding, source, init_value, scatter):
"""Select and scatter op, used by the gradient of ReduceWindow.
@ -825,8 +893,7 @@ class ComputationBuilder(object):
shape = Shape(self.GetShape(mu).np_dtype, dims)
return _wrap_data_handle(
self._client.RngNormal(
_unwrap_data_handle(mu), _unwrap_data_handle(sigma),
_unwrap_shape(shape)))
_unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape))
def RngUniform(self, a, b, dims):
"""Enqueues an RngUniform operation onto the computation.
@ -846,8 +913,7 @@ class ComputationBuilder(object):
shape = Shape(self.GetShape(a).np_dtype, dims)
return _wrap_data_handle(
self._client.RngUniform(
_unwrap_data_handle(a), _unwrap_data_handle(b),
_unwrap_shape(shape)))
_unwrap_data_handle(a), _unwrap_data_handle(b), shape))
def While(self, cond, body, init):
"""Enqueues a While operation onto the computation.
@ -865,10 +931,37 @@ class ComputationBuilder(object):
_unwrap_data_handle(init)))
def Dot(self, lhs, rhs):
"""Matrix multiplication between lhs and rhs."""
"""Enqueues a dot operation onto the computation.
Args:
lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array.
rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array.
Returns: a ComputationDataHandle representing the Dot operation.
"""
return _wrap_data_handle(
self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs)))
def DotGeneral(self, lhs, rhs, dimension_numbers):
"""Enqueues a general dot operation onto the computation.
Args:
lhs: ComputationDataHandle for the left-hand-side array.
rhs: ComputationDataHandle for the right-hand-side array.
dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested
tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of
integers representing the dimensions to treat as contracting dimensions
and batch dimensions on each input operand.
Returns: a ComputationDataHandle representing the DotGeneral operation.
"""
if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers):
dimension_numbers = GetDotDimensionsFromLists(dimension_numbers)
return _wrap_data_handle(
self._client.DotGeneral(
_unwrap_data_handle(lhs), _unwrap_data_handle(rhs),
dimension_numbers))
def Conv(self, lhs, rhs, window_strides, padding):
"""Enqueues a Conv operation onto the computation.
@ -979,7 +1072,7 @@ def initialize_replica_count(replica_count):
Args:
replica_count: number of replicas that are desired for set up during XLA
initalization.
initialization.
Raises:
A runtime exception if the XLA service has already been initialized.
@ -1005,3 +1098,13 @@ def GetPaddingConfigFromTriples(triples):
dimension.edge_padding_high = hi
dimension.interior_padding = interior
return padding_config
def GetDotDimensionsFromLists(dimension_numbers):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
dot_dims_proto = xla_data_pb2.DotDimensionNumbers()
dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract)
dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract)
dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch)
dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch)
return dot_dims_proto

View File

@ -444,6 +444,30 @@ class SingleOpTest(LocalComputationTest):
c.Dot(c.Constant(lhs), c.Constant(rhs))
self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs))
def testDotGeneral(self):
c = self._NewComputation()
rng = np.random.RandomState(0)
lhs = NumpyArrayF32(rng.randn(10, 3, 4))
rhs = NumpyArrayF32(rng.randn(10, 4, 5))
dimension_numbers = (([2], [1]), ([0], [0]))
c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
def testDotGeneralWithDotDimensionNumbersProto(self):
c = self._NewComputation()
rng = np.random.RandomState(0)
lhs = NumpyArrayF32(rng.randn(10, 3, 4))
rhs = NumpyArrayF32(rng.randn(10, 4, 5))
dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers()
dimension_numbers.lhs_contracting_dimensions.append(2)
dimension_numbers.rhs_contracting_dimensions.append(1)
dimension_numbers.lhs_batch_dimensions.append(0)
dimension_numbers.rhs_batch_dimensions.append(0)
c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers)
self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs))
def testConvF32Same(self):
c = self._NewComputation()
a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32")
@ -496,6 +520,12 @@ class SingleOpTest(LocalComputationTest):
c.Exp(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.exp(arr))
def testRound(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
c.Round(c.Constant(arr))
self._ExecuteAndCompareClose(c, expected=np.round(arr))
def testLog(self):
c = self._NewComputation()
arr = NumpyArrayF32([3.3, 12.1])
@ -699,6 +729,23 @@ class SingleOpTest(LocalComputationTest):
self._ExecuteAndCompareExact(
c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]])
def testClampF32(self):
c = self._NewComputation()
c.Clamp(
c.Constant(NumpyArrayF32(-1)),
c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])),
c.Constant(NumpyArrayF32(2)))
self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2])
# TODO(b/72689392): re-enable when bug S32 resolved
def DISABLED_testClampS32(self):
c = self._NewComputation()
c.Clamp(
c.Constant(NumpyArrayS32(-1)),
c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])),
c.Constant(NumpyArrayS32(2)))
self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2])
def testSelect(self):
c = self._NewComputation()
c.Select(

View File

@ -509,6 +509,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
@ -1110,8 +1111,6 @@ cc_library(
":hlo",
":hlo_evaluator",
":hlo_pass",
":tuple_util",
":while_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
],
@ -1156,6 +1155,34 @@ tf_cc_test(
],
)
cc_library(
name = "implicit_broadcast_remover",
srcs = ["implicit_broadcast_remover.cc"],
hdrs = ["implicit_broadcast_remover.h"],
deps = [
":hlo",
":hlo_dce",
":hlo_pass",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "implicit_broadcast_remover_test",
srcs = ["implicit_broadcast_remover_test.cc"],
deps = [
":hlo_matchers",
":implicit_broadcast_remover",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/tests:hlo_verified_test_base",
],
)
cc_library(
name = "dot_decomposer",
srcs = ["dot_decomposer.cc"],
@ -1825,7 +1852,9 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tools/parser:hlo_parser",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
@ -1856,6 +1885,7 @@ cc_library(
":hlo",
":hlo_graph_dumper",
":hlo_pass",
":hlo_proto_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",

View File

@ -1618,9 +1618,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
reduce,
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
}
// A Transpose feeding a reduce can simply permute the reduction dimensions
// field.
if (arg->opcode() == HloOpcode::kTranspose) {
// field if the output of the reduce is a vector or scalar. Higher ranked
// result may require a transpose of the output.
if (ShapeUtil::Rank(reduce->shape()) <= 1 &&
arg->opcode() == HloOpcode::kTranspose) {
auto transpose_dimensions = arg->dimensions();
std::vector<int64> new_reduce_dimensions;
for (auto dim : dimensions) {

View File

@ -997,14 +997,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
auto color = single_colored_set.first;
VLOG(2) << "Simulating heap for color " << color;
int64 alignment = assignment->color_alignment_(color);
HeapSimulator::Options options;
options.buffers_to_assign = &single_colored_set.second;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
MakeUnique<LazyBestFitHeap>(alignment)),
assignment->module(), module_sequence,
assignment->points_to_analysis(),
assignment->buffer_size_,
&single_colored_set.second));
assignment->buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment,
single_colored_set.first);
}
@ -1024,14 +1025,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
auto color = single_colored_set.first;
VLOG(2) << "Simulating heap for color " << color;
int64 alignment = assignment->color_alignment_(color);
HeapSimulator::Options options;
options.buffers_to_assign = &single_colored_set.second;
TF_ASSIGN_OR_RETURN(
const HeapSimulator::Result result,
HeapSimulator::Run(MakeUnique<DecreasingSizeRunsHeap>(
MakeUnique<LazyBestFitHeap>(alignment)),
*computation, *instruction_sequence,
assignment->points_to_analysis(),
assignment->buffer_size_,
&single_colored_set.second));
assignment->buffer_size_, options));
AssignBuffersFromHeapSimulator(result, assignment,
single_colored_set.first);
}

View File

@ -72,8 +72,18 @@ class AotCompilationOptions {
// Returns the ID of the platform to which these options apply.
virtual perftools::gputools::Platform::Id PlatformId() const = 0;
// Optional allocator that may be used for allocating temp space on the device
// during compilation.
DeviceMemoryAllocator* device_allocator() const { return device_allocator_; }
void set_device_allocator(DeviceMemoryAllocator* device_allocator) {
device_allocator_ = device_allocator;
}
protected:
AotCompilationOptions() = default;
private:
DeviceMemoryAllocator* device_allocator_ = nullptr;
};
// Abstract compiler interface that is subclassed for compilation on a
@ -99,9 +109,16 @@ class Compiler {
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
// module.
//
// If device_allocator is not null, the compiler may use it to allocate temp
// space on the device for use during compilation. For example, the compiler
// may allocate buffers on the device and then run variants of a given
// algorithm over those buffers, to see which variant is fastest. Any space
// allocated should be deallocated before this function returns.
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* executor) = 0;
perftools::gputools::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) = 0;
// Compiles the HLO module for execution on a device given by the executor,
// and returns an executable object or an error status. No HLO passes are
@ -112,21 +129,27 @@ class Compiler {
// The compiler may optionally specialize to the individual device
// (not just type of device) indicated by the executor.
//
// device_allocator is optional; see RunHloPasses.
//
// Use the overload below to compile computations that run in parallel.
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* executor) = 0;
perftools::gputools::StreamExecutor* executor,
DeviceMemoryAllocator* device_allocator) = 0;
// Compiles a set of HLO modules that can run in parallel, potentially
// communicating data between the modules, and returns a corresponding
// sequence of executable objects.
//
// device_allocator is optional; see RunHloPasses.
//
// TODO(b/68666782): Remove this method after adding support for multiple
// modules to RunHloPasses and RunBackends.
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::vector<std::unique_ptr<HloModule>> modules,
std::vector<std::vector<perftools::gputools::StreamExecutor*>>
stream_exec) = 0;
stream_exec,
DeviceMemoryAllocator* device_allocator) = 0;
// Compiles the HLO module for ahead-of-time execution. This is intended for
// use in static compilation.

View File

@ -437,7 +437,8 @@ Status VerifyLlvmModule(const llvm::Module& llvm_module) {
StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* /*stream_exec*/) {
perftools::gputools::StreamExecutor* /*stream_exec*/,
DeviceMemoryAllocator* /*device_allocator*/) {
VLOG(2) << "Before optimization:";
XLA_VLOG_LINES(2, module->ToString());
@ -450,7 +451,8 @@ StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* stream_exec) {
perftools::gputools::StreamExecutor* stream_exec,
DeviceMemoryAllocator* /*device_allocator*/) {
const string timer_message =
"Compiling [" + module->name() + "] for CPU using JIT";
XLA_SCOPED_LOGGING_TIMER(timer_message);
@ -517,8 +519,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// ownership is std::moved.
const bool embed_ir_in_executable =
module->config().debug_options().xla_embed_ir_in_executable();
const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_hlo_proto_to();
const string xla_dump_optimized_hlo_proto_to =
module->config().debug_options().xla_dump_optimized_hlo_proto_to();
if (options::CpuParallelBackendRequested(module->config())) {
VLOG(1) << "Using parallel cpu backend";
@ -538,10 +540,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
if (!xla_dump_hlo_proto_to.empty()) {
if (!xla_dump_optimized_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
proto, xla_dump_optimized_hlo_proto_to, module->name()));
}
// If we are using the parallel CPU backend, we need to create map from
@ -647,10 +649,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
if (!xla_dump_hlo_proto_to.empty()) {
if (!xla_dump_optimized_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
proto, xla_dump_optimized_hlo_proto_to, module->name()));
}
// Each computation is a single function. Emit all embedded computations
@ -826,12 +828,12 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_hlo_proto_to();
if (!xla_dump_hlo_proto_to.empty()) {
const string xla_dump_optimized_hlo_proto_to =
module->config().debug_options().xla_dump_optimized_hlo_proto_to();
if (!xla_dump_optimized_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
proto, xla_dump_optimized_hlo_proto_to, module->name()));
}
IrEmitter ir_emitter(*module, *assignment, &llvm_module,

View File

@ -118,11 +118,13 @@ class CpuCompiler : public LLVMCompiler {
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* stream_exec) override;
perftools::gputools::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module,
perftools::gputools::StreamExecutor* stream_exec) override;
perftools::gputools::StreamExecutor* stream_exec,
DeviceMemoryAllocator* device_allocator) override;
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,

View File

@ -479,7 +479,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
Status IrEmitter::HandleSort(HloInstruction* sort) {
// TODO(b/26783907): Implement sort on CPU.
return Unimplemented("Sort is not supported on CPU (b/26783907).");
return Unimplemented("Sort is not implemented on CPU.");
}
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
@ -522,7 +522,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
// TODO(b/31410564): Implement dilation for reduce-window.
if (window_util::HasDilation(window)) {
return Unimplemented(
"Dilation for reduce-window not implemented on CPU. See b/31410564.");
"Dilation for ReduceWindow is not implemented on CPU.");
}
// The called computation should have been emitted previously.
@ -625,8 +625,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
// TODO(b/31410564): Implement dilation for select-and-scatter.
if (window_util::HasDilation(window)) {
return Unimplemented(
"Dilation for select-and-scatter not implemented on CPU. "
"See b/31410564.");
"Dilation for SelectAndScatter is not implemented on CPU. ");
}
// The select and scatter computations should have been emitted previously.
@ -1196,8 +1195,7 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
}
// TODO(b/33011107): Support cross replica sum on CPU.
return Unimplemented(
"Cross replica sum is not implemented on CPU. See b/33011107.");
return Unimplemented("CrossReplicaSum is not implemented on CPU.");
}
// Fills up the free variables in 'index_with_free_var' with values from
@ -1811,12 +1809,12 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) {
Status IrEmitter::HandleSend(HloInstruction* send) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Send is not implemented on CPU. See b/33942983.");
return Unimplemented("Send is not implemented on CPU.");
}
Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Send-done is not implemented on CPU. See b/33942983.");
return Unimplemented("Send-done is not implemented on CPU.");
}
Status IrEmitter::HandleSlice(HloInstruction* slice) {
@ -1981,12 +1979,12 @@ Status IrEmitter::HandleDynamicUpdateSlice(
Status IrEmitter::HandleRecv(HloInstruction* recv) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
return Unimplemented("Recv is not implemented on CPU.");
}
Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
// TODO(b/33942983): Support Send/Recv on CPU.
return Unimplemented("Recv-done is not implemented on CPU. See b/33942983.");
return Unimplemented("Recv-done is not implemented on CPU.");
}
Status IrEmitter::HandlePad(HloInstruction* pad) {
@ -1995,10 +1993,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
for (auto& padding_dimension : pad->padding_config().dimensions()) {
if (padding_dimension.edge_padding_low() < 0 ||
padding_dimension.edge_padding_high() < 0) {
return Unimplemented(
"Negative padding not supported in the CPU backend (b/34628603); "
"this should have been eliminated at the HLO level: %s",
pad->padding_config().ShortDebugString().c_str());
return InternalErrorStrCat(
"Encountered negative padding in IrEmitter on CPU. "
"This should have been eliminated at the HLO level. ",
pad->ToString());
}
}

View File

@ -24,7 +24,7 @@ limitations under the License.
namespace xla {
StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
perftools::gputools::Platform* platform,
const perftools::gputools::Platform* platform,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors)
: DeviceMemoryAllocator(platform),

View File

@ -33,7 +33,7 @@ class DeviceMemoryAllocator {
public:
// Parameter platform indicates which platform the allocator allocates memory
// on. Must be non-null.
explicit DeviceMemoryAllocator(perftools::gputools::Platform* platform)
explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform)
: platform_(platform) {}
virtual ~DeviceMemoryAllocator() {}
@ -49,14 +49,14 @@ class DeviceMemoryAllocator {
int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0;
// Return the platform that the allocator allocates memory on.
perftools::gputools::Platform* platform() const { return platform_; }
const perftools::gputools::Platform* platform() const { return platform_; }
// Can we call Deallocate() as soon as a computation has been scheduled on
// a stream, or do we have to wait for the computation to complete first?
virtual bool AllowsAsynchronousDeallocation() const = 0;
protected:
perftools::gputools::Platform* platform_;
const perftools::gputools::Platform* platform_;
};
// Default memory allocator for a platform which uses
@ -64,7 +64,7 @@ class DeviceMemoryAllocator {
class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
public:
StreamExecutorMemoryAllocator(
perftools::gputools::Platform* platform,
const perftools::gputools::Platform* platform,
tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
stream_executors);

View File

@ -428,7 +428,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
llvm::Intrinsic::round, {operand_value}, {operand_value->getType()},
ir_builder_);
case HloOpcode::kSign: {
// TODO(b/32151903): Ensure consistent sign behavior for -0.0
// TODO(b/32151903): Ensure consistent sign behavior for -0.0.
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
@ -870,7 +870,10 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) const {
if (prim_type != F32) {
return Unimplemented("inverse erf only implemented for F32 (b/34339814)");
// TODO(b/34339814): Implement inverse erf for F64.
return Unimplemented(
"Inverse erf is only implemented for element "
"type F32.");
}
auto getFloat = [&](const float f) {
return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
@ -1040,17 +1043,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
lhs_value, rhs_value, ir_builder_);
case HloOpcode::kMinimum:
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
lhs_value, rhs_value),
lhs_value, rhs_value);
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
case HloOpcode::kMaximum:
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
lhs_value, rhs_value),
lhs_value, rhs_value);
return EmitIntegralMax(lhs_value, rhs_value, is_signed);
case HloOpcode::kAnd:
return ir_builder_->CreateAnd(lhs_value, rhs_value);
case HloOpcode::kOr:
@ -1067,6 +1062,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
}
}
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) const {
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
lhs_value, rhs_value),
lhs_value, rhs_value);
}
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) const {
return ir_builder_->CreateSelect(
ir_builder_->CreateICmp(
is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
lhs_value, rhs_value),
lhs_value, rhs_value);
}
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
int64 operand_no) const {
@ -1363,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2)));
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
PrimitiveType prim_type = hlo->shape().element_type();
if (primitive_util::IsFloatingPointType(prim_type)) {
return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
} else if (primitive_util::IsIntegralType(prim_type)) {
bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
return EmitIntegralMin(
max_value, EmitIntegralMax(min_value, arg_value, is_signed),
is_signed);
} else {
return Unimplemented("Clamp unimplemented for %s",
PrimitiveType_Name(prim_type).c_str());
}
};
case HloOpcode::kReducePrecision:
return [this, hlo, &operand_to_generator](

View File

@ -86,6 +86,12 @@ class ElementalIrEmitter {
virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) const;
llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) const;
llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) const;
virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
llvm::Value* value) const;

View File

@ -131,6 +131,7 @@ cc_library(
"ir_emitter_context.h",
],
deps = [
":cudnn_convolution_runner",
":elemental_ir_emitter",
":gpu_constants",
":gpu_executable",
@ -262,6 +263,7 @@ cc_library(
],
deps = [
":buffer_allocations",
":cudnn_convolution_runner",
":infeed_manager",
":ir_emission_utils",
":partition_assignment",
@ -309,9 +311,41 @@ cc_library(
)
cc_library(
name = "convolution_folding",
srcs = ["convolution_folding.cc"],
hdrs = ["convolution_folding.h"],
name = "cudnn_convolution_algorithm_picker",
srcs = ["cudnn_convolution_algorithm_picker.cc"],
hdrs = ["cudnn_convolution_algorithm_picker.h"],
deps = [
":cudnn_convolution_runner",
":gpu_executable",
":ir_emission_utils",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
],
)
cc_library(
name = "cudnn_convolution_runner",
srcs = ["cudnn_convolution_runner.cc"],
hdrs = ["cudnn_convolution_runner.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:stream_executor_no_cuda",
],
)
cc_library(
name = "cudnn_convolution_rewriter",
srcs = ["cudnn_convolution_rewriter.cc"],
hdrs = ["cudnn_convolution_rewriter.h"],
deps = [
":ir_emission_utils",
"//tensorflow/compiler/xla:literal_util",
@ -325,15 +359,18 @@ cc_library(
)
tf_cc_test(
name = "convolution_folding_test",
srcs = ["convolution_folding_test.cc"],
name = "cudnn_convolution_rewriter_test",
srcs = ["cudnn_convolution_rewriter_test.cc"],
deps = [
":convolution_folding",
":cudnn_convolution_rewriter",
":ir_emission_utils",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:test",
],
)
@ -446,7 +483,8 @@ cc_library(
srcs = ["gpu_compiler.cc"],
hdrs = ["gpu_compiler.h"],
deps = [
":convolution_folding",
":cudnn_convolution_algorithm_picker",
":cudnn_convolution_rewriter",
":fusion_merger",
":gpu_constants",
":gpu_copy_insertion",
@ -514,7 +552,6 @@ cc_library(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"@llvm//:core",
],
)

Some files were not shown because too many files have changed in this diff Show More