Automated g4 rollback of changelist 196691101
PiperOrigin-RevId: 196879933
This commit is contained in:
parent
c9e4705d62
commit
41af9782f4
@ -551,14 +551,16 @@ TEST(TFCompileTest, HloProfiling) {
|
|||||||
auto header = HasSubstr("Execution profile for");
|
auto header = HasSubstr("Execution profile for");
|
||||||
auto total_cycles_profile_line = HasSubstr("[total]");
|
auto total_cycles_profile_line = HasSubstr("[total]");
|
||||||
auto dot_profile_line = HasSubstr(
|
auto dot_profile_line = HasSubstr(
|
||||||
"%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
"%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
||||||
"%arg1.0.1)");
|
"%arg1.0.1)");
|
||||||
auto add_profile_line = HasSubstr(
|
auto add_profile_line = HasSubstr(
|
||||||
"%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
"%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
||||||
"%arg1.0.1)");
|
"%arg1.0.1)");
|
||||||
auto tuple_profile_line = HasSubstr(
|
auto tuple_profile_line = HasSubstr(
|
||||||
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
|
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
|
||||||
"%dot.0.2, f32[2,2]{1,0} %add.0.5)");
|
"%dot.0.4, f32[2,2]{1,0} %add.0.6)");
|
||||||
|
auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
|
||||||
|
auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
|
||||||
|
|
||||||
EXPECT_THAT(hlo_profile_lines,
|
EXPECT_THAT(hlo_profile_lines,
|
||||||
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
||||||
|
@ -112,7 +112,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
// this is more obviously correct.)
|
// this is more obviously correct.)
|
||||||
core::ScopedUnref cache_ref(cache);
|
core::ScopedUnref cache_ref(cache);
|
||||||
|
|
||||||
const XlaDevice::Metadata* metadata;
|
const XlaDevice::Metadata* metadata = nullptr;
|
||||||
Status s = XlaDevice::GetMetadata(ctx, &metadata);
|
Status s = XlaDevice::GetMetadata(ctx, &metadata);
|
||||||
bool allocate_xla_tensors = s.ok();
|
bool allocate_xla_tensors = s.ok();
|
||||||
|
|
||||||
@ -153,9 +153,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||||
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
|
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
|
||||||
options.device_allocator = xla_allocator;
|
options.device_allocator = xla_allocator;
|
||||||
// TODO(b/77671268): We don't set variable_representation_shape_fn here. This
|
if (metadata) {
|
||||||
// is restricted to Variables, but we need something like this to apply to
|
options.shape_representation_fn = metadata->shape_representation_fn();
|
||||||
// normal Tensors too.
|
}
|
||||||
|
|
||||||
const XlaCompiler::CompilationResult* kernel;
|
const XlaCompiler::CompilationResult* kernel;
|
||||||
xla::LocalExecutable* executable;
|
xla::LocalExecutable* executable;
|
||||||
@ -164,9 +164,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
|||||||
for (int i : constants_) {
|
for (int i : constants_) {
|
||||||
constant_args.insert({i, ctx->input(i)});
|
constant_args.insert({i, ctx->input(i)});
|
||||||
}
|
}
|
||||||
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
|
XlaCompiler::CompileOptions compile_options;
|
||||||
variables, ctx, &kernel, &executable,
|
compile_options.is_entry_computation = true;
|
||||||
/*compile_options=*/nullptr));
|
OP_REQUIRES_OK(
|
||||||
|
ctx, cache->Compile(options, function_, constant_args, variables, ctx,
|
||||||
|
&kernel, &executable, &compile_options));
|
||||||
|
|
||||||
VLOG(1) << "Executing XLA Computation...";
|
VLOG(1) << "Executing XLA Computation...";
|
||||||
|
|
||||||
|
@ -156,11 +156,14 @@ Status XlaCompileOnDemandOp::Compile(
|
|||||||
options.client = metadata.client();
|
options.client = metadata.client();
|
||||||
options.flib_def =
|
options.flib_def =
|
||||||
new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
|
new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
|
||||||
|
options.shape_representation_fn = metadata.shape_representation_fn();
|
||||||
|
|
||||||
|
XlaCompiler::CompileOptions compile_options;
|
||||||
|
compile_options.is_entry_computation = true;
|
||||||
|
|
||||||
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
|
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
|
||||||
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
|
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
|
||||||
result, executable,
|
result, executable, &compile_options);
|
||||||
/*compile_options=*/nullptr);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
|
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
|
||||||
|
@ -50,10 +50,11 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
|||||||
(void)registrations;
|
(void)registrations;
|
||||||
|
|
||||||
std::unique_ptr<XlaDevice> device;
|
std::unique_ptr<XlaDevice> device;
|
||||||
TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
|
TF_RETURN_IF_ERROR(
|
||||||
DEVICE_CPU_XLA_JIT, options, name_prefix,
|
XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options,
|
||||||
registration,
|
name_prefix, registration,
|
||||||
/*transfer_as_literal=*/false, &device));
|
/*transfer_as_literal=*/false,
|
||||||
|
/*shape_representation_fn=*/{}, &device));
|
||||||
devices->push_back(device.release());
|
devices->push_back(device.release());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -110,7 +110,9 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
|
|||||||
const string& jit_device_name, const SessionOptions& options,
|
const string& jit_device_name, const SessionOptions& options,
|
||||||
const string& name_prefix,
|
const string& name_prefix,
|
||||||
const XlaOpRegistry::DeviceRegistration& registration,
|
const XlaOpRegistry::DeviceRegistration& registration,
|
||||||
bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) {
|
bool transfer_as_literal,
|
||||||
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||||
|
std::unique_ptr<XlaDevice>* device) {
|
||||||
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
|
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
|
||||||
<< device_ordinal;
|
<< device_ordinal;
|
||||||
|
|
||||||
@ -129,17 +131,19 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
|
|||||||
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
|
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
|
||||||
strings::StrCat("device: ", device_name, " device"));
|
strings::StrCat("device: ", device_name, " device"));
|
||||||
|
|
||||||
device->reset(new XlaDevice(options, attrs, device_ordinal,
|
device->reset(new XlaDevice(
|
||||||
DeviceType(jit_device_name),
|
options, attrs, device_ordinal, DeviceType(jit_device_name),
|
||||||
platform.ValueOrDie(), transfer_as_literal));
|
platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
|
XlaDevice::Metadata::Metadata(
|
||||||
const DeviceType& device_type)
|
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||||
: device_ordinal_(device_ordinal),
|
: device_ordinal_(device_ordinal),
|
||||||
device_type_(device_type),
|
device_type_(device_type),
|
||||||
platform_(platform) {}
|
platform_(platform),
|
||||||
|
shape_representation_fn_(std::move(shape_representation_fn)) {}
|
||||||
|
|
||||||
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
|
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
|
||||||
|
|
||||||
@ -170,17 +174,20 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaDevice::XlaDevice(const SessionOptions& options,
|
XlaDevice::XlaDevice(
|
||||||
const DeviceAttributes& attrs, int device_ordinal,
|
const SessionOptions& options, const DeviceAttributes& attrs,
|
||||||
const DeviceType& jit_device_name, se::Platform* platform,
|
int device_ordinal, const DeviceType& jit_device_name,
|
||||||
bool transfer_as_literal)
|
se::Platform* platform, bool transfer_as_literal,
|
||||||
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
|
||||||
: LocalDevice(options, attrs),
|
: LocalDevice(options, attrs),
|
||||||
xla_metadata_(device_ordinal, platform, jit_device_name),
|
xla_metadata_(device_ordinal, platform, jit_device_name,
|
||||||
|
shape_representation_fn),
|
||||||
device_ordinal_(device_ordinal),
|
device_ordinal_(device_ordinal),
|
||||||
jit_device_name_(jit_device_name),
|
jit_device_name_(jit_device_name),
|
||||||
xla_allocator_(nullptr),
|
xla_allocator_(nullptr),
|
||||||
platform_(platform),
|
platform_(platform),
|
||||||
transfer_as_literal_(transfer_as_literal) {
|
transfer_as_literal_(transfer_as_literal),
|
||||||
|
shape_representation_fn_(shape_representation_fn) {
|
||||||
VLOG(1) << "Created XLA device " << jit_device_name;
|
VLOG(1) << "Created XLA device " << jit_device_name;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -232,8 +239,8 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
|
|||||||
// gpu_device_info_->default_context.
|
// gpu_device_info_->default_context.
|
||||||
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
|
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
|
||||||
gpu_device_info_->stream = stream;
|
gpu_device_info_->stream = stream;
|
||||||
gpu_device_info_->default_context =
|
gpu_device_info_->default_context = new XlaDeviceContext(
|
||||||
new XlaDeviceContext(stream, client(), transfer_as_literal_);
|
stream, client(), transfer_as_literal_, shape_representation_fn_);
|
||||||
set_tensorflow_gpu_device_info(gpu_device_info_.get());
|
set_tensorflow_gpu_device_info(gpu_device_info_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -247,7 +254,8 @@ Status XlaDevice::FillContextMap(const Graph* graph,
|
|||||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||||
// Call GetAllocator for the side-effect of ensuring the allocator is created.
|
// Call GetAllocator for the side-effect of ensuring the allocator is created.
|
||||||
GetAllocator({});
|
GetAllocator({});
|
||||||
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_);
|
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
|
||||||
|
shape_representation_fn_);
|
||||||
for (Node* n : graph->nodes()) {
|
for (Node* n : graph->nodes()) {
|
||||||
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
|
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
|
||||||
ctx->Ref();
|
ctx->Ref();
|
||||||
@ -294,7 +302,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
|||||||
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
|
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
|
||||||
Notification n;
|
Notification n;
|
||||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||||
XlaTransferManager manager(stream, client(), transfer_as_literal_);
|
XlaTransferManager manager(stream, client(), transfer_as_literal_,
|
||||||
|
shape_representation_fn_);
|
||||||
manager.CopyCPUTensorToDevice(&parsed, this, ©,
|
manager.CopyCPUTensorToDevice(&parsed, this, ©,
|
||||||
[&n, &status](const Status& s) {
|
[&n, &status](const Status& s) {
|
||||||
status = s;
|
status = s;
|
||||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
|||||||
// runtime.
|
// runtime.
|
||||||
//
|
//
|
||||||
// Operators assigned to an XlaDevice are compiled into XLA computations.
|
// Operators assigned to an XlaDevice are compiled into XLA computations.
|
||||||
// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state
|
// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
|
||||||
// is managed by XLA.
|
|
||||||
//
|
//
|
||||||
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
|
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
|
||||||
// under different names (e.g., XLA_CPU or XLA_GPU).
|
// under different names (e.g., XLA_CPU or XLA_GPU).
|
||||||
@ -27,6 +26,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
|
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
@ -50,7 +50,8 @@ class XlaDevice : public LocalDevice {
|
|||||||
class Metadata {
|
class Metadata {
|
||||||
public:
|
public:
|
||||||
Metadata(int device_ordinal, se::Platform* platform,
|
Metadata(int device_ordinal, se::Platform* platform,
|
||||||
const DeviceType& device_type);
|
const DeviceType& device_type,
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||||
|
|
||||||
// The index of the device on this host.
|
// The index of the device on this host.
|
||||||
int device_ordinal() const;
|
int device_ordinal() const;
|
||||||
@ -58,11 +59,15 @@ class XlaDevice : public LocalDevice {
|
|||||||
se::Platform* platform() const;
|
se::Platform* platform() const;
|
||||||
xla::LocalClient* client() const;
|
xla::LocalClient* client() const;
|
||||||
const DeviceType& jit_device_type() const;
|
const DeviceType& jit_device_type() const;
|
||||||
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
|
||||||
|
return shape_representation_fn_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const int device_ordinal_;
|
const int device_ordinal_;
|
||||||
const DeviceType device_type_;
|
const DeviceType device_type_;
|
||||||
se::Platform* platform_; // Not owned.
|
se::Platform* platform_; // Not owned.
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
|
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
|
||||||
};
|
};
|
||||||
@ -76,16 +81,19 @@ class XlaDevice : public LocalDevice {
|
|||||||
// 'transfer_as_literal' is true if device<->host transfers must be done using
|
// 'transfer_as_literal' is true if device<->host transfers must be done using
|
||||||
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
|
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
|
||||||
// ThenMemcpy instead.
|
// ThenMemcpy instead.
|
||||||
static Status Create(const string& platform_name, const string& device_name,
|
static Status Create(
|
||||||
int device_ordinal, const string& jit_device_name,
|
const string& platform_name, const string& device_name,
|
||||||
const SessionOptions& options, const string& name_prefix,
|
int device_ordinal, const string& jit_device_name,
|
||||||
const XlaOpRegistry::DeviceRegistration& registration,
|
const SessionOptions& options, const string& name_prefix,
|
||||||
bool transfer_as_literal,
|
const XlaOpRegistry::DeviceRegistration& registration,
|
||||||
std::unique_ptr<XlaDevice>* device);
|
bool transfer_as_literal,
|
||||||
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||||
|
std::unique_ptr<XlaDevice>* device);
|
||||||
|
|
||||||
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
|
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
|
||||||
int device_ordinal, const DeviceType& jit_device_name,
|
int device_ordinal, const DeviceType& jit_device_name,
|
||||||
se::Platform* platform, bool transfer_as_literal);
|
se::Platform* platform, bool transfer_as_literal,
|
||||||
|
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn);
|
||||||
~XlaDevice() override;
|
~XlaDevice() override;
|
||||||
|
|
||||||
Allocator* GetAllocator(AllocatorAttributes attr) override;
|
Allocator* GetAllocator(AllocatorAttributes attr) override;
|
||||||
@ -116,8 +124,8 @@ class XlaDevice : public LocalDevice {
|
|||||||
// The name of the device that is used to compile Ops for this XlaDevice.
|
// The name of the device that is used to compile Ops for this XlaDevice.
|
||||||
DeviceType jit_device_name_;
|
DeviceType jit_device_name_;
|
||||||
// Memory allocator associated with this device.
|
// Memory allocator associated with this device.
|
||||||
Allocator* xla_allocator_; // Not owned.
|
Allocator* xla_allocator_; // Not owned.
|
||||||
se::Platform* platform_; // Not owned.
|
se::Platform* platform_; // Not owned.
|
||||||
// Stream associated with this device. Operations enqueued on this
|
// Stream associated with this device. Operations enqueued on this
|
||||||
// stream are executed on the device. Operations include data
|
// stream are executed on the device. Operations include data
|
||||||
// copying back and forth between CPU and the device, and
|
// copying back and forth between CPU and the device, and
|
||||||
@ -126,6 +134,7 @@ class XlaDevice : public LocalDevice {
|
|||||||
// Must we use XLA's transfer manager for correct host<->device transfers? if
|
// Must we use XLA's transfer manager for correct host<->device transfers? if
|
||||||
// false, we can use ThenMemcpy() instead.
|
// false, we can use ThenMemcpy() instead.
|
||||||
bool transfer_as_literal_;
|
bool transfer_as_literal_;
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||||
|
|
||||||
// If set, holds default device context (that we must Unref)
|
// If set, holds default device context (that we must Unref)
|
||||||
// and its stream.
|
// and its stream.
|
||||||
|
@ -47,13 +47,14 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
|
|||||||
|
|
||||||
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
|
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
|
||||||
|
|
||||||
XlaTransferManager::XlaTransferManager(se::Stream* stream,
|
XlaTransferManager::XlaTransferManager(
|
||||||
xla::LocalClient* client,
|
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||||
bool transfer_as_literal)
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||||
: stream_(stream),
|
: stream_(stream),
|
||||||
client_(client),
|
client_(client),
|
||||||
transfer_manager_(client->backend().transfer_manager()),
|
transfer_manager_(client->backend().transfer_manager()),
|
||||||
transfer_as_literal_(transfer_as_literal) {}
|
transfer_as_literal_(transfer_as_literal),
|
||||||
|
shape_representation_fn_(std::move(shape_representation_fn)) {}
|
||||||
|
|
||||||
Status XlaTransferManager::TransferLiteralToDevice(
|
Status XlaTransferManager::TransferLiteralToDevice(
|
||||||
const Tensor& host_tensor, Tensor* device_tensor) const {
|
const Tensor& host_tensor, Tensor* device_tensor) const {
|
||||||
@ -76,7 +77,15 @@ Status XlaTransferManager::TransferLiteralFromDevice(
|
|||||||
transfer_manager_->TransferLiteralFromDevice(
|
transfer_manager_->TransferLiteralFromDevice(
|
||||||
stream_->parent(), shaped_buffer));
|
stream_->parent(), shaped_buffer));
|
||||||
VLOG(1) << "Transfer from device as literal: " << literal->ToString();
|
VLOG(1) << "Transfer from device as literal: " << literal->ToString();
|
||||||
return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor);
|
Tensor tensor;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
|
||||||
|
// Reshape the tensor back to its declared shape.
|
||||||
|
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
|
||||||
|
return errors::Internal(
|
||||||
|
"Tensor::CopyFrom failed when copying from XLA device to CPU");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||||
@ -96,9 +105,17 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
|||||||
|
|
||||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||||
CHECK(xla_tensor);
|
CHECK(xla_tensor);
|
||||||
|
|
||||||
|
TensorShape shape;
|
||||||
|
if (shape_representation_fn_) {
|
||||||
|
shape = shape_representation_fn_(device_tensor->shape(),
|
||||||
|
device_tensor->dtype());
|
||||||
|
} else {
|
||||||
|
shape = device_tensor->shape();
|
||||||
|
}
|
||||||
if (!xla_tensor->has_shaped_buffer()) {
|
if (!xla_tensor->has_shaped_buffer()) {
|
||||||
Status s = xla_tensor->AllocateShapedBuffer(
|
Status s = xla_tensor->AllocateShapedBuffer(
|
||||||
device_tensor->dtype(), device_tensor->shape(), client_,
|
device_tensor->dtype(), shape, client_,
|
||||||
stream_->parent()->device_ordinal());
|
stream_->parent()->device_ordinal());
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
done(s);
|
done(s);
|
||||||
@ -106,12 +123,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
se::DeviceMemoryBase dev_dst_ptr =
|
|
||||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
|
||||||
Status status;
|
Status status;
|
||||||
if (transfer_as_literal_) {
|
if (transfer_as_literal_) {
|
||||||
status = TransferLiteralToDevice(*cpu_tensor, device_tensor);
|
Tensor reshaped_cpu_tensor;
|
||||||
|
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
|
||||||
|
done(errors::Internal(
|
||||||
|
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
|
||||||
} else {
|
} else {
|
||||||
|
se::DeviceMemoryBase dev_dst_ptr =
|
||||||
|
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||||
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
|
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
|
||||||
// TODO(hpucha): Make this asynchronous.
|
// TODO(hpucha): Make this asynchronous.
|
||||||
Status block_status = stream_->BlockHostUntilDone();
|
Status block_status = stream_->BlockHostUntilDone();
|
||||||
@ -171,9 +194,11 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
|||||||
done(Status::OK());
|
done(Status::OK());
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
|
XlaDeviceContext::XlaDeviceContext(
|
||||||
bool transfer_as_literal)
|
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||||
: manager_(stream, client, transfer_as_literal) {}
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||||
|
: manager_(stream, client, transfer_as_literal,
|
||||||
|
std::move(shape_representation_fn)) {}
|
||||||
|
|
||||||
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||||
Device* device,
|
Device* device,
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/core/framework/allocator.h"
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
@ -45,8 +46,9 @@ class XlaDeviceAllocator : public Allocator {
|
|||||||
// Helper class for managing data transfers between host and XLA devices.
|
// Helper class for managing data transfers between host and XLA devices.
|
||||||
class XlaTransferManager {
|
class XlaTransferManager {
|
||||||
public:
|
public:
|
||||||
explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client,
|
explicit XlaTransferManager(
|
||||||
bool transfer_as_literal);
|
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||||
|
|
||||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||||
Tensor* device_tensor, StatusCallback done) const;
|
Tensor* device_tensor, StatusCallback done) const;
|
||||||
@ -69,7 +71,8 @@ class XlaTransferManager {
|
|||||||
// Transfer manager, for marshalling data to and from the device.
|
// Transfer manager, for marshalling data to and from the device.
|
||||||
xla::TransferManager* transfer_manager_;
|
xla::TransferManager* transfer_manager_;
|
||||||
// True if we must use XLA's TransferManager for correct device transfers.
|
// True if we must use XLA's TransferManager for correct device transfers.
|
||||||
bool transfer_as_literal_;
|
const bool transfer_as_literal_;
|
||||||
|
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// DeviceContext for operators assigned to XlaDevice devices. The
|
// DeviceContext for operators assigned to XlaDevice devices. The
|
||||||
@ -77,8 +80,9 @@ class XlaTransferManager {
|
|||||||
// wraps the methods in XlaTransferManager.
|
// wraps the methods in XlaTransferManager.
|
||||||
class XlaDeviceContext : public DeviceContext {
|
class XlaDeviceContext : public DeviceContext {
|
||||||
public:
|
public:
|
||||||
explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
|
explicit XlaDeviceContext(
|
||||||
bool transfer_as_literal);
|
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||||
|
|
||||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||||
Tensor* device_tensor,
|
Tensor* device_tensor,
|
||||||
|
@ -48,7 +48,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
|||||||
Status status =
|
Status status =
|
||||||
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
|
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
|
||||||
name_prefix, registration,
|
name_prefix, registration,
|
||||||
/*transfer_as_literal=*/false, &device);
|
/*transfer_as_literal=*/false,
|
||||||
|
/*shape_representation_fn=*/{}, &device);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||||
VLOG(1) << "Failed to create XLA_GPU device: " << status;
|
VLOG(1) << "Failed to create XLA_GPU device: " << status;
|
||||||
|
@ -195,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
|
ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
|
||||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
|
||||||
OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer(
|
|
||||||
const_tensor.dtype(), const_tensor.shape(),
|
|
||||||
client_, stream->parent()->device_ordinal()));
|
|
||||||
}
|
|
||||||
|
|
||||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||||
OP_REQUIRES(ctx, device != nullptr,
|
OP_REQUIRES(ctx, device != nullptr,
|
||||||
|
@ -42,7 +42,7 @@ py_library(
|
|||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client",
|
"//tensorflow/python:client",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform",
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:random_seed",
|
"//tensorflow/python:random_seed",
|
||||||
"//tensorflow/python:session",
|
"//tensorflow/python:session",
|
||||||
@ -58,7 +58,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
@ -72,7 +72,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
@ -93,7 +93,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -111,7 +111,7 @@ tf_xla_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:bitwise_ops",
|
"//tensorflow/python:bitwise_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:math_ops_gen",
|
"//tensorflow/python:math_ops_gen",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
@ -127,7 +127,7 @@ tf_xla_py_test(
|
|||||||
tags = ["optonly"],
|
tags = ["optonly"],
|
||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
@ -141,7 +141,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
@ -156,7 +156,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
@ -170,7 +170,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -184,7 +184,7 @@ tf_xla_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:array_ops_gen",
|
"//tensorflow/python:array_ops_gen",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:gradient_checker",
|
"//tensorflow/python:gradient_checker",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
@ -209,7 +209,7 @@ tf_xla_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:array_ops_gen",
|
"//tensorflow/python:array_ops_gen",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:gradient_checker",
|
"//tensorflow/python:gradient_checker",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
@ -225,7 +225,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:nn",
|
"//tensorflow/python:nn",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:nn_ops_gen",
|
"//tensorflow/python:nn_ops_gen",
|
||||||
@ -241,7 +241,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:nn",
|
"//tensorflow/python:nn",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:nn_ops_gen",
|
"//tensorflow/python:nn_ops_gen",
|
||||||
@ -263,7 +263,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:nn",
|
"//tensorflow/python:nn",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:nn_ops_gen",
|
"//tensorflow/python:nn_ops_gen",
|
||||||
@ -291,7 +291,7 @@ tf_xla_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:data_flow_ops",
|
"//tensorflow/python:data_flow_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -307,7 +307,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -326,7 +326,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:layers",
|
"//tensorflow/python:layers",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:nn",
|
"//tensorflow/python:nn",
|
||||||
@ -346,7 +346,7 @@ tf_xla_py_test(
|
|||||||
"//tensorflow/contrib/signal:signal_py",
|
"//tensorflow/contrib/signal:signal_py",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:extra_py_tests_deps",
|
"//tensorflow/python:extra_py_tests_deps",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:spectral_ops",
|
"//tensorflow/python:spectral_ops",
|
||||||
],
|
],
|
||||||
@ -360,7 +360,7 @@ tf_xla_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:data_flow_ops",
|
"//tensorflow/python:data_flow_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -372,7 +372,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
@ -388,7 +388,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -403,7 +403,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:image_ops",
|
"//tensorflow/python:image_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -431,7 +431,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:nn",
|
"//tensorflow/python:nn",
|
||||||
"//tensorflow/python:nn_ops_gen",
|
"//tensorflow/python:nn_ops_gen",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
@ -446,7 +446,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -458,7 +458,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
@ -472,7 +472,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -485,7 +485,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -498,7 +498,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:nn_ops_gen",
|
"//tensorflow/python:nn_ops_gen",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
@ -513,7 +513,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:nn_ops_gen",
|
"//tensorflow/python:nn_ops_gen",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
@ -530,7 +530,7 @@ tf_xla_py_test(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
],
|
],
|
||||||
@ -545,7 +545,7 @@ tf_xla_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -561,7 +561,7 @@ tf_xla_py_test(
|
|||||||
"//tensorflow/compiler/tf2xla/python:xla",
|
"//tensorflow/compiler/tf2xla/python:xla",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:errors",
|
"//tensorflow/python:errors",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -574,7 +574,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -586,7 +586,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -598,7 +598,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
@ -613,7 +613,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -626,7 +626,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:math_ops_gen",
|
"//tensorflow/python:math_ops_gen",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
@ -641,7 +641,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -657,7 +657,7 @@ tf_xla_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:data_flow_ops",
|
"//tensorflow/python:data_flow_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -670,7 +670,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/contrib/stateless",
|
"//tensorflow/contrib/stateless",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -684,7 +684,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:math_ops_gen",
|
"//tensorflow/python:math_ops_gen",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
@ -703,7 +703,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
@ -716,7 +716,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
"//tensorflow/python:nn_ops_gen",
|
"//tensorflow/python:nn_ops_gen",
|
||||||
@ -730,7 +730,7 @@ tf_xla_py_test(
|
|||||||
srcs = ["fused_batchnorm_test.py"],
|
srcs = ["fused_batchnorm_test.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:math_ops_gen",
|
"//tensorflow/python:math_ops_gen",
|
||||||
"//tensorflow/python:nn",
|
"//tensorflow/python:nn",
|
||||||
@ -749,7 +749,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:math_ops_gen",
|
"//tensorflow/python:math_ops_gen",
|
||||||
"//tensorflow/python:nn_ops",
|
"//tensorflow/python:nn_ops",
|
||||||
@ -768,7 +768,7 @@ tf_xla_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/compiler/tf2xla/python:xla",
|
"//tensorflow/compiler/tf2xla/python:xla",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
],
|
],
|
||||||
@ -783,7 +783,7 @@ tf_xla_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:data_flow_ops",
|
"//tensorflow/python:data_flow_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -795,7 +795,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -808,21 +808,34 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "xla_device_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["xla_device_test.py"],
|
||||||
|
tags = ["optonly"],
|
||||||
|
deps = [
|
||||||
|
":xla_test",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_test(
|
cuda_py_test(
|
||||||
name = "xla_device_test",
|
name = "xla_device_gpu_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
srcs = ["xla_device_test.py"],
|
srcs = ["xla_device_gpu_test.py"],
|
||||||
additional_deps = [
|
additional_deps = [
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client",
|
"//tensorflow/python:client",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -839,7 +852,6 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:control_flow_ops",
|
"//tensorflow/python:control_flow_ops",
|
||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:layers",
|
"//tensorflow/python:layers",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
@ -887,7 +899,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:random_ops",
|
"//tensorflow/python:random_ops",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
@ -902,7 +914,7 @@ cuda_py_test(
|
|||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:init_ops",
|
"//tensorflow/python:init_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
@ -940,7 +952,7 @@ tf_xla_py_test(
|
|||||||
srcs = ["fake_quant_ops_test.py"],
|
srcs = ["fake_quant_ops_test.py"],
|
||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -952,7 +964,7 @@ tf_xla_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":xla_test",
|
":xla_test",
|
||||||
"//tensorflow/python:array_ops",
|
"//tensorflow/python:array_ops",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
48
tensorflow/compiler/tests/xla_device_gpu_test.py
Normal file
48
tensorflow/compiler/tests/xla_device_gpu_test.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Test cases for XLA devices."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.client import session as session_lib
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class XlaDeviceGpuTest(test.TestCase):
|
||||||
|
|
||||||
|
def testCopiesToAndFromGpuWork(self):
|
||||||
|
"""Tests that copies between GPU and XLA devices work."""
|
||||||
|
if not test.is_gpu_available():
|
||||||
|
return
|
||||||
|
|
||||||
|
with session_lib.Session() as sess:
|
||||||
|
x = array_ops.placeholder(dtypes.float32, [2])
|
||||||
|
with ops.device("GPU"):
|
||||||
|
y = x * 2
|
||||||
|
with ops.device("device:XLA_CPU:0"):
|
||||||
|
z = y * y
|
||||||
|
with ops.device("GPU"):
|
||||||
|
w = y + z
|
||||||
|
result = sess.run(w, {x: [1.5, 0.5]})
|
||||||
|
self.assertAllClose(result, [12., 2.], rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -18,30 +18,33 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.client import session as session_lib
|
import numpy as np
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
|
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
class XlaDeviceTest(test.TestCase):
|
class XlaDeviceTest(XLATestCase):
|
||||||
|
|
||||||
def testCopies(self):
|
def testCopies(self):
|
||||||
"""Tests that copies between GPU and XLA devices work."""
|
"""Tests that copies onto and off XLA devices work."""
|
||||||
if not test.is_gpu_available():
|
shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],
|
||||||
return
|
[16384, 1], [1, 16384], [1, 20000, 1, 1]]
|
||||||
|
for dtype in self.numeric_types:
|
||||||
|
for shape in shapes:
|
||||||
|
with self.test_session() as sess:
|
||||||
|
with ops.device("CPU"):
|
||||||
|
x = array_ops.placeholder(dtype, shape)
|
||||||
|
with self.test_scope():
|
||||||
|
y = x + x
|
||||||
|
with ops.device("CPU"):
|
||||||
|
z = array_ops.identity(y)
|
||||||
|
|
||||||
with session_lib.Session() as sess:
|
inputs = np.random.randint(-100, 100, shape).astype(dtype)
|
||||||
x = array_ops.placeholder(dtypes.float32, [2])
|
result = sess.run(z, {x: inputs})
|
||||||
with ops.device("GPU"):
|
self.assertAllCloseAccordingToType(result, inputs + inputs)
|
||||||
y = x * 2
|
|
||||||
with ops.device("device:XLA_CPU:0"):
|
|
||||||
z = y * y
|
|
||||||
with ops.device("GPU"):
|
|
||||||
w = y + z
|
|
||||||
result = sess.run(w, {x: [1.5, 0.5]})
|
|
||||||
self.assertAllClose(result, [12., 2.], rtol=1e-3)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -325,6 +325,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||||
|
@ -208,10 +208,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
|
PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
|
||||||
|
|
||||||
|
XlaCompiler::CompileOptions compile_options;
|
||||||
|
compile_options.is_entry_computation = false;
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(),
|
compiler->CompileFunction(compile_options, func, arguments, &result));
|
||||||
func, arguments, &result));
|
|
||||||
|
|
||||||
TF_RET_CHECK(arguments.size() == expressions.size());
|
TF_RET_CHECK(arguments.size() == expressions.size());
|
||||||
|
|
||||||
|
@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaContext& tc = XlaContext::Get(ctx);
|
XlaContext& tc = XlaContext::Get(ctx);
|
||||||
if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) {
|
if (tc.resolve_compile_time_constants() &&
|
||||||
|
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
|
||||||
xla::Literal literal;
|
xla::Literal literal;
|
||||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
|
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
|
||||||
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
|
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
|
||||||
} else {
|
} else {
|
||||||
// The core from which a return value is returned depends on the core
|
TensorShape shape = ctx->InputShape(0);
|
||||||
// assignment of the input to the retval .Since we can't change the core
|
TensorShape representation_shape =
|
||||||
// assignment of <input> as this point, create a tuple/get-tuple-element
|
tc.is_entry_computation()
|
||||||
// combination so that the core will be set on them.
|
? tc.RepresentationShape(shape, ctx->input_type(0))
|
||||||
auto tuple_elem =
|
: shape;
|
||||||
ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0);
|
|
||||||
tc.AddRetval(index_, dtype_, tuple_elem);
|
xla::XlaOp output = input;
|
||||||
|
if (tc.is_entry_computation()) {
|
||||||
|
output =
|
||||||
|
ctx->builder()->Reshape(input, representation_shape.dim_sizes());
|
||||||
|
} else {
|
||||||
|
// The core from which a return value is returned depends on the
|
||||||
|
// device assignment of the input to the retval. Since we can't change
|
||||||
|
// the device assignment of "input" at this point, we must always
|
||||||
|
// introduce an operator here, even if the shape does not change.
|
||||||
|
// TODO(b/76097077): propagate device assignments onto arguments and
|
||||||
|
// return values of functions, and then reshape unconditionally.
|
||||||
|
output = ctx->builder()->GetTupleElement(
|
||||||
|
ctx->builder()->Tuple({output}), 0);
|
||||||
|
}
|
||||||
|
tc.AddRetval(index_, dtype_, shape, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,10 +15,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
|
|
||||||
#include <deque>
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
|
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
|
||||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||||
@ -28,7 +27,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/executor.h"
|
#include "tensorflow/core/common_runtime/executor.h"
|
||||||
@ -40,7 +38,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/graph/node_builder.h"
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/public/version.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -110,10 +107,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
|||||||
local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
|
local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
|
||||||
flib_runtime_ = pflr_->GetFLR(device_->name());
|
flib_runtime_ = pflr_->GetFLR(device_->name());
|
||||||
|
|
||||||
// The default variable representation shape is the identity function.
|
// The default shape representation function is the identity.
|
||||||
if (!options_.variable_representation_shape_fn) {
|
if (!options_.shape_representation_fn) {
|
||||||
options_.variable_representation_shape_fn =
|
options_.shape_representation_fn = [](const TensorShape& shape,
|
||||||
[](const TensorShape& shape, DataType type) { return shape; };
|
DataType type) { return shape; };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,20 +227,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
|||||||
|
|
||||||
// Computes the XLA shape for argument 'arg'.
|
// Computes the XLA shape for argument 'arg'.
|
||||||
Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
||||||
|
bool is_entry_computation,
|
||||||
xla::Shape* xla_shape) {
|
xla::Shape* xla_shape) {
|
||||||
switch (arg.kind) {
|
switch (arg.kind) {
|
||||||
case XlaCompiler::Argument::kConstant:
|
case XlaCompiler::Argument::kConstant:
|
||||||
return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
|
LOG(FATAL) << "Unreachable case";
|
||||||
xla_shape);
|
case XlaCompiler::Argument::kParameter: {
|
||||||
case XlaCompiler::Argument::kParameter:
|
TensorShape shape =
|
||||||
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
|
is_entry_computation
|
||||||
|
? options_.shape_representation_fn(arg.shape, arg.type)
|
||||||
|
: arg.shape;
|
||||||
|
return TensorShapeToXLAShape(arg.type, shape, xla_shape);
|
||||||
|
}
|
||||||
case XlaCompiler::Argument::kResource: {
|
case XlaCompiler::Argument::kResource: {
|
||||||
TF_RET_CHECK(arg.initialized);
|
TF_RET_CHECK(arg.initialized);
|
||||||
|
|
||||||
switch (arg.resource_kind) {
|
switch (arg.resource_kind) {
|
||||||
case XlaResource::kVariable: {
|
case XlaResource::kVariable: {
|
||||||
TensorShape representation_shape =
|
TensorShape representation_shape =
|
||||||
options_.variable_representation_shape_fn(arg.shape, arg.type);
|
options_.shape_representation_fn(arg.shape, arg.type);
|
||||||
return TensorShapeToXLAShape(arg.type, representation_shape,
|
return TensorShapeToXLAShape(arg.type, representation_shape,
|
||||||
xla_shape);
|
xla_shape);
|
||||||
}
|
}
|
||||||
@ -337,16 +339,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
|||||||
Status BuildComputation(
|
Status BuildComputation(
|
||||||
const std::vector<XlaCompiler::Argument>& args,
|
const std::vector<XlaCompiler::Argument>& args,
|
||||||
const std::vector<int>& arg_cores,
|
const std::vector<int>& arg_cores,
|
||||||
const std::vector<XlaExpression>& retvals,
|
const std::vector<XlaContext::Retval>& retvals,
|
||||||
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
||||||
bool return_updated_values_for_all_resources, xla::XlaBuilder* builder,
|
bool return_updated_values_for_all_resources, xla::XlaBuilder* builder,
|
||||||
xla::XlaComputation* computation, int* num_computation_outputs,
|
xla::XlaComputation* computation, int* num_computation_outputs,
|
||||||
int* num_nonconst_outputs,
|
int* num_nonconst_outputs,
|
||||||
|
std::vector<XlaCompiler::OutputDescription>* outputs,
|
||||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
||||||
std::vector<xla::XlaOp> elems;
|
std::vector<xla::XlaOp> elems;
|
||||||
elems.reserve(retvals.size());
|
elems.reserve(retvals.size());
|
||||||
for (const XlaExpression& retval : retvals) {
|
for (int i = 0; i < retvals.size(); ++i) {
|
||||||
if (!retval.has_constant_value()) {
|
XlaCompiler::OutputDescription& output = (*outputs)[i];
|
||||||
|
output.type = retvals[i].type;
|
||||||
|
output.shape = retvals[i].shape;
|
||||||
|
const XlaExpression& retval = retvals[i].expression;
|
||||||
|
if (retval.has_constant_value()) {
|
||||||
|
output.is_constant = true;
|
||||||
|
output.constant_value = retval.constant_value();
|
||||||
|
} else {
|
||||||
|
output.is_constant = false;
|
||||||
elems.push_back(retval.handle());
|
elems.push_back(retval.handle());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -490,8 +501,8 @@ Status XlaCompiler::BuildArguments(
|
|||||||
std::vector<xla::Shape> arg_shapes(input_mapping->size());
|
std::vector<xla::Shape> arg_shapes(input_mapping->size());
|
||||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||||
// Computes the shapes of non-constant arguments.
|
// Computes the shapes of non-constant arguments.
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(XLAShapeForArgument(
|
||||||
XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i]));
|
args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_tuple_arg) {
|
if (use_tuple_arg) {
|
||||||
@ -567,7 +578,8 @@ Status XlaCompiler::BuildArguments(
|
|||||||
|
|
||||||
builder->ClearOpMetadata();
|
builder->ClearOpMetadata();
|
||||||
|
|
||||||
// Fill in the handles in non-constant arguments.
|
// Fill in the handles in non-constant arguments, and reshape parameters
|
||||||
|
// back to their correct shapes.
|
||||||
VLOG(2) << "XLA computation inputs:";
|
VLOG(2) << "XLA computation inputs:";
|
||||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||||
const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
|
const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
|
||||||
@ -586,7 +598,15 @@ Status XlaCompiler::BuildArguments(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case XlaCompiler::Argument::kParameter:
|
case XlaCompiler::Argument::kParameter:
|
||||||
arg_expression.set_handle(arg_handles[i]);
|
// Reshape parameters back to their correct shapes.
|
||||||
|
// TODO(b/76097077): propagate device assignments onto arguments and
|
||||||
|
// return values of functions, and then reshape unconditionally.
|
||||||
|
if (is_entry_computation) {
|
||||||
|
arg_expression.set_handle(
|
||||||
|
builder->Reshape(arg_handles[i], arg.shape.dim_sizes()));
|
||||||
|
} else {
|
||||||
|
arg_expression.set_handle(arg_handles[i]);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case XlaCompiler::Argument::kConstant:
|
case XlaCompiler::Argument::kConstant:
|
||||||
case XlaCompiler::Argument::kInvalid:
|
case XlaCompiler::Argument::kInvalid:
|
||||||
@ -661,10 +681,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
|||||||
FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
|
FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
|
||||||
|
|
||||||
xla::XlaBuilder builder(name);
|
xla::XlaBuilder builder(name);
|
||||||
XlaContext* context =
|
XlaContext* context = new XlaContext(
|
||||||
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
|
this, &builder, options_.allow_cpu_custom_calls,
|
||||||
options.resolve_compile_time_constants,
|
options.resolve_compile_time_constants, options.is_entry_computation,
|
||||||
&options_.variable_representation_shape_fn);
|
&options_.shape_representation_fn);
|
||||||
core::ScopedUnref context_unref(context);
|
core::ScopedUnref context_unref(context);
|
||||||
|
|
||||||
std::vector<XlaExpression> arg_expressions;
|
std::vector<XlaExpression> arg_expressions;
|
||||||
@ -681,35 +701,22 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
|||||||
int num_nonconst_outputs;
|
int num_nonconst_outputs;
|
||||||
int num_computation_outputs;
|
int num_computation_outputs;
|
||||||
result->computation = std::make_shared<xla::XlaComputation>();
|
result->computation = std::make_shared<xla::XlaComputation>();
|
||||||
|
result->outputs.resize(context->retvals().size());
|
||||||
TF_RETURN_IF_ERROR(BuildComputation(
|
TF_RETURN_IF_ERROR(BuildComputation(
|
||||||
args, arg_cores, context->retvals(), context->resources(),
|
args, arg_cores, context->retvals(), context->resources(),
|
||||||
options.return_updated_values_for_all_resources, &builder,
|
options.return_updated_values_for_all_resources, &builder,
|
||||||
result->computation.get(), &num_computation_outputs,
|
result->computation.get(), &num_computation_outputs,
|
||||||
&num_nonconst_outputs, &result->resource_updates));
|
&num_nonconst_outputs, &result->outputs, &result->resource_updates));
|
||||||
|
|
||||||
VLOG(2) << "Outputs: total: " << context->retvals().size()
|
VLOG(2) << "Outputs: total: " << context->retvals().size()
|
||||||
<< " nonconstant: " << num_nonconst_outputs;
|
<< " nonconstant: " << num_nonconst_outputs;
|
||||||
result->outputs.resize(context->retvals().size());
|
|
||||||
for (std::vector<XlaExpression>::size_type i = 0;
|
|
||||||
i < context->retvals().size(); ++i) {
|
|
||||||
const XlaExpression& retval = context->retvals()[i];
|
|
||||||
if (retval.has_constant_value()) {
|
|
||||||
OutputDescription& output = result->outputs[i];
|
|
||||||
output.shape = retval.constant_value().shape();
|
|
||||||
output.is_constant = true;
|
|
||||||
output.constant_value = retval.constant_value();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the output shapes, if there is a computation with non-constant
|
// Compute the XLA output shape, if there is a computation with non-constant
|
||||||
// outputs.
|
// outputs.
|
||||||
auto computation_shape = client()->GetComputationShape(*result->computation);
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> computation_shape,
|
||||||
if (!computation_shape.ok()) {
|
client()->GetComputationShape(*result->computation));
|
||||||
return computation_shape.status();
|
|
||||||
}
|
|
||||||
|
|
||||||
result->xla_output_shape.Swap(
|
result->xla_output_shape.Swap(computation_shape->mutable_result());
|
||||||
computation_shape.ValueOrDie()->mutable_result());
|
|
||||||
VLOG(2) << "XLA output shape: "
|
VLOG(2) << "XLA output shape: "
|
||||||
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
||||||
|
|
||||||
@ -724,23 +731,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
|||||||
// Tensorflow expects a major-to-minor order of results.
|
// Tensorflow expects a major-to-minor order of results.
|
||||||
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
|
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
|
||||||
|
|
||||||
// Converts the output shapes to TensorShapes.
|
|
||||||
int computation_output = 0;
|
|
||||||
for (std::vector<XlaExpression>::size_type i = 0;
|
|
||||||
i < context->retvals().size(); ++i) {
|
|
||||||
const XlaExpression& retval = context->retvals()[i];
|
|
||||||
if (!retval.has_constant_value()) {
|
|
||||||
TF_RET_CHECK(computation_output < num_computation_outputs)
|
|
||||||
<< "Computation has more outputs than expected";
|
|
||||||
OutputDescription& output = result->outputs[i];
|
|
||||||
output.is_constant = false;
|
|
||||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
|
|
||||||
xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
|
|
||||||
computation_output),
|
|
||||||
&output.shape));
|
|
||||||
++computation_output;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,6 +67,15 @@ class XlaContext;
|
|||||||
// _Retval values are ordered by _Retval index, whereas kResource values are
|
// _Retval values are ordered by _Retval index, whereas kResource values are
|
||||||
// ordered by the original _Arg position of the variable.
|
// ordered by the original _Arg position of the variable.
|
||||||
//
|
//
|
||||||
|
// If a shape representation function is provided as part of
|
||||||
|
// XlaCompiler::CompileOptions, kParameter arguments and return values to an
|
||||||
|
// entry computation will be reshaped in accordance to the shape function.
|
||||||
|
// Arguments and return values to a non-entry computation are not reshaped.
|
||||||
|
// Variable resource arguments are passed and returned in reshaped form, even
|
||||||
|
// for non-entry computations. This feature allows TensorFlow to keep on-device
|
||||||
|
// tensors with a different shape to their representation inside the XLA
|
||||||
|
// computation.
|
||||||
|
//
|
||||||
// In both inputs and outputs, kResource values are placed the end. When
|
// In both inputs and outputs, kResource values are placed the end. When
|
||||||
// emitting While loop bodies, we must ensure that the loop body has
|
// emitting While loop bodies, we must ensure that the loop body has
|
||||||
// identical input and output signatures. By moving variable values
|
// identical input and output signatures. By moving variable values
|
||||||
@ -171,7 +180,7 @@ class XlaCompiler {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct OutputDescription {
|
struct OutputDescription {
|
||||||
// Type and shape of the output.
|
// Type and shape of the output. The shape is the unflattened shape.
|
||||||
DataType type;
|
DataType type;
|
||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
|
|
||||||
@ -206,10 +215,12 @@ class XlaCompiler {
|
|||||||
// original arguments, and are not necessarily in the same order.)
|
// original arguments, and are not necessarily in the same order.)
|
||||||
std::vector<int> input_mapping;
|
std::vector<int> input_mapping;
|
||||||
|
|
||||||
// Input shapes of the computation.
|
// Input shapes of the computation. If we are flattening inputs, these are
|
||||||
|
// the flattened shapes.
|
||||||
std::vector<xla::Shape> xla_input_shapes;
|
std::vector<xla::Shape> xla_input_shapes;
|
||||||
|
|
||||||
// Output shape in XLA format. The output shape is always a tuple.
|
// Output shape in XLA format. The output shape is always a tuple. If we
|
||||||
|
// are flattening outputs, these are the flattened shapes.
|
||||||
xla::Shape xla_output_shape;
|
xla::Shape xla_output_shape;
|
||||||
|
|
||||||
// TensorFlow shapes of outputs, together with the values of any
|
// TensorFlow shapes of outputs, together with the values of any
|
||||||
@ -230,6 +241,8 @@ class XlaCompiler {
|
|||||||
std::shared_ptr<xla::XlaComputation> computation;
|
std::shared_ptr<xla::XlaComputation> computation;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
typedef std::function<TensorShape(const TensorShape&, DataType)>
|
||||||
|
ShapeRepresentationFn;
|
||||||
struct Options {
|
struct Options {
|
||||||
// Name of the compilation device to use. Needs to be live only during
|
// Name of the compilation device to use. Needs to be live only during
|
||||||
// XlaCompiler's constructor.
|
// XlaCompiler's constructor.
|
||||||
@ -250,8 +263,7 @@ class XlaCompiler {
|
|||||||
// If set, the XLA representation of variables represented to XLA as the
|
// If set, the XLA representation of variables represented to XLA as the
|
||||||
// shape given by this shape function. Variables are reshaped to this shape
|
// shape given by this shape function. Variables are reshaped to this shape
|
||||||
// on write, and reshaped to their original shape on read.
|
// on write, and reshaped to their original shape on read.
|
||||||
std::function<TensorShape(const TensorShape&, DataType)>
|
ShapeRepresentationFn shape_representation_fn;
|
||||||
variable_representation_shape_fn;
|
|
||||||
|
|
||||||
// If not nullptr, populate_resource_manager is called with the
|
// If not nullptr, populate_resource_manager is called with the
|
||||||
// compilation device's resource manager when the compilation
|
// compilation device's resource manager when the compilation
|
||||||
@ -300,7 +312,8 @@ class XlaCompiler {
|
|||||||
// Returns the shape of the XLA parameter for an argument 'arg'.
|
// Returns the shape of the XLA parameter for an argument 'arg'.
|
||||||
// See the class comment for more details about the argument passing
|
// See the class comment for more details about the argument passing
|
||||||
// convention.
|
// convention.
|
||||||
Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
|
Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
|
||||||
|
xla::Shape* xla_shape);
|
||||||
|
|
||||||
// Retrieves the channel handle associated with `key`. Allocates
|
// Retrieves the channel handle associated with `key`. Allocates
|
||||||
// a new channel handle if none exists.
|
// a new channel handle if none exists.
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||||
@ -750,10 +751,7 @@ TEST_F(XlaCompilerTest, Variables) {
|
|||||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests a simple graph that reads and writes a variable, with a
|
xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
|
||||||
// variable_representation_shape_fn passed to the compiler that flattens all
|
|
||||||
// variable tensors to vectors.
|
|
||||||
TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
|
||||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||||
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
|
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
|
||||||
@ -764,7 +762,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
|||||||
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
|
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
|
||||||
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
|
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
|
||||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
|
||||||
|
return std::move(graph);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests a simple graph that reads and writes a variable, with a
|
||||||
|
// shape_representation_fn passed to the compiler that flattens all
|
||||||
|
// variable tensors to vectors.
|
||||||
|
TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
|
||||||
|
|
||||||
// Builds a description of the arguments.
|
// Builds a description of the arguments.
|
||||||
std::vector<XlaCompiler::Argument> args(2);
|
std::vector<XlaCompiler::Argument> args(2);
|
||||||
@ -779,15 +785,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
|||||||
|
|
||||||
// Compiles the graph.
|
// Compiles the graph.
|
||||||
XlaCompiler::Options options = DefaultOptions();
|
XlaCompiler::Options options = DefaultOptions();
|
||||||
options.variable_representation_shape_fn = [](const TensorShape& shape,
|
options.shape_representation_fn = [](const TensorShape& shape,
|
||||||
DataType type) {
|
DataType type) {
|
||||||
return TensorShape({shape.num_elements()});
|
return TensorShape({shape.num_elements()});
|
||||||
};
|
};
|
||||||
XlaCompiler compiler(options);
|
XlaCompiler compiler(options);
|
||||||
|
|
||||||
|
XlaCompiler::CompileOptions compile_options;
|
||||||
|
compile_options.is_entry_computation = false; // Only reshape variables.
|
||||||
|
|
||||||
XlaCompiler::CompilationResult result;
|
XlaCompiler::CompilationResult result;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||||
std::move(graph), args, &result));
|
args, &result));
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
||||||
|
client_->GetComputationShape(*result.computation));
|
||||||
|
|
||||||
|
ASSERT_EQ(program_shape->parameters_size(), 2);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
xla::ShapeUtil::Compatible(program_shape->parameters(0),
|
||||||
|
xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
|
||||||
|
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||||
|
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
|
||||||
|
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||||
|
program_shape->result(),
|
||||||
|
xla::ShapeUtil::MakeTupleShape(
|
||||||
|
{xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
|
||||||
|
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
|
||||||
|
|
||||||
// Tests that the generated computation works.
|
// Tests that the generated computation works.
|
||||||
std::unique_ptr<xla::Literal> param0_literal =
|
std::unique_ptr<xla::Literal> param0_literal =
|
||||||
@ -815,5 +839,74 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
|||||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
|
||||||
|
|
||||||
|
// Builds a description of the arguments.
|
||||||
|
std::vector<XlaCompiler::Argument> args(2);
|
||||||
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[0].type = DT_INT32;
|
||||||
|
args[0].shape = TensorShape({2, 2});
|
||||||
|
args[1].kind = XlaCompiler::Argument::kResource;
|
||||||
|
args[1].resource_kind = XlaResource::kVariable;
|
||||||
|
args[1].initialized = true;
|
||||||
|
args[1].type = DT_INT32;
|
||||||
|
args[1].shape = TensorShape({2, 2});
|
||||||
|
|
||||||
|
// Compiles the graph.
|
||||||
|
XlaCompiler::Options options = DefaultOptions();
|
||||||
|
options.shape_representation_fn = [](const TensorShape& shape,
|
||||||
|
DataType type) {
|
||||||
|
return TensorShape({shape.num_elements()});
|
||||||
|
};
|
||||||
|
XlaCompiler compiler(options);
|
||||||
|
|
||||||
|
XlaCompiler::CompileOptions compile_options;
|
||||||
|
compile_options.is_entry_computation = true; // Reshape args and retvals.
|
||||||
|
|
||||||
|
XlaCompiler::CompilationResult result;
|
||||||
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||||
|
args, &result));
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
||||||
|
client_->GetComputationShape(*result.computation));
|
||||||
|
|
||||||
|
ASSERT_EQ(program_shape->parameters_size(), 2);
|
||||||
|
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||||
|
program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
|
||||||
|
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||||
|
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
|
||||||
|
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||||
|
program_shape->result(),
|
||||||
|
xla::ShapeUtil::MakeTupleShape(
|
||||||
|
{xla::ShapeUtil::MakeShape(xla::S32, {4}),
|
||||||
|
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
|
||||||
|
|
||||||
|
// Tests that the generated computation works.
|
||||||
|
std::unique_ptr<xla::Literal> param0_literal =
|
||||||
|
xla::Literal::CreateR1<int32>({4, 55, 1, -3});
|
||||||
|
std::unique_ptr<xla::Literal> param1_literal =
|
||||||
|
xla::Literal::CreateR1<int32>({22, 11, 33, 404});
|
||||||
|
std::unique_ptr<xla::GlobalData> param0_data =
|
||||||
|
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||||
|
std::unique_ptr<xla::GlobalData> param1_data =
|
||||||
|
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||||
|
|
||||||
|
std::unique_ptr<xla::GlobalData> actual =
|
||||||
|
client_
|
||||||
|
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
|
||||||
|
.ConsumeValueOrDie();
|
||||||
|
std::unique_ptr<xla::Literal> actual_literal =
|
||||||
|
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||||
|
|
||||||
|
std::unique_ptr<xla::Literal> expected0 =
|
||||||
|
xla::Literal::CreateR1<int32>({27, 67, 35, 402});
|
||||||
|
std::unique_ptr<xla::Literal> expected1 =
|
||||||
|
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
|
||||||
|
std::unique_ptr<xla::Literal> expected_literal =
|
||||||
|
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||||
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -65,26 +65,30 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
|
|||||||
XlaContext::XlaContext(
|
XlaContext::XlaContext(
|
||||||
XlaCompiler* compiler, xla::XlaBuilder* builder,
|
XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||||
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
||||||
|
bool is_entry_computation,
|
||||||
const std::function<TensorShape(const TensorShape&, DataType)>*
|
const std::function<TensorShape(const TensorShape&, DataType)>*
|
||||||
variable_representation_shape_fn)
|
shape_representation_fn)
|
||||||
: compiler_(compiler),
|
: compiler_(compiler),
|
||||||
builder_(builder),
|
builder_(builder),
|
||||||
allow_cpu_custom_calls_(allow_cpu_custom_calls),
|
allow_cpu_custom_calls_(allow_cpu_custom_calls),
|
||||||
resolve_compile_time_constants_(resolve_compile_time_constants),
|
resolve_compile_time_constants_(resolve_compile_time_constants),
|
||||||
variable_representation_shape_fn_(variable_representation_shape_fn) {}
|
is_entry_computation_(is_entry_computation),
|
||||||
|
shape_representation_fn_(shape_representation_fn) {}
|
||||||
|
|
||||||
string XlaContext::DebugString() { return "TLA JIT context"; }
|
string XlaContext::DebugString() { return "TLA JIT context"; }
|
||||||
|
|
||||||
// This is called by the Retval Op to associate a computed value
|
// This is called by the Retval Op to associate a computed value
|
||||||
// with a specific return value of the subgraph.
|
// with a specific return value of the subgraph.
|
||||||
void XlaContext::AddRetval(int retval_index, DataType type,
|
void XlaContext::AddRetval(int retval_index, DataType type,
|
||||||
const xla::XlaOp& handle) {
|
const TensorShape& shape, const xla::XlaOp& handle) {
|
||||||
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
|
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
|
||||||
// Add the return value to the list being built up.
|
// Add the return value to the list being built up.
|
||||||
if (retvals_.size() <= retval_index) {
|
if (retvals_.size() <= retval_index) {
|
||||||
retvals_.resize(retval_index + 1);
|
retvals_.resize(retval_index + 1);
|
||||||
}
|
}
|
||||||
retvals_[retval_index].set_handle(handle);
|
XlaExpression e;
|
||||||
|
e.set_handle(handle);
|
||||||
|
retvals_[retval_index] = Retval{type, shape, e};
|
||||||
}
|
}
|
||||||
|
|
||||||
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||||
@ -94,13 +98,11 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
|||||||
if (retvals_.size() <= retval_index) {
|
if (retvals_.size() <= retval_index) {
|
||||||
retvals_.resize(retval_index + 1);
|
retvals_.resize(retval_index + 1);
|
||||||
}
|
}
|
||||||
if (resolve_compile_time_constants_) {
|
Tensor value;
|
||||||
Tensor value;
|
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
|
||||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
|
XlaExpression e;
|
||||||
retvals_[retval_index].set_constant_value(std::move(value));
|
e.set_constant_value(value);
|
||||||
} else {
|
retvals_[retval_index] = Retval{dtype, value.shape(), e};
|
||||||
retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,9 +119,9 @@ Status XlaContext::CreateResource(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape,
|
TensorShape XlaContext::RepresentationShape(const TensorShape& shape,
|
||||||
DataType type) const {
|
DataType type) const {
|
||||||
return (*variable_representation_shape_fn_)(shape, type);
|
return (*shape_representation_fn_)(shape, type);
|
||||||
}
|
}
|
||||||
|
|
||||||
const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
|
const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
|
||||||
|
@ -42,11 +42,13 @@ class XlaContext : public ResourceBase {
|
|||||||
static XlaContext& Get(const OpKernelContext* ctx);
|
static XlaContext& Get(const OpKernelContext* ctx);
|
||||||
static XlaContext& Get(const XlaOpKernelContext* ctx);
|
static XlaContext& Get(const XlaOpKernelContext* ctx);
|
||||||
|
|
||||||
// Creates a new XlaContext.
|
// Creates a new XlaContext. See the documentation on the class data fields
|
||||||
|
// for descriptions of the arguments.
|
||||||
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
|
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||||
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
||||||
|
bool is_entry_computation,
|
||||||
const std::function<TensorShape(const TensorShape&, DataType)>*
|
const std::function<TensorShape(const TensorShape&, DataType)>*
|
||||||
variable_representation_shape_fn);
|
shape_representation_fn);
|
||||||
|
|
||||||
// Virtual method defined by ResourceBase.
|
// Virtual method defined by ResourceBase.
|
||||||
string DebugString() override;
|
string DebugString() override;
|
||||||
@ -58,14 +60,26 @@ class XlaContext : public ResourceBase {
|
|||||||
|
|
||||||
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
|
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
|
||||||
|
|
||||||
|
bool resolve_compile_time_constants() const {
|
||||||
|
return resolve_compile_time_constants_;
|
||||||
|
}
|
||||||
|
bool is_entry_computation() const { return is_entry_computation_; }
|
||||||
|
|
||||||
const std::vector<XlaExpression>& args() const { return args_; }
|
const std::vector<XlaExpression>& args() const { return args_; }
|
||||||
void set_args(std::vector<XlaExpression> args);
|
void set_args(std::vector<XlaExpression> args);
|
||||||
|
|
||||||
const std::vector<XlaExpression>& retvals() { return retvals_; }
|
struct Retval {
|
||||||
|
DataType type;
|
||||||
|
TensorShape shape;
|
||||||
|
// An XlaExpression representing the Retval's value.
|
||||||
|
XlaExpression expression;
|
||||||
|
};
|
||||||
|
const std::vector<Retval>& retvals() { return retvals_; }
|
||||||
|
|
||||||
// This is called by the Retval Op to associate a computed value
|
// This is called by the Retval Op to associate a computed value
|
||||||
// with a specific return value of the subgraph.
|
// with a specific return value of the subgraph.
|
||||||
void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle);
|
void AddRetval(int retval_index, DataType type, const TensorShape& shape,
|
||||||
|
const xla::XlaOp& handle);
|
||||||
|
|
||||||
// As for Retval, but for return values that are compile-time constants.
|
// As for Retval, but for return values that are compile-time constants.
|
||||||
Status AddConstRetval(int retval_index, DataType dtype,
|
Status AddConstRetval(int retval_index, DataType dtype,
|
||||||
@ -86,9 +100,9 @@ class XlaContext : public ResourceBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Returns the XLA shape to be used to represent a variable of TF `shape`
|
// Returns the XLA shape to be used to represent a variable of TF `shape`
|
||||||
// and `type`.
|
// and `type`, or of an argument or return value of a top-level computation.
|
||||||
TensorShape VariableRepresentationShape(const TensorShape& shape,
|
TensorShape RepresentationShape(const TensorShape& shape,
|
||||||
DataType type) const;
|
DataType type) const;
|
||||||
|
|
||||||
// Get an XLA lambda to compute Max. This is cached in the
|
// Get an XLA lambda to compute Max. This is cached in the
|
||||||
// XlaContext since it may be used by multiple Ops. There is a
|
// XlaContext since it may be used by multiple Ops. There is a
|
||||||
@ -131,15 +145,23 @@ class XlaContext : public ResourceBase {
|
|||||||
std::vector<XlaExpression> args_;
|
std::vector<XlaExpression> args_;
|
||||||
|
|
||||||
// Return values of the Tensorflow graph, indexed by _Retval index.
|
// Return values of the Tensorflow graph, indexed by _Retval index.
|
||||||
std::vector<XlaExpression> retvals_;
|
std::vector<Retval> retvals_;
|
||||||
|
|
||||||
// Holds ownership of resources. The resources are not ordered.
|
// Holds ownership of resources. The resources are not ordered.
|
||||||
std::vector<std::unique_ptr<XlaResource>> resources_;
|
std::vector<std::unique_ptr<XlaResource>> resources_;
|
||||||
|
|
||||||
// A function that describes how variable shapes should be represented
|
// Is this a top-level computation, or an inner computation (e.g., a while
|
||||||
// in XLA. Variable values will be reshaped to this shape. Must be non-null.
|
// body)?
|
||||||
|
const bool is_entry_computation_;
|
||||||
|
|
||||||
|
// A function that describes how the shapes of
|
||||||
|
// a) argument and return value, for entry computations
|
||||||
|
// b) variables, for all computations,
|
||||||
|
// should be represented in XLA. Parameters/return values will be shaped
|
||||||
|
// according to this function, and reshaped back to/from their declared shapes
|
||||||
|
// for computations. Must be non-null.
|
||||||
const std::function<TensorShape(const TensorShape&, DataType)>*
|
const std::function<TensorShape(const TensorShape&, DataType)>*
|
||||||
variable_representation_shape_fn_;
|
shape_representation_fn_;
|
||||||
|
|
||||||
// Cache of prebuilt computations indexed by their type.
|
// Cache of prebuilt computations indexed by their type.
|
||||||
using ComputationMap = std::map<DataType, xla::XlaComputation>;
|
using ComputationMap = std::map<DataType, xla::XlaComputation>;
|
||||||
|
@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
|
|||||||
}
|
}
|
||||||
|
|
||||||
XlaContext& xla_context = XlaContext::Get(context_);
|
XlaContext& xla_context = XlaContext::Get(context_);
|
||||||
TensorShape representation_shape = xla_context.VariableRepresentationShape(
|
TensorShape representation_shape =
|
||||||
variable->shape(), variable->type());
|
xla_context.RepresentationShape(variable->shape(), variable->type());
|
||||||
if (representation_shape == variable->shape()) {
|
if (representation_shape == variable->shape()) {
|
||||||
*value = variable->value();
|
*value = variable->value();
|
||||||
} else {
|
} else {
|
||||||
@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
|
|||||||
|
|
||||||
XlaContext& xla_context = XlaContext::Get(context_);
|
XlaContext& xla_context = XlaContext::Get(context_);
|
||||||
TensorShape representation_shape =
|
TensorShape representation_shape =
|
||||||
xla_context.VariableRepresentationShape(shape, type);
|
xla_context.RepresentationShape(shape, type);
|
||||||
if (shape != representation_shape) {
|
if (shape != representation_shape) {
|
||||||
handle = builder()->Reshape(handle, representation_shape.dim_sizes());
|
handle = builder()->Reshape(handle, representation_shape.dim_sizes());
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user