Automated g4 rollback of changelist 196691101

PiperOrigin-RevId: 196879933
This commit is contained in:
Peter Hawkins 2018-05-16 13:34:10 -07:00 committed by TensorFlower Gardener
parent c9e4705d62
commit 41af9782f4
22 changed files with 507 additions and 256 deletions

View File

@ -551,14 +551,16 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for"); auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]"); auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr( auto dot_profile_line = HasSubstr(
"%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)"); "%arg1.0.1)");
auto add_profile_line = HasSubstr( auto add_profile_line = HasSubstr(
"%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)"); "%arg1.0.1)");
auto tuple_profile_line = HasSubstr( auto tuple_profile_line = HasSubstr(
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
"%dot.0.2, f32[2,2]{1,0} %add.0.5)"); "%dot.0.4, f32[2,2]{1,0} %add.0.6)");
auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
EXPECT_THAT(hlo_profile_lines, EXPECT_THAT(hlo_profile_lines,
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,

View File

@ -112,7 +112,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
// this is more obviously correct.) // this is more obviously correct.)
core::ScopedUnref cache_ref(cache); core::ScopedUnref cache_ref(cache);
const XlaDevice::Metadata* metadata; const XlaDevice::Metadata* metadata = nullptr;
Status s = XlaDevice::GetMetadata(ctx, &metadata); Status s = XlaDevice::GetMetadata(ctx, &metadata);
bool allocate_xla_tensors = s.ok(); bool allocate_xla_tensors = s.ok();
@ -153,9 +153,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
options.graph_def_version = ctx->function_library()->graph_def_version(); options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
options.device_allocator = xla_allocator; options.device_allocator = xla_allocator;
// TODO(b/77671268): We don't set variable_representation_shape_fn here. This if (metadata) {
// is restricted to Variables, but we need something like this to apply to options.shape_representation_fn = metadata->shape_representation_fn();
// normal Tensors too. }
const XlaCompiler::CompilationResult* kernel; const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable; xla::LocalExecutable* executable;
@ -164,9 +164,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
for (int i : constants_) { for (int i : constants_) {
constant_args.insert({i, ctx->input(i)}); constant_args.insert({i, ctx->input(i)});
} }
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, XlaCompiler::CompileOptions compile_options;
variables, ctx, &kernel, &executable, compile_options.is_entry_computation = true;
/*compile_options=*/nullptr)); OP_REQUIRES_OK(
ctx, cache->Compile(options, function_, constant_args, variables, ctx,
&kernel, &executable, &compile_options));
VLOG(1) << "Executing XLA Computation..."; VLOG(1) << "Executing XLA Computation...";

View File

@ -156,11 +156,14 @@ Status XlaCompileOnDemandOp::Compile(
options.client = metadata.client(); options.client = metadata.client();
options.flib_def = options.flib_def =
new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
options.shape_representation_fn = metadata.shape_representation_fn();
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
std::map<int, OptionalTensor> variable_args = GetVariables(ctx); std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
result, executable, result, executable, &compile_options);
/*compile_options=*/nullptr);
} }
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {

View File

@ -50,10 +50,11 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
(void)registrations; (void)registrations;
std::unique_ptr<XlaDevice> device; std::unique_ptr<XlaDevice> device;
TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, TF_RETURN_IF_ERROR(
DEVICE_CPU_XLA_JIT, options, name_prefix, XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options,
registration, name_prefix, registration,
/*transfer_as_literal=*/false, &device)); /*transfer_as_literal=*/false,
/*shape_representation_fn=*/{}, &device));
devices->push_back(device.release()); devices->push_back(device.release());
return Status::OK(); return Status::OK();
} }

View File

@ -110,7 +110,9 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
const string& jit_device_name, const SessionOptions& options, const string& jit_device_name, const SessionOptions& options,
const string& name_prefix, const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration, const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) { bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
<< device_ordinal; << device_ordinal;
@ -129,17 +131,19 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device")); strings::StrCat("device: ", device_name, " device"));
device->reset(new XlaDevice(options, attrs, device_ordinal, device->reset(new XlaDevice(
DeviceType(jit_device_name), options, attrs, device_ordinal, DeviceType(jit_device_name),
platform.ValueOrDie(), transfer_as_literal)); platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
return Status::OK(); return Status::OK();
} }
XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, XlaDevice::Metadata::Metadata(
const DeviceType& device_type) int device_ordinal, se::Platform* platform, const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: device_ordinal_(device_ordinal), : device_ordinal_(device_ordinal),
device_type_(device_type), device_type_(device_type),
platform_(platform) {} platform_(platform),
shape_representation_fn_(std::move(shape_representation_fn)) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@ -170,17 +174,20 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return Status::OK(); return Status::OK();
} }
XlaDevice::XlaDevice(const SessionOptions& options, XlaDevice::XlaDevice(
const DeviceAttributes& attrs, int device_ordinal, const SessionOptions& options, const DeviceAttributes& attrs,
const DeviceType& jit_device_name, se::Platform* platform, int device_ordinal, const DeviceType& jit_device_name,
bool transfer_as_literal) se::Platform* platform, bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
: LocalDevice(options, attrs), : LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name), xla_metadata_(device_ordinal, platform, jit_device_name,
shape_representation_fn),
device_ordinal_(device_ordinal), device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name), jit_device_name_(jit_device_name),
xla_allocator_(nullptr), xla_allocator_(nullptr),
platform_(platform), platform_(platform),
transfer_as_literal_(transfer_as_literal) { transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(shape_representation_fn) {
VLOG(1) << "Created XLA device " << jit_device_name; VLOG(1) << "Created XLA device " << jit_device_name;
} }
@ -232,8 +239,8 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
// gpu_device_info_->default_context. // gpu_device_info_->default_context.
gpu_device_info_ = MakeUnique<GpuDeviceInfo>(); gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
gpu_device_info_->stream = stream; gpu_device_info_->stream = stream;
gpu_device_info_->default_context = gpu_device_info_->default_context = new XlaDeviceContext(
new XlaDeviceContext(stream, client(), transfer_as_literal_); stream, client(), transfer_as_literal_, shape_representation_fn_);
set_tensorflow_gpu_device_info(gpu_device_info_.get()); set_tensorflow_gpu_device_info(gpu_device_info_.get());
} }
@ -247,7 +254,8 @@ Status XlaDevice::FillContextMap(const Graph* graph,
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
// Call GetAllocator for the side-effect of ensuring the allocator is created. // Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({}); GetAllocator({});
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_); auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
shape_representation_fn_);
for (Node* n : graph->nodes()) { for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref(); ctx->Ref();
@ -294,7 +302,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n; Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
XlaTransferManager manager(stream, client(), transfer_as_literal_); XlaTransferManager manager(stream, client(), transfer_as_literal_,
shape_representation_fn_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy, manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) { [&n, &status](const Status& s) {
status = s; status = s;

View File

@ -17,8 +17,7 @@ limitations under the License.
// runtime. // runtime.
// //
// Operators assigned to an XlaDevice are compiled into XLA computations. // Operators assigned to an XlaDevice are compiled into XLA computations.
// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state // Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
// is managed by XLA.
// //
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
// under different names (e.g., XLA_CPU or XLA_GPU). // under different names (e.g., XLA_CPU or XLA_GPU).
@ -27,6 +26,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_factory.h"
@ -50,7 +50,8 @@ class XlaDevice : public LocalDevice {
class Metadata { class Metadata {
public: public:
Metadata(int device_ordinal, se::Platform* platform, Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type); const DeviceType& device_type,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
// The index of the device on this host. // The index of the device on this host.
int device_ordinal() const; int device_ordinal() const;
@ -58,11 +59,15 @@ class XlaDevice : public LocalDevice {
se::Platform* platform() const; se::Platform* platform() const;
xla::LocalClient* client() const; xla::LocalClient* client() const;
const DeviceType& jit_device_type() const; const DeviceType& jit_device_type() const;
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
return shape_representation_fn_;
}
private: private:
const int device_ordinal_; const int device_ordinal_;
const DeviceType device_type_; const DeviceType device_type_;
se::Platform* platform_; // Not owned. se::Platform* platform_; // Not owned.
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata); TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
}; };
@ -76,16 +81,19 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using // 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use // XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead. // ThenMemcpy instead.
static Status Create(const string& platform_name, const string& device_name, static Status Create(
int device_ordinal, const string& jit_device_name, const string& platform_name, const string& device_name,
const SessionOptions& options, const string& name_prefix, int device_ordinal, const string& jit_device_name,
const XlaOpRegistry::DeviceRegistration& registration, const SessionOptions& options, const string& name_prefix,
bool transfer_as_literal, const XlaOpRegistry::DeviceRegistration& registration,
std::unique_ptr<XlaDevice>* device); bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
std::unique_ptr<XlaDevice>* device);
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name, int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal); se::Platform* platform, bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn);
~XlaDevice() override; ~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override; Allocator* GetAllocator(AllocatorAttributes attr) override;
@ -116,8 +124,8 @@ class XlaDevice : public LocalDevice {
// The name of the device that is used to compile Ops for this XlaDevice. // The name of the device that is used to compile Ops for this XlaDevice.
DeviceType jit_device_name_; DeviceType jit_device_name_;
// Memory allocator associated with this device. // Memory allocator associated with this device.
Allocator* xla_allocator_; // Not owned. Allocator* xla_allocator_; // Not owned.
se::Platform* platform_; // Not owned. se::Platform* platform_; // Not owned.
// Stream associated with this device. Operations enqueued on this // Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data // stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and // copying back and forth between CPU and the device, and
@ -126,6 +134,7 @@ class XlaDevice : public LocalDevice {
// Must we use XLA's transfer manager for correct host<->device transfers? if // Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead. // false, we can use ThenMemcpy() instead.
bool transfer_as_literal_; bool transfer_as_literal_;
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// If set, holds default device context (that we must Unref) // If set, holds default device context (that we must Unref)
// and its stream. // and its stream.

View File

@ -47,13 +47,14 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
XlaTransferManager::XlaTransferManager(se::Stream* stream, XlaTransferManager::XlaTransferManager(
xla::LocalClient* client, se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
bool transfer_as_literal) XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: stream_(stream), : stream_(stream),
client_(client), client_(client),
transfer_manager_(client->backend().transfer_manager()), transfer_manager_(client->backend().transfer_manager()),
transfer_as_literal_(transfer_as_literal) {} transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)) {}
Status XlaTransferManager::TransferLiteralToDevice( Status XlaTransferManager::TransferLiteralToDevice(
const Tensor& host_tensor, Tensor* device_tensor) const { const Tensor& host_tensor, Tensor* device_tensor) const {
@ -76,7 +77,15 @@ Status XlaTransferManager::TransferLiteralFromDevice(
transfer_manager_->TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice(
stream_->parent(), shaped_buffer)); stream_->parent(), shaped_buffer));
VLOG(1) << "Transfer from device as literal: " << literal->ToString(); VLOG(1) << "Transfer from device as literal: " << literal->ToString();
return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor); Tensor tensor;
TF_RETURN_IF_ERROR(
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
// Reshape the tensor back to its declared shape.
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
return errors::Internal(
"Tensor::CopyFrom failed when copying from XLA device to CPU");
}
return Status::OK();
} }
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
@ -96,9 +105,17 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor); CHECK(xla_tensor);
TensorShape shape;
if (shape_representation_fn_) {
shape = shape_representation_fn_(device_tensor->shape(),
device_tensor->dtype());
} else {
shape = device_tensor->shape();
}
if (!xla_tensor->has_shaped_buffer()) { if (!xla_tensor->has_shaped_buffer()) {
Status s = xla_tensor->AllocateShapedBuffer( Status s = xla_tensor->AllocateShapedBuffer(
device_tensor->dtype(), device_tensor->shape(), client_, device_tensor->dtype(), shape, client_,
stream_->parent()->device_ordinal()); stream_->parent()->device_ordinal());
if (!s.ok()) { if (!s.ok()) {
done(s); done(s);
@ -106,12 +123,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
} }
} }
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
Status status; Status status;
if (transfer_as_literal_) { if (transfer_as_literal_) {
status = TransferLiteralToDevice(*cpu_tensor, device_tensor); Tensor reshaped_cpu_tensor;
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
done(errors::Internal(
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
} else { } else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous. // TODO(hpucha): Make this asynchronous.
Status block_status = stream_->BlockHostUntilDone(); Status block_status = stream_->BlockHostUntilDone();
@ -171,9 +194,11 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
done(Status::OK()); done(Status::OK());
} }
XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, XlaDeviceContext::XlaDeviceContext(
bool transfer_as_literal) se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
: manager_(stream, client, transfer_as_literal) {} XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: manager_(stream, client, transfer_as_literal,
std::move(shape_representation_fn)) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device, Device* device,

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/compiler/jit/xla_tensor.h" #include "tensorflow/compiler/jit/xla_tensor.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/allocator.h"
@ -45,8 +46,9 @@ class XlaDeviceAllocator : public Allocator {
// Helper class for managing data transfers between host and XLA devices. // Helper class for managing data transfers between host and XLA devices.
class XlaTransferManager { class XlaTransferManager {
public: public:
explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client, explicit XlaTransferManager(
bool transfer_as_literal); se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const; Tensor* device_tensor, StatusCallback done) const;
@ -69,7 +71,8 @@ class XlaTransferManager {
// Transfer manager, for marshalling data to and from the device. // Transfer manager, for marshalling data to and from the device.
xla::TransferManager* transfer_manager_; xla::TransferManager* transfer_manager_;
// True if we must use XLA's TransferManager for correct device transfers. // True if we must use XLA's TransferManager for correct device transfers.
bool transfer_as_literal_; const bool transfer_as_literal_;
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
}; };
// DeviceContext for operators assigned to XlaDevice devices. The // DeviceContext for operators assigned to XlaDevice devices. The
@ -77,8 +80,9 @@ class XlaTransferManager {
// wraps the methods in XlaTransferManager. // wraps the methods in XlaTransferManager.
class XlaDeviceContext : public DeviceContext { class XlaDeviceContext : public DeviceContext {
public: public:
explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, explicit XlaDeviceContext(
bool transfer_as_literal); se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, Tensor* device_tensor,

View File

@ -48,7 +48,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
Status status = Status status =
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration, name_prefix, registration,
/*transfer_as_literal=*/false, &device); /*transfer_as_literal=*/false,
/*shape_representation_fn=*/{}, &device);
if (!status.ok()) { if (!status.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine. // Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << status; VLOG(1) << "Failed to create XLA_GPU device: " << status;

View File

@ -195,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs(
OP_REQUIRES_OK( OP_REQUIRES_OK(
ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer(
const_tensor.dtype(), const_tensor.shape(),
client_, stream->parent()->device_ordinal()));
}
Device* device = dynamic_cast<Device*>(ctx->device()); Device* device = dynamic_cast<Device*>(ctx->device());
OP_REQUIRES(ctx, device != nullptr, OP_REQUIRES(ctx, device != nullptr,

View File

@ -42,7 +42,7 @@ py_library(
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client", "//tensorflow/python:client",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform", "//tensorflow/python:platform",
"//tensorflow/python:random_seed", "//tensorflow/python:random_seed",
"//tensorflow/python:session", "//tensorflow/python:session",
@ -58,7 +58,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training", "//tensorflow/python:training",
@ -72,7 +72,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training", "//tensorflow/python:training",
@ -93,7 +93,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
@ -111,7 +111,7 @@ tf_xla_py_test(
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:bitwise_ops", "//tensorflow/python:bitwise_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen", "//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
@ -127,7 +127,7 @@ tf_xla_py_test(
tags = ["optonly"], tags = ["optonly"],
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:random_ops", "//tensorflow/python:random_ops",
], ],
@ -141,7 +141,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training", "//tensorflow/python:training",
@ -156,7 +156,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training", "//tensorflow/python:training",
@ -170,7 +170,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
@ -184,7 +184,7 @@ tf_xla_py_test(
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen", "//tensorflow/python:array_ops_gen",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:gradient_checker", "//tensorflow/python:gradient_checker",
"//tensorflow/python:gradients", "//tensorflow/python:gradients",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
@ -209,7 +209,7 @@ tf_xla_py_test(
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen", "//tensorflow/python:array_ops_gen",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:gradient_checker", "//tensorflow/python:gradient_checker",
"//tensorflow/python:gradients", "//tensorflow/python:gradients",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
@ -225,7 +225,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:nn", "//tensorflow/python:nn",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen", "//tensorflow/python:nn_ops_gen",
@ -241,7 +241,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:nn", "//tensorflow/python:nn",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen", "//tensorflow/python:nn_ops_gen",
@ -263,7 +263,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:nn", "//tensorflow/python:nn",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen", "//tensorflow/python:nn_ops_gen",
@ -291,7 +291,7 @@ tf_xla_py_test(
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops", "//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -307,7 +307,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -326,7 +326,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:layers", "//tensorflow/python:layers",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:nn", "//tensorflow/python:nn",
@ -346,7 +346,7 @@ tf_xla_py_test(
"//tensorflow/contrib/signal:signal_py", "//tensorflow/contrib/signal:signal_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:extra_py_tests_deps", "//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:spectral_ops", "//tensorflow/python:spectral_ops",
], ],
@ -360,7 +360,7 @@ tf_xla_py_test(
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops", "//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -372,7 +372,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training", "//tensorflow/python:training",
@ -388,7 +388,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -403,7 +403,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:image_ops", "//tensorflow/python:image_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
@ -431,7 +431,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:nn", "//tensorflow/python:nn",
"//tensorflow/python:nn_ops_gen", "//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
@ -446,7 +446,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -458,7 +458,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training", "//tensorflow/python:training",
@ -472,7 +472,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
@ -485,7 +485,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -498,7 +498,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen", "//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
@ -513,7 +513,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen", "//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
@ -530,7 +530,7 @@ tf_xla_py_test(
], ],
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:random_ops", "//tensorflow/python:random_ops",
], ],
@ -545,7 +545,7 @@ tf_xla_py_test(
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
@ -561,7 +561,7 @@ tf_xla_py_test(
"//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:errors", "//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
@ -574,7 +574,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
], ],
) )
@ -586,7 +586,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -598,7 +598,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training", "//tensorflow/python:training",
@ -613,7 +613,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
@ -626,7 +626,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen", "//tensorflow/python:math_ops_gen",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
@ -641,7 +641,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
@ -657,7 +657,7 @@ tf_xla_py_test(
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops", "//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -670,7 +670,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/contrib/stateless", "//tensorflow/contrib/stateless",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -684,7 +684,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen", "//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
@ -703,7 +703,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
@ -716,7 +716,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen", "//tensorflow/python:nn_ops_gen",
@ -730,7 +730,7 @@ tf_xla_py_test(
srcs = ["fused_batchnorm_test.py"], srcs = ["fused_batchnorm_test.py"],
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen", "//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn", "//tensorflow/python:nn",
@ -749,7 +749,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen", "//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops",
@ -768,7 +768,7 @@ tf_xla_py_test(
":xla_test", ":xla_test",
"//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
"//tensorflow/python:training", "//tensorflow/python:training",
], ],
@ -783,7 +783,7 @@ tf_xla_py_test(
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops", "//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -795,7 +795,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -808,21 +808,34 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
cuda_py_test( cuda_py_test(
name = "xla_device_test", name = "xla_device_gpu_test",
size = "small", size = "small",
srcs = ["xla_device_test.py"], srcs = ["xla_device_gpu_test.py"],
additional_deps = [ additional_deps = [
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client", "//tensorflow/python:client",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
], ],
) )
@ -839,7 +852,6 @@ cuda_py_test(
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework", "//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients", "//tensorflow/python:gradients",
"//tensorflow/python:layers", "//tensorflow/python:layers",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
@ -887,7 +899,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:random_ops", "//tensorflow/python:random_ops",
"//tensorflow/python:variables", "//tensorflow/python:variables",
@ -902,7 +914,7 @@ cuda_py_test(
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:gradients", "//tensorflow/python:gradients",
"//tensorflow/python:init_ops", "//tensorflow/python:init_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
@ -940,7 +952,7 @@ tf_xla_py_test(
srcs = ["fake_quant_ops_test.py"], srcs = ["fake_quant_ops_test.py"],
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )
@ -952,7 +964,7 @@ tf_xla_py_test(
deps = [ deps = [
":xla_test", ":xla_test",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",
], ],
) )

View File

@ -0,0 +1,48 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test cases for XLA devices."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class XlaDeviceGpuTest(test.TestCase):
def testCopiesToAndFromGpuWork(self):
"""Tests that copies between GPU and XLA devices work."""
if not test.is_gpu_available():
return
with session_lib.Session() as sess:
x = array_ops.placeholder(dtypes.float32, [2])
with ops.device("GPU"):
y = x * 2
with ops.device("device:XLA_CPU:0"):
z = y * y
with ops.device("GPU"):
w = y + z
result = sess.run(w, {x: [1.5, 0.5]})
self.assertAllClose(result, [12., 2.], rtol=1e-3)
if __name__ == "__main__":
test.main()

View File

@ -1,4 +1,4 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,30 +18,33 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.client import session as session_lib import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
class XlaDeviceTest(test.TestCase): class XlaDeviceTest(XLATestCase):
def testCopies(self): def testCopies(self):
"""Tests that copies between GPU and XLA devices work.""" """Tests that copies onto and off XLA devices work."""
if not test.is_gpu_available(): shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],
return [16384, 1], [1, 16384], [1, 20000, 1, 1]]
for dtype in self.numeric_types:
for shape in shapes:
with self.test_session() as sess:
with ops.device("CPU"):
x = array_ops.placeholder(dtype, shape)
with self.test_scope():
y = x + x
with ops.device("CPU"):
z = array_ops.identity(y)
with session_lib.Session() as sess: inputs = np.random.randint(-100, 100, shape).astype(dtype)
x = array_ops.placeholder(dtypes.float32, [2]) result = sess.run(z, {x: inputs})
with ops.device("GPU"): self.assertAllCloseAccordingToType(result, inputs + inputs)
y = x * 2
with ops.device("device:XLA_CPU:0"):
z = y * y
with ops.device("GPU"):
w = y + z
result = sess.run(w, {x: [1.5, 0.5]})
self.assertAllClose(result, [12., 2.], rtol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -325,6 +325,7 @@ tf_cc_test(
"//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:cpu_plugin", "//tensorflow/compiler/xla/service:cpu_plugin",

View File

@ -208,10 +208,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = false;
XlaCompiler::CompilationResult result; XlaCompiler::CompilationResult result;
TF_RETURN_IF_ERROR(
TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), compiler->CompileFunction(compile_options, func, arguments, &result));
func, arguments, &result));
TF_RET_CHECK(arguments.size() == expressions.size()); TF_RET_CHECK(arguments.size() == expressions.size());

View File

@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel {
} }
XlaContext& tc = XlaContext::Get(ctx); XlaContext& tc = XlaContext::Get(ctx);
if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { if (tc.resolve_compile_time_constants() &&
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
xla::Literal literal; xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else { } else {
// The core from which a return value is returned depends on the core TensorShape shape = ctx->InputShape(0);
// assignment of the input to the retval .Since we can't change the core TensorShape representation_shape =
// assignment of <input> as this point, create a tuple/get-tuple-element tc.is_entry_computation()
// combination so that the core will be set on them. ? tc.RepresentationShape(shape, ctx->input_type(0))
auto tuple_elem = : shape;
ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0);
tc.AddRetval(index_, dtype_, tuple_elem); xla::XlaOp output = input;
if (tc.is_entry_computation()) {
output =
ctx->builder()->Reshape(input, representation_shape.dim_sizes());
} else {
// The core from which a return value is returned depends on the
// device assignment of the input to the retval. Since we can't change
// the device assignment of "input" at this point, we must always
// introduce an operator here, even if the shape does not change.
// TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally.
output = ctx->builder()->GetTupleElement(
ctx->builder()->Tuple({output}), 0);
}
tc.AddRetval(index_, dtype_, shape, output);
} }
} }
} }

View File

@ -15,10 +15,9 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include <deque>
#include <numeric> #include <numeric>
#include <vector>
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h"
@ -28,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/executor.h"
@ -40,7 +38,6 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -110,10 +107,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
flib_runtime_ = pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name());
// The default variable representation shape is the identity function. // The default shape representation function is the identity.
if (!options_.variable_representation_shape_fn) { if (!options_.shape_representation_fn) {
options_.variable_representation_shape_fn = options_.shape_representation_fn = [](const TensorShape& shape,
[](const TensorShape& shape, DataType type) { return shape; }; DataType type) { return shape; };
} }
} }
@ -230,20 +227,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
// Computes the XLA shape for argument 'arg'. // Computes the XLA shape for argument 'arg'.
Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
bool is_entry_computation,
xla::Shape* xla_shape) { xla::Shape* xla_shape) {
switch (arg.kind) { switch (arg.kind) {
case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kConstant:
return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), LOG(FATAL) << "Unreachable case";
xla_shape); case XlaCompiler::Argument::kParameter: {
case XlaCompiler::Argument::kParameter: TensorShape shape =
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); is_entry_computation
? options_.shape_representation_fn(arg.shape, arg.type)
: arg.shape;
return TensorShapeToXLAShape(arg.type, shape, xla_shape);
}
case XlaCompiler::Argument::kResource: { case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized); TF_RET_CHECK(arg.initialized);
switch (arg.resource_kind) { switch (arg.resource_kind) {
case XlaResource::kVariable: { case XlaResource::kVariable: {
TensorShape representation_shape = TensorShape representation_shape =
options_.variable_representation_shape_fn(arg.shape, arg.type); options_.shape_representation_fn(arg.shape, arg.type);
return TensorShapeToXLAShape(arg.type, representation_shape, return TensorShapeToXLAShape(arg.type, representation_shape,
xla_shape); xla_shape);
} }
@ -337,16 +339,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
Status BuildComputation( Status BuildComputation(
const std::vector<XlaCompiler::Argument>& args, const std::vector<XlaCompiler::Argument>& args,
const std::vector<int>& arg_cores, const std::vector<int>& arg_cores,
const std::vector<XlaExpression>& retvals, const std::vector<XlaContext::Retval>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources, const std::vector<std::unique_ptr<XlaResource>>& resources,
bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, bool return_updated_values_for_all_resources, xla::XlaBuilder* builder,
xla::XlaComputation* computation, int* num_computation_outputs, xla::XlaComputation* computation, int* num_computation_outputs,
int* num_nonconst_outputs, int* num_nonconst_outputs,
std::vector<XlaCompiler::OutputDescription>* outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) { std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
std::vector<xla::XlaOp> elems; std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size()); elems.reserve(retvals.size());
for (const XlaExpression& retval : retvals) { for (int i = 0; i < retvals.size(); ++i) {
if (!retval.has_constant_value()) { XlaCompiler::OutputDescription& output = (*outputs)[i];
output.type = retvals[i].type;
output.shape = retvals[i].shape;
const XlaExpression& retval = retvals[i].expression;
if (retval.has_constant_value()) {
output.is_constant = true;
output.constant_value = retval.constant_value();
} else {
output.is_constant = false;
elems.push_back(retval.handle()); elems.push_back(retval.handle());
} }
} }
@ -490,8 +501,8 @@ Status XlaCompiler::BuildArguments(
std::vector<xla::Shape> arg_shapes(input_mapping->size()); std::vector<xla::Shape> arg_shapes(input_mapping->size());
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) { for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
// Computes the shapes of non-constant arguments. // Computes the shapes of non-constant arguments.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(XLAShapeForArgument(
XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i]));
} }
if (use_tuple_arg) { if (use_tuple_arg) {
@ -567,7 +578,8 @@ Status XlaCompiler::BuildArguments(
builder->ClearOpMetadata(); builder->ClearOpMetadata();
// Fill in the handles in non-constant arguments. // Fill in the handles in non-constant arguments, and reshape parameters
// back to their correct shapes.
VLOG(2) << "XLA computation inputs:"; VLOG(2) << "XLA computation inputs:";
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) { for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
@ -586,7 +598,15 @@ Status XlaCompiler::BuildArguments(
break; break;
} }
case XlaCompiler::Argument::kParameter: case XlaCompiler::Argument::kParameter:
arg_expression.set_handle(arg_handles[i]); // Reshape parameters back to their correct shapes.
// TODO(b/76097077): propagate device assignments onto arguments and
// return values of functions, and then reshape unconditionally.
if (is_entry_computation) {
arg_expression.set_handle(
builder->Reshape(arg_handles[i], arg.shape.dim_sizes()));
} else {
arg_expression.set_handle(arg_handles[i]);
}
break; break;
case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid: case XlaCompiler::Argument::kInvalid:
@ -661,10 +681,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
xla::XlaBuilder builder(name); xla::XlaBuilder builder(name);
XlaContext* context = XlaContext* context = new XlaContext(
new XlaContext(this, &builder, options_.allow_cpu_custom_calls, this, &builder, options_.allow_cpu_custom_calls,
options.resolve_compile_time_constants, options.resolve_compile_time_constants, options.is_entry_computation,
&options_.variable_representation_shape_fn); &options_.shape_representation_fn);
core::ScopedUnref context_unref(context); core::ScopedUnref context_unref(context);
std::vector<XlaExpression> arg_expressions; std::vector<XlaExpression> arg_expressions;
@ -681,35 +701,22 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_nonconst_outputs; int num_nonconst_outputs;
int num_computation_outputs; int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>(); result->computation = std::make_shared<xla::XlaComputation>();
result->outputs.resize(context->retvals().size());
TF_RETURN_IF_ERROR(BuildComputation( TF_RETURN_IF_ERROR(BuildComputation(
args, arg_cores, context->retvals(), context->resources(), args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources, &builder, options.return_updated_values_for_all_resources, &builder,
result->computation.get(), &num_computation_outputs, result->computation.get(), &num_computation_outputs,
&num_nonconst_outputs, &result->resource_updates)); &num_nonconst_outputs, &result->outputs, &result->resource_updates));
VLOG(2) << "Outputs: total: " << context->retvals().size() VLOG(2) << "Outputs: total: " << context->retvals().size()
<< " nonconstant: " << num_nonconst_outputs; << " nonconstant: " << num_nonconst_outputs;
result->outputs.resize(context->retvals().size());
for (std::vector<XlaExpression>::size_type i = 0;
i < context->retvals().size(); ++i) {
const XlaExpression& retval = context->retvals()[i];
if (retval.has_constant_value()) {
OutputDescription& output = result->outputs[i];
output.shape = retval.constant_value().shape();
output.is_constant = true;
output.constant_value = retval.constant_value();
}
}
// Compute the output shapes, if there is a computation with non-constant // Compute the XLA output shape, if there is a computation with non-constant
// outputs. // outputs.
auto computation_shape = client()->GetComputationShape(*result->computation); TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> computation_shape,
if (!computation_shape.ok()) { client()->GetComputationShape(*result->computation));
return computation_shape.status();
}
result->xla_output_shape.Swap( result->xla_output_shape.Swap(computation_shape->mutable_result());
computation_shape.ValueOrDie()->mutable_result());
VLOG(2) << "XLA output shape: " VLOG(2) << "XLA output shape: "
<< xla::ShapeUtil::HumanString(result->xla_output_shape); << xla::ShapeUtil::HumanString(result->xla_output_shape);
@ -724,23 +731,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
// Tensorflow expects a major-to-minor order of results. // Tensorflow expects a major-to-minor order of results.
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
// Converts the output shapes to TensorShapes.
int computation_output = 0;
for (std::vector<XlaExpression>::size_type i = 0;
i < context->retvals().size(); ++i) {
const XlaExpression& retval = context->retvals()[i];
if (!retval.has_constant_value()) {
TF_RET_CHECK(computation_output < num_computation_outputs)
<< "Computation has more outputs than expected";
OutputDescription& output = result->outputs[i];
output.is_constant = false;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
computation_output),
&output.shape));
++computation_output;
}
}
return Status::OK(); return Status::OK();
} }

View File

@ -67,6 +67,15 @@ class XlaContext;
// _Retval values are ordered by _Retval index, whereas kResource values are // _Retval values are ordered by _Retval index, whereas kResource values are
// ordered by the original _Arg position of the variable. // ordered by the original _Arg position of the variable.
// //
// If a shape representation function is provided as part of
// XlaCompiler::CompileOptions, kParameter arguments and return values to an
// entry computation will be reshaped in accordance to the shape function.
// Arguments and return values to a non-entry computation are not reshaped.
// Variable resource arguments are passed and returned in reshaped form, even
// for non-entry computations. This feature allows TensorFlow to keep on-device
// tensors with a different shape to their representation inside the XLA
// computation.
//
// In both inputs and outputs, kResource values are placed the end. When // In both inputs and outputs, kResource values are placed the end. When
// emitting While loop bodies, we must ensure that the loop body has // emitting While loop bodies, we must ensure that the loop body has
// identical input and output signatures. By moving variable values // identical input and output signatures. By moving variable values
@ -171,7 +180,7 @@ class XlaCompiler {
}; };
struct OutputDescription { struct OutputDescription {
// Type and shape of the output. // Type and shape of the output. The shape is the unflattened shape.
DataType type; DataType type;
TensorShape shape; TensorShape shape;
@ -206,10 +215,12 @@ class XlaCompiler {
// original arguments, and are not necessarily in the same order.) // original arguments, and are not necessarily in the same order.)
std::vector<int> input_mapping; std::vector<int> input_mapping;
// Input shapes of the computation. // Input shapes of the computation. If we are flattening inputs, these are
// the flattened shapes.
std::vector<xla::Shape> xla_input_shapes; std::vector<xla::Shape> xla_input_shapes;
// Output shape in XLA format. The output shape is always a tuple. // Output shape in XLA format. The output shape is always a tuple. If we
// are flattening outputs, these are the flattened shapes.
xla::Shape xla_output_shape; xla::Shape xla_output_shape;
// TensorFlow shapes of outputs, together with the values of any // TensorFlow shapes of outputs, together with the values of any
@ -230,6 +241,8 @@ class XlaCompiler {
std::shared_ptr<xla::XlaComputation> computation; std::shared_ptr<xla::XlaComputation> computation;
}; };
typedef std::function<TensorShape(const TensorShape&, DataType)>
ShapeRepresentationFn;
struct Options { struct Options {
// Name of the compilation device to use. Needs to be live only during // Name of the compilation device to use. Needs to be live only during
// XlaCompiler's constructor. // XlaCompiler's constructor.
@ -250,8 +263,7 @@ class XlaCompiler {
// If set, the XLA representation of variables represented to XLA as the // If set, the XLA representation of variables represented to XLA as the
// shape given by this shape function. Variables are reshaped to this shape // shape given by this shape function. Variables are reshaped to this shape
// on write, and reshaped to their original shape on read. // on write, and reshaped to their original shape on read.
std::function<TensorShape(const TensorShape&, DataType)> ShapeRepresentationFn shape_representation_fn;
variable_representation_shape_fn;
// If not nullptr, populate_resource_manager is called with the // If not nullptr, populate_resource_manager is called with the
// compilation device's resource manager when the compilation // compilation device's resource manager when the compilation
@ -300,7 +312,8 @@ class XlaCompiler {
// Returns the shape of the XLA parameter for an argument 'arg'. // Returns the shape of the XLA parameter for an argument 'arg'.
// See the class comment for more details about the argument passing // See the class comment for more details about the argument passing
// convention. // convention.
Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
xla::Shape* xla_shape);
// Retrieves the channel handle associated with `key`. Allocates // Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists. // a new channel handle if none exists.

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/common_shape_fns.h"
@ -750,10 +751,7 @@ TEST_F(XlaCompilerTest, Variables) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
} }
// Tests a simple graph that reads and writes a variable, with a xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
// variable_representation_shape_fn passed to the compiler that flattens all
// variable tensors to vectors.
TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
Scope scope = Scope::NewRootScope().ExitOnError(); Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
@ -764,7 +762,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1)); auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get())); TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
return std::move(graph);
}
// Tests a simple graph that reads and writes a variable, with a
// shape_representation_fn passed to the compiler that flattens all
// variable tensors to vectors.
TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
// Builds a description of the arguments. // Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2); std::vector<XlaCompiler::Argument> args(2);
@ -779,15 +785,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
// Compiles the graph. // Compiles the graph.
XlaCompiler::Options options = DefaultOptions(); XlaCompiler::Options options = DefaultOptions();
options.variable_representation_shape_fn = [](const TensorShape& shape, options.shape_representation_fn = [](const TensorShape& shape,
DataType type) { DataType type) {
return TensorShape({shape.num_elements()}); return TensorShape({shape.num_elements()});
}; };
XlaCompiler compiler(options); XlaCompiler compiler(options);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = false; // Only reshape variables.
XlaCompiler::CompilationResult result; XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
std::move(graph), args, &result)); args, &result));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
client_->GetComputationShape(*result.computation));
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(
xla::ShapeUtil::Compatible(program_shape->parameters(0),
xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->result(),
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works. // Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal = std::unique_ptr<xla::Literal> param0_literal =
@ -815,5 +839,74 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
} }
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 2});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 2});
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
options.shape_representation_fn = [](const TensorShape& shape,
DataType type) {
return TensorShape({shape.num_elements()});
};
XlaCompiler compiler(options);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true; // Reshape args and retvals.
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
client_->GetComputationShape(*result.computation));
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
EXPECT_TRUE(xla::ShapeUtil::Compatible(
program_shape->result(),
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {4}),
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
xla::Literal::CreateR1<int32>({4, 55, 1, -3});
std::unique_ptr<xla::Literal> param1_literal =
xla::Literal::CreateR1<int32>({22, 11, 33, 404});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.ConsumeValueOrDie();
std::unique_ptr<xla::Literal> actual_literal =
client_->Transfer(*actual).ConsumeValueOrDie();
std::unique_ptr<xla::Literal> expected0 =
xla::Literal::CreateR1<int32>({27, 67, 35, 402});
std::unique_ptr<xla::Literal> expected1 =
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -65,26 +65,30 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
XlaContext::XlaContext( XlaContext::XlaContext(
XlaCompiler* compiler, xla::XlaBuilder* builder, XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
const std::function<TensorShape(const TensorShape&, DataType)>* const std::function<TensorShape(const TensorShape&, DataType)>*
variable_representation_shape_fn) shape_representation_fn)
: compiler_(compiler), : compiler_(compiler),
builder_(builder), builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls), allow_cpu_custom_calls_(allow_cpu_custom_calls),
resolve_compile_time_constants_(resolve_compile_time_constants), resolve_compile_time_constants_(resolve_compile_time_constants),
variable_representation_shape_fn_(variable_representation_shape_fn) {} is_entry_computation_(is_entry_computation),
shape_representation_fn_(shape_representation_fn) {}
string XlaContext::DebugString() { return "TLA JIT context"; } string XlaContext::DebugString() { return "TLA JIT context"; }
// This is called by the Retval Op to associate a computed value // This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph. // with a specific return value of the subgraph.
void XlaContext::AddRetval(int retval_index, DataType type, void XlaContext::AddRetval(int retval_index, DataType type,
const xla::XlaOp& handle) { const TensorShape& shape, const xla::XlaOp& handle) {
VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
// Add the return value to the list being built up. // Add the return value to the list being built up.
if (retvals_.size() <= retval_index) { if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1); retvals_.resize(retval_index + 1);
} }
retvals_[retval_index].set_handle(handle); XlaExpression e;
e.set_handle(handle);
retvals_[retval_index] = Retval{type, shape, e};
} }
Status XlaContext::AddConstRetval(int retval_index, DataType dtype, Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
@ -94,13 +98,11 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
if (retvals_.size() <= retval_index) { if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1); retvals_.resize(retval_index + 1);
} }
if (resolve_compile_time_constants_) { Tensor value;
Tensor value; TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); XlaExpression e;
retvals_[retval_index].set_constant_value(std::move(value)); e.set_constant_value(value);
} else { retvals_[retval_index] = Retval{dtype, value.shape(), e};
retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal));
}
return Status::OK(); return Status::OK();
} }
@ -117,9 +119,9 @@ Status XlaContext::CreateResource(
return Status::OK(); return Status::OK();
} }
TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, TensorShape XlaContext::RepresentationShape(const TensorShape& shape,
DataType type) const { DataType type) const {
return (*variable_representation_shape_fn_)(shape, type); return (*shape_representation_fn_)(shape, type);
} }
const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {

View File

@ -42,11 +42,13 @@ class XlaContext : public ResourceBase {
static XlaContext& Get(const OpKernelContext* ctx); static XlaContext& Get(const OpKernelContext* ctx);
static XlaContext& Get(const XlaOpKernelContext* ctx); static XlaContext& Get(const XlaOpKernelContext* ctx);
// Creates a new XlaContext. // Creates a new XlaContext. See the documentation on the class data fields
// for descriptions of the arguments.
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
const std::function<TensorShape(const TensorShape&, DataType)>* const std::function<TensorShape(const TensorShape&, DataType)>*
variable_representation_shape_fn); shape_representation_fn);
// Virtual method defined by ResourceBase. // Virtual method defined by ResourceBase.
string DebugString() override; string DebugString() override;
@ -58,14 +60,26 @@ class XlaContext : public ResourceBase {
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
bool resolve_compile_time_constants() const {
return resolve_compile_time_constants_;
}
bool is_entry_computation() const { return is_entry_computation_; }
const std::vector<XlaExpression>& args() const { return args_; } const std::vector<XlaExpression>& args() const { return args_; }
void set_args(std::vector<XlaExpression> args); void set_args(std::vector<XlaExpression> args);
const std::vector<XlaExpression>& retvals() { return retvals_; } struct Retval {
DataType type;
TensorShape shape;
// An XlaExpression representing the Retval's value.
XlaExpression expression;
};
const std::vector<Retval>& retvals() { return retvals_; }
// This is called by the Retval Op to associate a computed value // This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph. // with a specific return value of the subgraph.
void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle); void AddRetval(int retval_index, DataType type, const TensorShape& shape,
const xla::XlaOp& handle);
// As for Retval, but for return values that are compile-time constants. // As for Retval, but for return values that are compile-time constants.
Status AddConstRetval(int retval_index, DataType dtype, Status AddConstRetval(int retval_index, DataType dtype,
@ -86,9 +100,9 @@ class XlaContext : public ResourceBase {
} }
// Returns the XLA shape to be used to represent a variable of TF `shape` // Returns the XLA shape to be used to represent a variable of TF `shape`
// and `type`. // and `type`, or of an argument or return value of a top-level computation.
TensorShape VariableRepresentationShape(const TensorShape& shape, TensorShape RepresentationShape(const TensorShape& shape,
DataType type) const; DataType type) const;
// Get an XLA lambda to compute Max. This is cached in the // Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a // XlaContext since it may be used by multiple Ops. There is a
@ -131,15 +145,23 @@ class XlaContext : public ResourceBase {
std::vector<XlaExpression> args_; std::vector<XlaExpression> args_;
// Return values of the Tensorflow graph, indexed by _Retval index. // Return values of the Tensorflow graph, indexed by _Retval index.
std::vector<XlaExpression> retvals_; std::vector<Retval> retvals_;
// Holds ownership of resources. The resources are not ordered. // Holds ownership of resources. The resources are not ordered.
std::vector<std::unique_ptr<XlaResource>> resources_; std::vector<std::unique_ptr<XlaResource>> resources_;
// A function that describes how variable shapes should be represented // Is this a top-level computation, or an inner computation (e.g., a while
// in XLA. Variable values will be reshaped to this shape. Must be non-null. // body)?
const bool is_entry_computation_;
// A function that describes how the shapes of
// a) argument and return value, for entry computations
// b) variables, for all computations,
// should be represented in XLA. Parameters/return values will be shaped
// according to this function, and reshaped back to/from their declared shapes
// for computations. Must be non-null.
const std::function<TensorShape(const TensorShape&, DataType)>* const std::function<TensorShape(const TensorShape&, DataType)>*
variable_representation_shape_fn_; shape_representation_fn_;
// Cache of prebuilt computations indexed by their type. // Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::XlaComputation>; using ComputationMap = std::map<DataType, xla::XlaComputation>;

View File

@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
} }
XlaContext& xla_context = XlaContext::Get(context_); XlaContext& xla_context = XlaContext::Get(context_);
TensorShape representation_shape = xla_context.VariableRepresentationShape( TensorShape representation_shape =
variable->shape(), variable->type()); xla_context.RepresentationShape(variable->shape(), variable->type());
if (representation_shape == variable->shape()) { if (representation_shape == variable->shape()) {
*value = variable->value(); *value = variable->value();
} else { } else {
@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
XlaContext& xla_context = XlaContext::Get(context_); XlaContext& xla_context = XlaContext::Get(context_);
TensorShape representation_shape = TensorShape representation_shape =
xla_context.VariableRepresentationShape(shape, type); xla_context.RepresentationShape(shape, type);
if (shape != representation_shape) { if (shape != representation_shape) {
handle = builder()->Reshape(handle, representation_shape.dim_sizes()); handle = builder()->Reshape(handle, representation_shape.dim_sizes());
} }