Automated g4 rollback of changelist 196691101

PiperOrigin-RevId: 196879933
This commit is contained in:
Peter Hawkins 2018-05-16 13:34:10 -07:00 committed by TensorFlower Gardener
parent c9e4705d62
commit 41af9782f4
22 changed files with 507 additions and 256 deletions

View File

@ -551,14 +551,16 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
"%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
auto add_profile_line = HasSubstr(
"%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
auto tuple_profile_line = HasSubstr(
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
"%dot.0.2, f32[2,2]{1,0} %add.0.5)");
"%dot.0.4, f32[2,2]{1,0} %add.0.6)");
auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
EXPECT_THAT(hlo_profile_lines,
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,

View File

@ -112,7 +112,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
const XlaDevice::Metadata* metadata;
const XlaDevice::Metadata* metadata = nullptr;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
bool allocate_xla_tensors = s.ok();
@ -153,9 +153,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
options.device_allocator = xla_allocator;
// TODO(b/77671268): We don't set variable_representation_shape_fn here. This
// is restricted to Variables, but we need something like this to apply to
// normal Tensors too.
if (metadata) {
options.shape_representation_fn = metadata->shape_representation_fn();
}
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
@ -164,9 +164,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
for (int i : constants_) {
constant_args.insert({i, ctx->input(i)});
}
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
variables, ctx, &kernel, &executable,
/*compile_options=*/nullptr));
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
OP_REQUIRES_OK(
ctx, cache->Compile(options, function_, constant_args, variables, ctx,
&kernel, &executable, &compile_options));
VLOG(1) << "Executing XLA Computation...";

View File

@ -156,11 +156,14 @@ Status XlaCompileOnDemandOp::Compile(
options.client = metadata.client();
options.flib_def =
new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
options.shape_representation_fn = metadata.shape_representation_fn();
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
result, executable,
/*compile_options=*/nullptr);
result, executable, &compile_options);
}
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {

View File

@ -50,10 +50,11 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
(void)registrations;
std::unique_ptr<XlaDevice> device;
TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
DEVICE_CPU_XLA_JIT, options, name_prefix,
registration,
/*transfer_as_literal=*/false, &device));
TF_RETURN_IF_ERROR(
XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false,
/*shape_representation_fn=*/{}, &device));
devices->push_back(device.release());
return Status::OK();
}

View File

@ -110,7 +110,9 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) {
bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
<< device_ordinal;
@ -129,17 +131,19 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device"));
device->reset(new XlaDevice(options, attrs, device_ordinal,
DeviceType(jit_device_name),
platform.ValueOrDie(), transfer_as_literal));
device->reset(new XlaDevice(
options, attrs, device_ordinal, DeviceType(jit_device_name),
platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
return Status::OK();
}
XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type)
XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform) {}
platform_(platform),
shape_representation_fn_(std::move(shape_representation_fn)) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@ -170,17 +174,20 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return Status::OK();
}
XlaDevice::XlaDevice(const SessionOptions& options,
const DeviceAttributes& attrs, int device_ordinal,
const DeviceType& jit_device_name, se::Platform* platform,
bool transfer_as_literal)
XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name),
xla_metadata_(device_ordinal, platform, jit_device_name,
shape_representation_fn),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
transfer_as_literal_(transfer_as_literal) {
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name;
}
@ -232,8 +239,8 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
// gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
gpu_device_info_->default_context =
new XlaDeviceContext(stream, client(), transfer_as_literal_);
gpu_device_info_->default_context = new XlaDeviceContext(
stream, client(), transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}
@ -247,7 +254,8 @@ Status XlaDevice::FillContextMap(const Graph* graph,
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({});
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_);
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
shape_representation_fn_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
@ -294,7 +302,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
XlaTransferManager manager(stream, client(), transfer_as_literal_);
XlaTransferManager manager(stream, client(), transfer_as_literal_,
shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;

View File

@ -17,8 +17,7 @@ limitations under the License.
// runtime.
//
// Operators assigned to an XlaDevice are compiled into XLA computations.
// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state
// is managed by XLA.
// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
//
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
// under different names (e.g., XLA_CPU or XLA_GPU).
@ -27,6 +26,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device_factory.h"
@ -50,7 +50,8 @@ class XlaDevice : public LocalDevice {
class Metadata {
public:
Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type);
const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
// The index of the device on this host.
int device_ordinal() const;
@ -58,11 +59,15 @@ class XlaDevice : public LocalDevice {
se::Platform* platform() const;
xla::LocalClient* client() const;
const DeviceType& jit_device_type() const;
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
return shape_representation_fn_;
}
private:
const int device_ordinal_;
const DeviceType device_type_;
se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
@ -76,16 +81,19 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead.
static Status Create(const string& platform_name, const string& device_name,
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
std::unique_ptr<XlaDevice>* device);
static Status Create(
const string& platform_name, const string& device_name,
int device_ordinal, const string& jit_device_name,
const SessionOptions& options, const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
std::unique_ptr<XlaDevice>* device);
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal);
se::Platform* platform, bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
@ -116,8 +124,8 @@ class XlaDevice : public LocalDevice {
// The name of the device that is used to compile Ops for this XlaDevice.
DeviceType jit_device_name_;
// Memory allocator associated with this device.
Allocator* xla_allocator_; // Not owned.
se::Platform* platform_; // Not owned.
Allocator* xla_allocator_; // Not owned.
se::Platform* platform_; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
@ -126,6 +134,7 @@ class XlaDevice : public LocalDevice {
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// If set, holds default device context (that we must Unref)
// and its stream.

View File

@ -47,13 +47,14 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(se::Stream* stream,
xla::LocalClient* client,
bool transfer_as_literal)
XlaTransferManager::XlaTransferManager(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: stream_(stream),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal) {}
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)) {}
Status XlaTransferManager::TransferLiteralToDevice(
const Tensor& host_tensor, Tensor* device_tensor) const {
@ -76,7 +77,15 @@ Status XlaTransferManager::TransferLiteralFromDevice(
transfer_manager_->TransferLiteralFromDevice(
stream_->parent(), shaped_buffer));
VLOG(1) << "Transfer from device as literal: " << literal->ToString();
return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor);
Tensor tensor;
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
// Reshape the tensor back to its declared shape.
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
return errors::Internal(
"Tensor::CopyFrom failed when copying from XLA device to CPU");
}
return Status::OK();
}
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
@ -96,9 +105,17 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
TensorShape shape;
if (shape_representation_fn_) {
shape = shape_representation_fn_(device_tensor->shape(),
device_tensor->dtype());
} else {
shape = device_tensor->shape();
}
if (!xla_tensor->has_shaped_buffer()) {
Status s = xla_tensor->AllocateShapedBuffer(
device_tensor->dtype(), device_tensor->shape(), client_,
device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal());
if (!s.ok()) {
done(s);
@ -106,12 +123,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
}
}
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
Status status;
if (transfer_as_literal_) {
status = TransferLiteralToDevice(*cpu_tensor, device_tensor);
Tensor reshaped_cpu_tensor;
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
done(errors::Internal(
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
Status block_status = stream_->BlockHostUntilDone();
@ -171,9 +194,11 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
done(Status::OK());
}
XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
bool transfer_as_literal)
: manager_(stream, client, transfer_as_literal) {}
XlaDeviceContext::XlaDeviceContext(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: manager_(stream, client, transfer_as_literal,
std::move(shape_representation_fn)) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/allocator.h"
@ -45,8 +46,9 @@ class XlaDeviceAllocator : public Allocator {
// Helper class for managing data transfers between host and XLA devices.
class XlaTransferManager {
public:
explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client,
bool transfer_as_literal);
explicit XlaTransferManager(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
@ -69,7 +71,8 @@ class XlaTransferManager {
// Transfer manager, for marshalling data to and from the device.
xla::TransferManager* transfer_manager_;
// True if we must use XLA's TransferManager for correct device transfers.
bool transfer_as_literal_;
const bool transfer_as_literal_;
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@ -77,8 +80,9 @@ class XlaTransferManager {
// wraps the methods in XlaTransferManager.
class XlaDeviceContext : public DeviceContext {
public:
explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
bool transfer_as_literal);
explicit XlaDeviceContext(
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,

View File

@ -48,7 +48,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
Status status =
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration,
/*transfer_as_literal=*/false, &device);
/*transfer_as_literal=*/false,
/*shape_representation_fn=*/{}, &device);
if (!status.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << status;

View File

@ -195,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs(
OP_REQUIRES_OK(
ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer(
const_tensor.dtype(), const_tensor.shape(),
client_, stream->parent()->device_ordinal()));
}
Device* device = dynamic_cast<Device*>(ctx->device());
OP_REQUIRES(ctx, device != nullptr,

View File

@ -42,7 +42,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform",
"//tensorflow/python:random_seed",
"//tensorflow/python:session",
@ -58,7 +58,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -72,7 +72,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -93,7 +93,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -111,7 +111,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:bitwise_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
@ -127,7 +127,7 @@ tf_xla_py_test(
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
@ -141,7 +141,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -156,7 +156,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -170,7 +170,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -184,7 +184,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:gradient_checker",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
@ -209,7 +209,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:gradient_checker",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
@ -225,7 +225,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@ -241,7 +241,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@ -263,7 +263,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@ -291,7 +291,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -307,7 +307,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -326,7 +326,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
@ -346,7 +346,7 @@ tf_xla_py_test(
"//tensorflow/contrib/signal:signal_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:spectral_ops",
],
@ -360,7 +360,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -372,7 +372,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -388,7 +388,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -403,7 +403,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:image_ops",
"//tensorflow/python:platform_test",
],
@ -431,7 +431,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
@ -446,7 +446,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -458,7 +458,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -472,7 +472,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -485,7 +485,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -498,7 +498,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
@ -513,7 +513,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
@ -530,7 +530,7 @@ tf_xla_py_test(
],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
@ -545,7 +545,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -561,7 +561,7 @@ tf_xla_py_test(
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -574,7 +574,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
],
)
@ -586,7 +586,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -598,7 +598,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@ -613,7 +613,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -626,7 +626,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:platform_test",
@ -641,7 +641,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -657,7 +657,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -670,7 +670,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/contrib/stateless",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -684,7 +684,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
@ -703,7 +703,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@ -716,7 +716,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@ -730,7 +730,7 @@ tf_xla_py_test(
srcs = ["fused_batchnorm_test.py"],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn",
@ -749,7 +749,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
@ -768,7 +768,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
@ -783,7 +783,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -795,7 +795,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -808,21 +808,34 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "xla_device_test",
name = "xla_device_gpu_test",
size = "small",
srcs = ["xla_device_test.py"],
srcs = ["xla_device_gpu_test.py"],
additional_deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
],
)
@ -839,7 +852,6 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
@ -887,7 +899,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
@ -902,7 +914,7 @@ cuda_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
@ -940,7 +952,7 @@ tf_xla_py_test(
srcs = ["fake_quant_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
@ -952,7 +964,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)

View File

@ -0,0 +1,48 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for XLA devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class XlaDeviceGpuTest(test.TestCase):
def testCopiesToAndFromGpuWork(self):
"""Tests that copies between GPU and XLA devices work."""
if not test.is_gpu_available():
return
with session_lib.Session() as sess:
x = array_ops.placeholder(dtypes.float32, [2])
with ops.device("GPU"):
y = x * 2
with ops.device("device:XLA_CPU:0"):
z = y * y
with ops.device("GPU"):
w = y + z
result = sess.run(w, {x: [1.5, 0.5]})
self.assertAllClose(result, [12., 2.], rtol=1e-3)
if __name__ == "__main__":
test.main()

View File

@ -1,4 +1,4 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,30 +18,33 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class XlaDeviceTest(test.TestCase):
class XlaDeviceTest(XLATestCase):
def testCopies(self):
"""Tests that copies between GPU and XLA devices work."""
if not test.is_gpu_available():
return
"""Tests that copies onto and off XLA devices work."""
shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],
[16384, 1], [1, 16384], [1, 20000, 1, 1]]
for dtype in self.numeric_types:
for shape in shapes:
with self.test_session() as sess:
with ops.device("CPU"):
x = array_ops.placeholder(dtype, shape)
with self.test_scope():
y = x + x
with ops.device("CPU"):
z = array_ops.identity(y)
with session_lib.Session() as sess:
x = array_ops.placeholder(dtypes.float32, [2])
with ops.device("GPU"):
y = x * 2
with ops.device("device:XLA_CPU:0"):
z = y * y
with ops.device("GPU"):
w = y + z
result = sess.run(w, {x: [1.5, 0.5]})
self.assertAllClose(result, [12., 2.], rtol=1e-3)
inputs = np.random.randint(-100, 100, shape).astype(dtype)
result = sess.run(z, {x: inputs})
self.assertAllCloseAccordingToType(result, inputs + inputs)
if __name__ == "__main__":

View File

@ -325,6 +325,7 @@ tf_cc_test(
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:cpu_plugin",

View File

@ -208,10 +208,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
TF_RETURN_IF_ERROR(
PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = false;
XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(),
func, arguments, &result));
TF_RETURN_IF_ERROR(
compiler->CompileFunction(compile_options, func, arguments, &result));
TF_RET_CHECK(arguments.size() == expressions.size());

View File

@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel {
}
XlaContext& tc = XlaContext::Get(ctx);
if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) {
if (tc.resolve_compile_time_constants() &&
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else {
// The core from which a return value is returned depends on the core
// assignment of the input to the retval .Since we can't change the core
// assignment of <input> as this point, create a tuple/get-tuple-element
// combination so that the core will be set on them.
auto tuple_elem =
ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0);
tc.AddRetval(index_, dtype_, tuple_elem);
TensorShape shape = ctx->InputShape(0);
TensorShape representation_shape =
tc.is_entry_computation()
? tc.RepresentationShape(shape, ctx->input_type(0))
: shape;
xla::XlaOp output = input;
if (tc.is_entry_computation()) {
output =
ctx->builder()->Reshape(input, representation_shape.dim_sizes());
} else {
// The core from which a return value is returned depends on the
// device assignment of the input to the retval. Since we can't change
// the device assignment of "input" at this point, we must always
// introduce an operator here, even if the shape does not change.
// TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally.
output = ctx->builder()->GetTupleElement(
ctx->builder()->Tuple({output}), 0);
}
tc.AddRetval(index_, dtype_, shape, output);
}
}
}

View File

@ -15,10 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include <deque>
#include <numeric>
#include <vector>
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
@ -28,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
@ -40,7 +38,6 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
@ -110,10 +107,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
flib_runtime_ = pflr_->GetFLR(device_->name());
// The default variable representation shape is the identity function.
if (!options_.variable_representation_shape_fn) {
options_.variable_representation_shape_fn =
[](const TensorShape& shape, DataType type) { return shape; };
// The default shape representation function is the identity.
if (!options_.shape_representation_fn) {
options_.shape_representation_fn = [](const TensorShape& shape,
DataType type) { return shape; };
}
}
@ -230,20 +227,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
// Computes the XLA shape for argument 'arg'.
Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
bool is_entry_computation,
xla::Shape* xla_shape) {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
xla_shape);
case XlaCompiler::Argument::kParameter:
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
LOG(FATAL) << "Unreachable case";
case XlaCompiler::Argument::kParameter: {
TensorShape shape =
is_entry_computation
? options_.shape_representation_fn(arg.shape, arg.type)
: arg.shape;
return TensorShapeToXLAShape(arg.type, shape, xla_shape);
}
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
switch (arg.resource_kind) {
case XlaResource::kVariable: {
TensorShape representation_shape =
options_.variable_representation_shape_fn(arg.shape, arg.type);
options_.shape_representation_fn(arg.shape, arg.type);
return TensorShapeToXLAShape(arg.type, representation_shape,
xla_shape);
}
@ -337,16 +339,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
Status BuildComputation(
const std::vector<XlaCompiler::Argument>& args,
const std::vector<int>& arg_cores,
const std::vector<XlaExpression>& retvals,
const std::vector<XlaContext::Retval>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
bool return_updated_values_for_all_resources, xla::XlaBuilder* builder,
xla::XlaComputation* computation, int* num_computation_outputs,
int* num_nonconst_outputs,
std::vector<XlaCompiler::OutputDescription>* outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size());
for (const XlaExpression& retval : retvals) {
if (!retval.has_constant_value()) {
for (int i = 0; i < retvals.size(); ++i) {
XlaCompiler::OutputDescription& output = (*outputs)[i];
output.type = retvals[i].type;
output.shape = retvals[i].shape;
const XlaExpression& retval = retvals[i].expression;
if (retval.has_constant_value()) {
output.is_constant = true;
output.constant_value = retval.constant_value();
} else {
output.is_constant = false;
elems.push_back(retval.handle());
}
}
@ -490,8 +501,8 @@ Status XlaCompiler::BuildArguments(
std::vector<xla::Shape> arg_shapes(input_mapping->size());
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
// Computes the shapes of non-constant arguments.
TF_RETURN_IF_ERROR(
XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i]));
TF_RETURN_IF_ERROR(XLAShapeForArgument(
args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i]));
}
if (use_tuple_arg) {
@ -567,7 +578,8 @@ Status XlaCompiler::BuildArguments(
builder->ClearOpMetadata();
// Fill in the handles in non-constant arguments.
// Fill in the handles in non-constant arguments, and reshape parameters
// back to their correct shapes.
VLOG(2) << "XLA computation inputs:";
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
@ -586,7 +598,15 @@ Status XlaCompiler::BuildArguments(
break;
}
case XlaCompiler::Argument::kParameter:
arg_expression.set_handle(arg_handles[i]);
// Reshape parameters back to their correct shapes.
// TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally.
if (is_entry_computation) {
arg_expression.set_handle(
builder->Reshape(arg_handles[i], arg.shape.dim_sizes()));
} else {
arg_expression.set_handle(arg_handles[i]);
}
break;
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid:
@ -661,10 +681,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
xla::XlaBuilder builder(name);
XlaContext* context =
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
options.resolve_compile_time_constants,
&options_.variable_representation_shape_fn);
XlaContext* context = new XlaContext(
this, &builder, options_.allow_cpu_custom_calls,
options.resolve_compile_time_constants, options.is_entry_computation,
&options_.shape_representation_fn);
core::ScopedUnref context_unref(context);
std::vector<XlaExpression> arg_expressions;
@ -681,35 +701,22 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_nonconst_outputs;
int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>();
result->outputs.resize(context->retvals().size());
TF_RETURN_IF_ERROR(BuildComputation(
args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources, &builder,
result->computation.get(), &num_computation_outputs,
&num_nonconst_outputs, &result->resource_updates));
&num_nonconst_outputs, &result->outputs, &result->resource_updates));
VLOG(2) << "Outputs: total: " << context->retvals().size()
<< " nonconstant: " << num_nonconst_outputs;
result->outputs.resize(context->retvals().size());
for (std::vector<XlaExpression>::size_type i = 0;
i < context->retvals().size(); ++i) {
const XlaExpression& retval = context->retvals()[i];
if (retval.has_constant_value()) {
OutputDescription& output = result->outputs[i];
output.shape = retval.constant_value().shape();
output.is_constant = true;
output.constant_value = retval.constant_value();
}
}
// Compute the output shapes, if there is a computation with non-constant
// Compute the XLA output shape, if there is a computation with non-constant
// outputs.
auto computation_shape = client()->GetComputationShape(*result->computation);
if (!computation_shape.ok()) {
return computation_shape.status();
}
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> computation_shape,
client()->GetComputationShape(*result->computation));
result->xla_output_shape.Swap(
computation_shape.ValueOrDie()->mutable_result());
result->xla_output_shape.Swap(computation_shape->mutable_result());
VLOG(2) << "XLA output shape: "
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
@ -724,23 +731,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
// Tensorflow expects a major-to-minor order of results.
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
// Converts the output shapes to TensorShapes.
int computation_output = 0;
for (std::vector<XlaExpression>::size_type i = 0;
i < context->retvals().size(); ++i) {
const XlaExpression& retval = context->retvals()[i];
if (!retval.has_constant_value()) {
TF_RET_CHECK(computation_output < num_computation_outputs)
<< "Computation has more outputs than expected";
OutputDescription& output = result->outputs[i];
output.is_constant = false;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
computation_output),
&output.shape));
++computation_output;
}
}
return Status::OK();
}

View File

@ -67,6 +67,15 @@ class XlaContext;
// _Retval values are ordered by _Retval index, whereas kResource values are
// ordered by the original _Arg position of the variable.
//
// If a shape representation function is provided as part of
// XlaCompiler::CompileOptions, kParameter arguments and return values to an
// entry computation will be reshaped in accordance to the shape function.
// Arguments and return values to a non-entry computation are not reshaped.
// Variable resource arguments are passed and returned in reshaped form, even
// for non-entry computations. This feature allows TensorFlow to keep on-device
// tensors with a different shape to their representation inside the XLA
// computation.
//
// In both inputs and outputs, kResource values are placed the end. When
// emitting While loop bodies, we must ensure that the loop body has
// identical input and output signatures. By moving variable values
@ -171,7 +180,7 @@ class XlaCompiler {
};
struct OutputDescription {
// Type and shape of the output.
// Type and shape of the output. The shape is the unflattened shape.
DataType type;
TensorShape shape;
@ -206,10 +215,12 @@ class XlaCompiler {
// original arguments, and are not necessarily in the same order.)
std::vector<int> input_mapping;
// Input shapes of the computation.
// Input shapes of the computation. If we are flattening inputs, these are
// the flattened shapes.
std::vector<xla::Shape> xla_input_shapes;
// Output shape in XLA format. The output shape is always a tuple.
// Output shape in XLA format. The output shape is always a tuple. If we
// are flattening outputs, these are the flattened shapes.
xla::Shape xla_output_shape;
// TensorFlow shapes of outputs, together with the values of any
@ -230,6 +241,8 @@ class XlaCompiler {
std::shared_ptr<xla::XlaComputation> computation;
};
typedef std::function<TensorShape(const TensorShape&, DataType)>
ShapeRepresentationFn;
struct Options {
// Name of the compilation device to use. Needs to be live only during
// XlaCompiler's constructor.
@ -250,8 +263,7 @@ class XlaCompiler {
// If set, the XLA representation of variables represented to XLA as the
// shape given by this shape function. Variables are reshaped to this shape
// on write, and reshaped to their original shape on read.
std::function<TensorShape(const TensorShape&, DataType)>
variable_representation_shape_fn;
ShapeRepresentationFn shape_representation_fn;
// If not nullptr, populate_resource_manager is called with the
// compilation device's resource manager when the compilation
@ -300,7 +312,8 @@ class XlaCompiler {
// Returns the shape of the XLA parameter for an argument 'arg'.
// See the class comment for more details about the argument passing
// convention.
Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
xla::Shape* xla_shape);
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
@ -750,10 +751,7 @@ TEST_F(XlaCompilerTest, Variables) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
// Tests a simple graph that reads and writes a variable, with a
// variable_representation_shape_fn passed to the compiler that flattens all
// variable tensors to vectors.
TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
@ -764,7 +762,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
return std::move(graph);
}
// Tests a simple graph that reads and writes a variable, with a
// shape_representation_fn passed to the compiler that flattens all
// variable tensors to vectors.
TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
@ -779,15 +785,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.variable_representation_shape_fn = [](const TensorShape& shape,
DataType type) {
options.shape_representation_fn = [](const TensorShape& shape,
DataType type) {
return TensorShape({shape.num_elements()});
};
XlaCompiler compiler(options);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = false; // Only reshape variables.
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
client_->GetComputationShape(*result.computation));
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(
xla::ShapeUtil::Compatible(program_shape->parameters(0),
xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->result(),
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
@ -815,5 +839,74 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 2});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 2});
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.shape_representation_fn = [](const TensorShape& shape,
DataType type) {
return TensorShape({shape.num_elements()});
};
XlaCompiler compiler(options);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true; // Reshape args and retvals.
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
client_->GetComputationShape(*result.computation));
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->result(),
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {4}),
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::Literal::CreateR1<int32>({4, 55, 1, -3});
std::unique_ptr<xla::Literal> param1_literal =
xla::Literal::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::Literal::CreateR1<int32>({27, 67, 35, 402});
std::unique_ptr<xla::Literal> expected1 =
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
} // namespace
} // namespace tensorflow

View File

@ -65,26 +65,30 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
XlaContext::XlaContext(
XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
const std::function<TensorShape(const TensorShape&, DataType)>*
variable_representation_shape_fn)
shape_representation_fn)
: compiler_(compiler),
builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls),
resolve_compile_time_constants_(resolve_compile_time_constants),
variable_representation_shape_fn_(variable_representation_shape_fn) {}
is_entry_computation_(is_entry_computation),
shape_representation_fn_(shape_representation_fn) {}
string XlaContext::DebugString() { return "TLA JIT context"; }
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
void XlaContext::AddRetval(int retval_index, DataType type,
const xla::XlaOp& handle) {
const TensorShape& shape, const xla::XlaOp& handle) {
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
// Add the return value to the list being built up.
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
retvals_[retval_index].set_handle(handle);
XlaExpression e;
e.set_handle(handle);
retvals_[retval_index] = Retval{type, shape, e};
}
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
@ -94,13 +98,11 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
if (resolve_compile_time_constants_) {
Tensor value;
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
retvals_[retval_index].set_constant_value(std::move(value));
} else {
retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal));
}
Tensor value;
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
XlaExpression e;
e.set_constant_value(value);
retvals_[retval_index] = Retval{dtype, value.shape(), e};
return Status::OK();
}
@ -117,9 +119,9 @@ Status XlaContext::CreateResource(
return Status::OK();
}
TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape,
DataType type) const {
return (*variable_representation_shape_fn_)(shape, type);
TensorShape XlaContext::RepresentationShape(const TensorShape& shape,
DataType type) const {
return (*shape_representation_fn_)(shape, type);
}
const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {

View File

@ -42,11 +42,13 @@ class XlaContext : public ResourceBase {
static XlaContext& Get(const OpKernelContext* ctx);
static XlaContext& Get(const XlaOpKernelContext* ctx);
// Creates a new XlaContext.
// Creates a new XlaContext. See the documentation on the class data fields
// for descriptions of the arguments.
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
const std::function<TensorShape(const TensorShape&, DataType)>*
variable_representation_shape_fn);
shape_representation_fn);
// Virtual method defined by ResourceBase.
string DebugString() override;
@ -58,14 +60,26 @@ class XlaContext : public ResourceBase {
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
bool resolve_compile_time_constants() const {
return resolve_compile_time_constants_;
}
bool is_entry_computation() const { return is_entry_computation_; }
const std::vector<XlaExpression>& args() const { return args_; }
void set_args(std::vector<XlaExpression> args);
const std::vector<XlaExpression>& retvals() { return retvals_; }
struct Retval {
DataType type;
TensorShape shape;
// An XlaExpression representing the Retval's value.
XlaExpression expression;
};
const std::vector<Retval>& retvals() { return retvals_; }
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle);
void AddRetval(int retval_index, DataType type, const TensorShape& shape,
const xla::XlaOp& handle);
// As for Retval, but for return values that are compile-time constants.
Status AddConstRetval(int retval_index, DataType dtype,
@ -86,9 +100,9 @@ class XlaContext : public ResourceBase {
}
// Returns the XLA shape to be used to represent a variable of TF `shape`
// and `type`.
TensorShape VariableRepresentationShape(const TensorShape& shape,
DataType type) const;
// and `type`, or of an argument or return value of a top-level computation.
TensorShape RepresentationShape(const TensorShape& shape,
DataType type) const;
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
@ -131,15 +145,23 @@ class XlaContext : public ResourceBase {
std::vector<XlaExpression> args_;
// Return values of the Tensorflow graph, indexed by _Retval index.
std::vector<XlaExpression> retvals_;
std::vector<Retval> retvals_;
// Holds ownership of resources. The resources are not ordered.
std::vector<std::unique_ptr<XlaResource>> resources_;
// A function that describes how variable shapes should be represented
// in XLA. Variable values will be reshaped to this shape. Must be non-null.
// Is this a top-level computation, or an inner computation (e.g., a while
// body)?
const bool is_entry_computation_;
// A function that describes how the shapes of
// a) argument and return value, for entry computations
// b) variables, for all computations,
// should be represented in XLA. Parameters/return values will be shaped
// according to this function, and reshaped back to/from their declared shapes
// for computations. Must be non-null.
const std::function<TensorShape(const TensorShape&, DataType)>*
variable_representation_shape_fn_;
shape_representation_fn_;
// Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::XlaComputation>;

View File

@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
}
XlaContext& xla_context = XlaContext::Get(context_);
TensorShape representation_shape = xla_context.VariableRepresentationShape(
variable->shape(), variable->type());
TensorShape representation_shape =
xla_context.RepresentationShape(variable->shape(), variable->type());
if (representation_shape == variable->shape()) {
*value = variable->value();
} else {
@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
XlaContext& xla_context = XlaContext::Get(context_);
TensorShape representation_shape =
xla_context.VariableRepresentationShape(shape, type);
xla_context.RepresentationShape(shape, type);
if (shape != representation_shape) {
handle = builder()->Reshape(handle, representation_shape.dim_sizes());
}