Automated g4 rollback of changelist 196691101
PiperOrigin-RevId: 196879933
This commit is contained in:
parent
c9e4705d62
commit
41af9782f4
@ -551,14 +551,16 @@ TEST(TFCompileTest, HloProfiling) {
|
||||
auto header = HasSubstr("Execution profile for");
|
||||
auto total_cycles_profile_line = HasSubstr("[total]");
|
||||
auto dot_profile_line = HasSubstr(
|
||||
"%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
||||
"%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
||||
"%arg1.0.1)");
|
||||
auto add_profile_line = HasSubstr(
|
||||
"%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
||||
"%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
||||
"%arg1.0.1)");
|
||||
auto tuple_profile_line = HasSubstr(
|
||||
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
|
||||
"%dot.0.2, f32[2,2]{1,0} %add.0.5)");
|
||||
"%dot.0.4, f32[2,2]{1,0} %add.0.6)");
|
||||
auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
|
||||
auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
|
||||
|
||||
EXPECT_THAT(hlo_profile_lines,
|
||||
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
||||
|
@ -112,7 +112,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
// this is more obviously correct.)
|
||||
core::ScopedUnref cache_ref(cache);
|
||||
|
||||
const XlaDevice::Metadata* metadata;
|
||||
const XlaDevice::Metadata* metadata = nullptr;
|
||||
Status s = XlaDevice::GetMetadata(ctx, &metadata);
|
||||
bool allocate_xla_tensors = s.ok();
|
||||
|
||||
@ -153,9 +153,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
options.graph_def_version = ctx->function_library()->graph_def_version();
|
||||
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
|
||||
options.device_allocator = xla_allocator;
|
||||
// TODO(b/77671268): We don't set variable_representation_shape_fn here. This
|
||||
// is restricted to Variables, but we need something like this to apply to
|
||||
// normal Tensors too.
|
||||
if (metadata) {
|
||||
options.shape_representation_fn = metadata->shape_representation_fn();
|
||||
}
|
||||
|
||||
const XlaCompiler::CompilationResult* kernel;
|
||||
xla::LocalExecutable* executable;
|
||||
@ -164,9 +164,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
for (int i : constants_) {
|
||||
constant_args.insert({i, ctx->input(i)});
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
|
||||
variables, ctx, &kernel, &executable,
|
||||
/*compile_options=*/nullptr));
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, cache->Compile(options, function_, constant_args, variables, ctx,
|
||||
&kernel, &executable, &compile_options));
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
|
@ -156,11 +156,14 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
options.client = metadata.client();
|
||||
options.flib_def =
|
||||
new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{});
|
||||
options.shape_representation_fn = metadata.shape_representation_fn();
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true;
|
||||
|
||||
std::map<int, OptionalTensor> variable_args = GetVariables(ctx);
|
||||
return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx,
|
||||
result, executable,
|
||||
/*compile_options=*/nullptr);
|
||||
result, executable, &compile_options);
|
||||
}
|
||||
|
||||
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
|
||||
|
@ -50,10 +50,11 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
(void)registrations;
|
||||
|
||||
std::unique_ptr<XlaDevice> device;
|
||||
TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
|
||||
DEVICE_CPU_XLA_JIT, options, name_prefix,
|
||||
registration,
|
||||
/*transfer_as_literal=*/false, &device));
|
||||
TF_RETURN_IF_ERROR(
|
||||
XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options,
|
||||
name_prefix, registration,
|
||||
/*transfer_as_literal=*/false,
|
||||
/*shape_representation_fn=*/{}, &device));
|
||||
devices->push_back(device.release());
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -110,7 +110,9 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
|
||||
const string& jit_device_name, const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
const XlaOpRegistry::DeviceRegistration& registration,
|
||||
bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) {
|
||||
bool transfer_as_literal,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||
std::unique_ptr<XlaDevice>* device) {
|
||||
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
|
||||
<< device_ordinal;
|
||||
|
||||
@ -129,17 +131,19 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
|
||||
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
|
||||
strings::StrCat("device: ", device_name, " device"));
|
||||
|
||||
device->reset(new XlaDevice(options, attrs, device_ordinal,
|
||||
DeviceType(jit_device_name),
|
||||
platform.ValueOrDie(), transfer_as_literal));
|
||||
device->reset(new XlaDevice(
|
||||
options, attrs, device_ordinal, DeviceType(jit_device_name),
|
||||
platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
|
||||
const DeviceType& device_type)
|
||||
XlaDevice::Metadata::Metadata(
|
||||
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||
: device_ordinal_(device_ordinal),
|
||||
device_type_(device_type),
|
||||
platform_(platform) {}
|
||||
platform_(platform),
|
||||
shape_representation_fn_(std::move(shape_representation_fn)) {}
|
||||
|
||||
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
|
||||
|
||||
@ -170,17 +174,20 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
XlaDevice::XlaDevice(const SessionOptions& options,
|
||||
const DeviceAttributes& attrs, int device_ordinal,
|
||||
const DeviceType& jit_device_name, se::Platform* platform,
|
||||
bool transfer_as_literal)
|
||||
XlaDevice::XlaDevice(
|
||||
const SessionOptions& options, const DeviceAttributes& attrs,
|
||||
int device_ordinal, const DeviceType& jit_device_name,
|
||||
se::Platform* platform, bool transfer_as_literal,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
|
||||
: LocalDevice(options, attrs),
|
||||
xla_metadata_(device_ordinal, platform, jit_device_name),
|
||||
xla_metadata_(device_ordinal, platform, jit_device_name,
|
||||
shape_representation_fn),
|
||||
device_ordinal_(device_ordinal),
|
||||
jit_device_name_(jit_device_name),
|
||||
xla_allocator_(nullptr),
|
||||
platform_(platform),
|
||||
transfer_as_literal_(transfer_as_literal) {
|
||||
transfer_as_literal_(transfer_as_literal),
|
||||
shape_representation_fn_(shape_representation_fn) {
|
||||
VLOG(1) << "Created XLA device " << jit_device_name;
|
||||
}
|
||||
|
||||
@ -232,8 +239,8 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
|
||||
// gpu_device_info_->default_context.
|
||||
gpu_device_info_ = MakeUnique<GpuDeviceInfo>();
|
||||
gpu_device_info_->stream = stream;
|
||||
gpu_device_info_->default_context =
|
||||
new XlaDeviceContext(stream, client(), transfer_as_literal_);
|
||||
gpu_device_info_->default_context = new XlaDeviceContext(
|
||||
stream, client(), transfer_as_literal_, shape_representation_fn_);
|
||||
set_tensorflow_gpu_device_info(gpu_device_info_.get());
|
||||
}
|
||||
|
||||
@ -247,7 +254,8 @@ Status XlaDevice::FillContextMap(const Graph* graph,
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||
// Call GetAllocator for the side-effect of ensuring the allocator is created.
|
||||
GetAllocator({});
|
||||
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_);
|
||||
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
|
||||
shape_representation_fn_);
|
||||
for (Node* n : graph->nodes()) {
|
||||
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
|
||||
ctx->Ref();
|
||||
@ -294,7 +302,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
|
||||
Notification n;
|
||||
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
|
||||
XlaTransferManager manager(stream, client(), transfer_as_literal_);
|
||||
XlaTransferManager manager(stream, client(), transfer_as_literal_,
|
||||
shape_representation_fn_);
|
||||
manager.CopyCPUTensorToDevice(&parsed, this, ©,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
|
@ -17,8 +17,7 @@ limitations under the License.
|
||||
// runtime.
|
||||
//
|
||||
// Operators assigned to an XlaDevice are compiled into XLA computations.
|
||||
// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state
|
||||
// is managed by XLA.
|
||||
// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
|
||||
//
|
||||
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
|
||||
// under different names (e.g., XLA_CPU or XLA_GPU).
|
||||
@ -27,6 +26,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
@ -50,7 +50,8 @@ class XlaDevice : public LocalDevice {
|
||||
class Metadata {
|
||||
public:
|
||||
Metadata(int device_ordinal, se::Platform* platform,
|
||||
const DeviceType& device_type);
|
||||
const DeviceType& device_type,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||
|
||||
// The index of the device on this host.
|
||||
int device_ordinal() const;
|
||||
@ -58,11 +59,15 @@ class XlaDevice : public LocalDevice {
|
||||
se::Platform* platform() const;
|
||||
xla::LocalClient* client() const;
|
||||
const DeviceType& jit_device_type() const;
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
|
||||
return shape_representation_fn_;
|
||||
}
|
||||
|
||||
private:
|
||||
const int device_ordinal_;
|
||||
const DeviceType device_type_;
|
||||
se::Platform* platform_; // Not owned.
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
|
||||
};
|
||||
@ -76,16 +81,19 @@ class XlaDevice : public LocalDevice {
|
||||
// 'transfer_as_literal' is true if device<->host transfers must be done using
|
||||
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
|
||||
// ThenMemcpy instead.
|
||||
static Status Create(const string& platform_name, const string& device_name,
|
||||
int device_ordinal, const string& jit_device_name,
|
||||
const SessionOptions& options, const string& name_prefix,
|
||||
const XlaOpRegistry::DeviceRegistration& registration,
|
||||
bool transfer_as_literal,
|
||||
std::unique_ptr<XlaDevice>* device);
|
||||
static Status Create(
|
||||
const string& platform_name, const string& device_name,
|
||||
int device_ordinal, const string& jit_device_name,
|
||||
const SessionOptions& options, const string& name_prefix,
|
||||
const XlaOpRegistry::DeviceRegistration& registration,
|
||||
bool transfer_as_literal,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
|
||||
std::unique_ptr<XlaDevice>* device);
|
||||
|
||||
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
|
||||
int device_ordinal, const DeviceType& jit_device_name,
|
||||
se::Platform* platform, bool transfer_as_literal);
|
||||
se::Platform* platform, bool transfer_as_literal,
|
||||
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn);
|
||||
~XlaDevice() override;
|
||||
|
||||
Allocator* GetAllocator(AllocatorAttributes attr) override;
|
||||
@ -116,8 +124,8 @@ class XlaDevice : public LocalDevice {
|
||||
// The name of the device that is used to compile Ops for this XlaDevice.
|
||||
DeviceType jit_device_name_;
|
||||
// Memory allocator associated with this device.
|
||||
Allocator* xla_allocator_; // Not owned.
|
||||
se::Platform* platform_; // Not owned.
|
||||
Allocator* xla_allocator_; // Not owned.
|
||||
se::Platform* platform_; // Not owned.
|
||||
// Stream associated with this device. Operations enqueued on this
|
||||
// stream are executed on the device. Operations include data
|
||||
// copying back and forth between CPU and the device, and
|
||||
@ -126,6 +134,7 @@ class XlaDevice : public LocalDevice {
|
||||
// Must we use XLA's transfer manager for correct host<->device transfers? if
|
||||
// false, we can use ThenMemcpy() instead.
|
||||
bool transfer_as_literal_;
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
|
||||
// If set, holds default device context (that we must Unref)
|
||||
// and its stream.
|
||||
|
@ -47,13 +47,14 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
|
||||
|
||||
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
|
||||
|
||||
XlaTransferManager::XlaTransferManager(se::Stream* stream,
|
||||
xla::LocalClient* client,
|
||||
bool transfer_as_literal)
|
||||
XlaTransferManager::XlaTransferManager(
|
||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||
: stream_(stream),
|
||||
client_(client),
|
||||
transfer_manager_(client->backend().transfer_manager()),
|
||||
transfer_as_literal_(transfer_as_literal) {}
|
||||
transfer_as_literal_(transfer_as_literal),
|
||||
shape_representation_fn_(std::move(shape_representation_fn)) {}
|
||||
|
||||
Status XlaTransferManager::TransferLiteralToDevice(
|
||||
const Tensor& host_tensor, Tensor* device_tensor) const {
|
||||
@ -76,7 +77,15 @@ Status XlaTransferManager::TransferLiteralFromDevice(
|
||||
transfer_manager_->TransferLiteralFromDevice(
|
||||
stream_->parent(), shaped_buffer));
|
||||
VLOG(1) << "Transfer from device as literal: " << literal->ToString();
|
||||
return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor);
|
||||
Tensor tensor;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
|
||||
// Reshape the tensor back to its declared shape.
|
||||
if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
|
||||
return errors::Internal(
|
||||
"Tensor::CopyFrom failed when copying from XLA device to CPU");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
@ -96,9 +105,17 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
|
||||
CHECK(xla_tensor);
|
||||
|
||||
TensorShape shape;
|
||||
if (shape_representation_fn_) {
|
||||
shape = shape_representation_fn_(device_tensor->shape(),
|
||||
device_tensor->dtype());
|
||||
} else {
|
||||
shape = device_tensor->shape();
|
||||
}
|
||||
if (!xla_tensor->has_shaped_buffer()) {
|
||||
Status s = xla_tensor->AllocateShapedBuffer(
|
||||
device_tensor->dtype(), device_tensor->shape(), client_,
|
||||
device_tensor->dtype(), shape, client_,
|
||||
stream_->parent()->device_ordinal());
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
@ -106,12 +123,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
}
|
||||
}
|
||||
|
||||
se::DeviceMemoryBase dev_dst_ptr =
|
||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||
Status status;
|
||||
if (transfer_as_literal_) {
|
||||
status = TransferLiteralToDevice(*cpu_tensor, device_tensor);
|
||||
Tensor reshaped_cpu_tensor;
|
||||
if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
|
||||
done(errors::Internal(
|
||||
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
|
||||
return;
|
||||
}
|
||||
status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
|
||||
} else {
|
||||
se::DeviceMemoryBase dev_dst_ptr =
|
||||
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
|
||||
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
|
||||
// TODO(hpucha): Make this asynchronous.
|
||||
Status block_status = stream_->BlockHostUntilDone();
|
||||
@ -171,9 +194,11 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
|
||||
done(Status::OK());
|
||||
}
|
||||
|
||||
XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
|
||||
bool transfer_as_literal)
|
||||
: manager_(stream, client, transfer_as_literal) {}
|
||||
XlaDeviceContext::XlaDeviceContext(
|
||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||
: manager_(stream, client, transfer_as_literal,
|
||||
std::move(shape_representation_fn)) {}
|
||||
|
||||
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
|
||||
Device* device,
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/jit/xla_tensor.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
@ -45,8 +46,9 @@ class XlaDeviceAllocator : public Allocator {
|
||||
// Helper class for managing data transfers between host and XLA devices.
|
||||
class XlaTransferManager {
|
||||
public:
|
||||
explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client,
|
||||
bool transfer_as_literal);
|
||||
explicit XlaTransferManager(
|
||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor, StatusCallback done) const;
|
||||
@ -69,7 +71,8 @@ class XlaTransferManager {
|
||||
// Transfer manager, for marshalling data to and from the device.
|
||||
xla::TransferManager* transfer_manager_;
|
||||
// True if we must use XLA's TransferManager for correct device transfers.
|
||||
bool transfer_as_literal_;
|
||||
const bool transfer_as_literal_;
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
};
|
||||
|
||||
// DeviceContext for operators assigned to XlaDevice devices. The
|
||||
@ -77,8 +80,9 @@ class XlaTransferManager {
|
||||
// wraps the methods in XlaTransferManager.
|
||||
class XlaDeviceContext : public DeviceContext {
|
||||
public:
|
||||
explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
|
||||
bool transfer_as_literal);
|
||||
explicit XlaDeviceContext(
|
||||
se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn);
|
||||
|
||||
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
|
||||
Tensor* device_tensor,
|
||||
|
@ -48,7 +48,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
Status status =
|
||||
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
|
||||
name_prefix, registration,
|
||||
/*transfer_as_literal=*/false, &device);
|
||||
/*transfer_as_literal=*/false,
|
||||
/*shape_representation_fn=*/{}, &device);
|
||||
if (!status.ok()) {
|
||||
// Treat failures as non-fatal; there might not be a GPU in the machine.
|
||||
VLOG(1) << "Failed to create XLA_GPU device: " << status;
|
||||
|
@ -195,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs(
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer(
|
||||
const_tensor.dtype(), const_tensor.shape(),
|
||||
client_, stream->parent()->device_ordinal()));
|
||||
}
|
||||
|
||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||
OP_REQUIRES(ctx, device != nullptr,
|
||||
|
@ -42,7 +42,7 @@ py_library(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:session",
|
||||
@ -58,7 +58,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
@ -72,7 +72,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
@ -93,7 +93,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -111,7 +111,7 @@ tf_xla_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:bitwise_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:math_ops_gen",
|
||||
"//tensorflow/python:nn_ops",
|
||||
@ -127,7 +127,7 @@ tf_xla_py_test(
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:random_ops",
|
||||
],
|
||||
@ -141,7 +141,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
@ -156,7 +156,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
@ -170,7 +170,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -184,7 +184,7 @@ tf_xla_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:array_ops_gen",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:gradient_checker",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -209,7 +209,7 @@ tf_xla_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:array_ops_gen",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:gradient_checker",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -225,7 +225,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
@ -241,7 +241,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
@ -263,7 +263,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
@ -291,7 +291,7 @@ tf_xla_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -307,7 +307,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -326,7 +326,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn",
|
||||
@ -346,7 +346,7 @@ tf_xla_py_test(
|
||||
"//tensorflow/contrib/signal:signal_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:spectral_ops",
|
||||
],
|
||||
@ -360,7 +360,7 @@ tf_xla_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -372,7 +372,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
@ -388,7 +388,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -403,7 +403,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:image_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -431,7 +431,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:nn",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -446,7 +446,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -458,7 +458,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
@ -472,7 +472,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -485,7 +485,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -498,7 +498,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -513,7 +513,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -530,7 +530,7 @@ tf_xla_py_test(
|
||||
],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:random_ops",
|
||||
],
|
||||
@ -545,7 +545,7 @@ tf_xla_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -561,7 +561,7 @@ tf_xla_py_test(
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -574,7 +574,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
],
|
||||
)
|
||||
|
||||
@ -586,7 +586,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -598,7 +598,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
@ -613,7 +613,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -626,7 +626,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:math_ops_gen",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -641,7 +641,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -657,7 +657,7 @@ tf_xla_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -670,7 +670,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/contrib/stateless",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -684,7 +684,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:math_ops_gen",
|
||||
"//tensorflow/python:nn_ops",
|
||||
@ -703,7 +703,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -716,7 +716,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:nn_ops_gen",
|
||||
@ -730,7 +730,7 @@ tf_xla_py_test(
|
||||
srcs = ["fused_batchnorm_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:math_ops_gen",
|
||||
"//tensorflow/python:nn",
|
||||
@ -749,7 +749,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:math_ops_gen",
|
||||
"//tensorflow/python:nn_ops",
|
||||
@ -768,7 +768,7 @@ tf_xla_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/compiler/tf2xla/python:xla",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:training",
|
||||
],
|
||||
@ -783,7 +783,7 @@ tf_xla_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -795,7 +795,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -808,21 +808,34 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "xla_device_test",
|
||||
size = "small",
|
||||
srcs = ["xla_device_test.py"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "xla_device_test",
|
||||
name = "xla_device_gpu_test",
|
||||
size = "small",
|
||||
srcs = ["xla_device_test.py"],
|
||||
srcs = ["xla_device_gpu_test.py"],
|
||||
additional_deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
@ -839,7 +852,6 @@ cuda_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -887,7 +899,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:variables",
|
||||
@ -902,7 +914,7 @@ cuda_py_test(
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -940,7 +952,7 @@ tf_xla_py_test(
|
||||
srcs = ["fake_quant_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
@ -952,7 +964,7 @@ tf_xla_py_test(
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
48
tensorflow/compiler/tests/xla_device_gpu_test.py
Normal file
48
tensorflow/compiler/tests/xla_device_gpu_test.py
Normal file
@ -0,0 +1,48 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Test cases for XLA devices."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class XlaDeviceGpuTest(test.TestCase):
|
||||
|
||||
def testCopiesToAndFromGpuWork(self):
|
||||
"""Tests that copies between GPU and XLA devices work."""
|
||||
if not test.is_gpu_available():
|
||||
return
|
||||
|
||||
with session_lib.Session() as sess:
|
||||
x = array_ops.placeholder(dtypes.float32, [2])
|
||||
with ops.device("GPU"):
|
||||
y = x * 2
|
||||
with ops.device("device:XLA_CPU:0"):
|
||||
z = y * y
|
||||
with ops.device("GPU"):
|
||||
w = y + z
|
||||
result = sess.run(w, {x: [1.5, 0.5]})
|
||||
self.assertAllClose(result, [12., 2.], rtol=1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -18,30 +18,33 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.framework import dtypes
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class XlaDeviceTest(test.TestCase):
|
||||
class XlaDeviceTest(XLATestCase):
|
||||
|
||||
def testCopies(self):
|
||||
"""Tests that copies between GPU and XLA devices work."""
|
||||
if not test.is_gpu_available():
|
||||
return
|
||||
"""Tests that copies onto and off XLA devices work."""
|
||||
shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],
|
||||
[16384, 1], [1, 16384], [1, 20000, 1, 1]]
|
||||
for dtype in self.numeric_types:
|
||||
for shape in shapes:
|
||||
with self.test_session() as sess:
|
||||
with ops.device("CPU"):
|
||||
x = array_ops.placeholder(dtype, shape)
|
||||
with self.test_scope():
|
||||
y = x + x
|
||||
with ops.device("CPU"):
|
||||
z = array_ops.identity(y)
|
||||
|
||||
with session_lib.Session() as sess:
|
||||
x = array_ops.placeholder(dtypes.float32, [2])
|
||||
with ops.device("GPU"):
|
||||
y = x * 2
|
||||
with ops.device("device:XLA_CPU:0"):
|
||||
z = y * y
|
||||
with ops.device("GPU"):
|
||||
w = y + z
|
||||
result = sess.run(w, {x: [1.5, 0.5]})
|
||||
self.assertAllClose(result, [12., 2.], rtol=1e-3)
|
||||
inputs = np.random.randint(-100, 100, shape).astype(dtype)
|
||||
result = sess.run(z, {x: inputs})
|
||||
self.assertAllCloseAccordingToType(result, inputs + inputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -325,6 +325,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
|
@ -208,10 +208,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n,
|
||||
TF_RETURN_IF_ERROR(
|
||||
PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments));
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = false;
|
||||
XlaCompiler::CompilationResult result;
|
||||
|
||||
TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(),
|
||||
func, arguments, &result));
|
||||
TF_RETURN_IF_ERROR(
|
||||
compiler->CompileFunction(compile_options, func, arguments, &result));
|
||||
|
||||
TF_RET_CHECK(arguments.size() == expressions.size());
|
||||
|
||||
|
@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel {
|
||||
}
|
||||
|
||||
XlaContext& tc = XlaContext::Get(ctx);
|
||||
if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) {
|
||||
if (tc.resolve_compile_time_constants() &&
|
||||
(input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
|
||||
xla::Literal literal;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
|
||||
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
|
||||
} else {
|
||||
// The core from which a return value is returned depends on the core
|
||||
// assignment of the input to the retval .Since we can't change the core
|
||||
// assignment of <input> as this point, create a tuple/get-tuple-element
|
||||
// combination so that the core will be set on them.
|
||||
auto tuple_elem =
|
||||
ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0);
|
||||
tc.AddRetval(index_, dtype_, tuple_elem);
|
||||
TensorShape shape = ctx->InputShape(0);
|
||||
TensorShape representation_shape =
|
||||
tc.is_entry_computation()
|
||||
? tc.RepresentationShape(shape, ctx->input_type(0))
|
||||
: shape;
|
||||
|
||||
xla::XlaOp output = input;
|
||||
if (tc.is_entry_computation()) {
|
||||
output =
|
||||
ctx->builder()->Reshape(input, representation_shape.dim_sizes());
|
||||
} else {
|
||||
// The core from which a return value is returned depends on the
|
||||
// device assignment of the input to the retval. Since we can't change
|
||||
// the device assignment of "input" at this point, we must always
|
||||
// introduce an operator here, even if the shape does not change.
|
||||
// TODO(b/76097077): propagate device assignments onto arguments and
|
||||
// return values of functions, and then reshape unconditionally.
|
||||
output = ctx->builder()->GetTupleElement(
|
||||
ctx->builder()->Tuple({output}), 0);
|
||||
}
|
||||
tc.AddRetval(index_, dtype_, shape, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,10 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
|
||||
#include <deque>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/const_analysis.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
|
||||
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
|
||||
@ -28,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/executor.h"
|
||||
@ -40,7 +38,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -110,10 +107,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||
local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
|
||||
flib_runtime_ = pflr_->GetFLR(device_->name());
|
||||
|
||||
// The default variable representation shape is the identity function.
|
||||
if (!options_.variable_representation_shape_fn) {
|
||||
options_.variable_representation_shape_fn =
|
||||
[](const TensorShape& shape, DataType type) { return shape; };
|
||||
// The default shape representation function is the identity.
|
||||
if (!options_.shape_representation_fn) {
|
||||
options_.shape_representation_fn = [](const TensorShape& shape,
|
||||
DataType type) { return shape; };
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,20 +227,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
|
||||
|
||||
// Computes the XLA shape for argument 'arg'.
|
||||
Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
|
||||
bool is_entry_computation,
|
||||
xla::Shape* xla_shape) {
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
|
||||
xla_shape);
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
|
||||
LOG(FATAL) << "Unreachable case";
|
||||
case XlaCompiler::Argument::kParameter: {
|
||||
TensorShape shape =
|
||||
is_entry_computation
|
||||
? options_.shape_representation_fn(arg.shape, arg.type)
|
||||
: arg.shape;
|
||||
return TensorShapeToXLAShape(arg.type, shape, xla_shape);
|
||||
}
|
||||
case XlaCompiler::Argument::kResource: {
|
||||
TF_RET_CHECK(arg.initialized);
|
||||
|
||||
switch (arg.resource_kind) {
|
||||
case XlaResource::kVariable: {
|
||||
TensorShape representation_shape =
|
||||
options_.variable_representation_shape_fn(arg.shape, arg.type);
|
||||
options_.shape_representation_fn(arg.shape, arg.type);
|
||||
return TensorShapeToXLAShape(arg.type, representation_shape,
|
||||
xla_shape);
|
||||
}
|
||||
@ -337,16 +339,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||
Status BuildComputation(
|
||||
const std::vector<XlaCompiler::Argument>& args,
|
||||
const std::vector<int>& arg_cores,
|
||||
const std::vector<XlaExpression>& retvals,
|
||||
const std::vector<XlaContext::Retval>& retvals,
|
||||
const std::vector<std::unique_ptr<XlaResource>>& resources,
|
||||
bool return_updated_values_for_all_resources, xla::XlaBuilder* builder,
|
||||
xla::XlaComputation* computation, int* num_computation_outputs,
|
||||
int* num_nonconst_outputs,
|
||||
std::vector<XlaCompiler::OutputDescription>* outputs,
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
||||
std::vector<xla::XlaOp> elems;
|
||||
elems.reserve(retvals.size());
|
||||
for (const XlaExpression& retval : retvals) {
|
||||
if (!retval.has_constant_value()) {
|
||||
for (int i = 0; i < retvals.size(); ++i) {
|
||||
XlaCompiler::OutputDescription& output = (*outputs)[i];
|
||||
output.type = retvals[i].type;
|
||||
output.shape = retvals[i].shape;
|
||||
const XlaExpression& retval = retvals[i].expression;
|
||||
if (retval.has_constant_value()) {
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value();
|
||||
} else {
|
||||
output.is_constant = false;
|
||||
elems.push_back(retval.handle());
|
||||
}
|
||||
}
|
||||
@ -490,8 +501,8 @@ Status XlaCompiler::BuildArguments(
|
||||
std::vector<xla::Shape> arg_shapes(input_mapping->size());
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
// Computes the shapes of non-constant arguments.
|
||||
TF_RETURN_IF_ERROR(
|
||||
XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i]));
|
||||
TF_RETURN_IF_ERROR(XLAShapeForArgument(
|
||||
args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i]));
|
||||
}
|
||||
|
||||
if (use_tuple_arg) {
|
||||
@ -567,7 +578,8 @@ Status XlaCompiler::BuildArguments(
|
||||
|
||||
builder->ClearOpMetadata();
|
||||
|
||||
// Fill in the handles in non-constant arguments.
|
||||
// Fill in the handles in non-constant arguments, and reshape parameters
|
||||
// back to their correct shapes.
|
||||
VLOG(2) << "XLA computation inputs:";
|
||||
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
|
||||
const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
|
||||
@ -586,7 +598,15 @@ Status XlaCompiler::BuildArguments(
|
||||
break;
|
||||
}
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
arg_expression.set_handle(arg_handles[i]);
|
||||
// Reshape parameters back to their correct shapes.
|
||||
// TODO(b/76097077): propagate device assignments onto arguments and
|
||||
// return values of functions, and then reshape unconditionally.
|
||||
if (is_entry_computation) {
|
||||
arg_expression.set_handle(
|
||||
builder->Reshape(arg_handles[i], arg.shape.dim_sizes()));
|
||||
} else {
|
||||
arg_expression.set_handle(arg_handles[i]);
|
||||
}
|
||||
break;
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
case XlaCompiler::Argument::kInvalid:
|
||||
@ -661,10 +681,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
|
||||
|
||||
xla::XlaBuilder builder(name);
|
||||
XlaContext* context =
|
||||
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
|
||||
options.resolve_compile_time_constants,
|
||||
&options_.variable_representation_shape_fn);
|
||||
XlaContext* context = new XlaContext(
|
||||
this, &builder, options_.allow_cpu_custom_calls,
|
||||
options.resolve_compile_time_constants, options.is_entry_computation,
|
||||
&options_.shape_representation_fn);
|
||||
core::ScopedUnref context_unref(context);
|
||||
|
||||
std::vector<XlaExpression> arg_expressions;
|
||||
@ -681,35 +701,22 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
int num_nonconst_outputs;
|
||||
int num_computation_outputs;
|
||||
result->computation = std::make_shared<xla::XlaComputation>();
|
||||
result->outputs.resize(context->retvals().size());
|
||||
TF_RETURN_IF_ERROR(BuildComputation(
|
||||
args, arg_cores, context->retvals(), context->resources(),
|
||||
options.return_updated_values_for_all_resources, &builder,
|
||||
result->computation.get(), &num_computation_outputs,
|
||||
&num_nonconst_outputs, &result->resource_updates));
|
||||
&num_nonconst_outputs, &result->outputs, &result->resource_updates));
|
||||
|
||||
VLOG(2) << "Outputs: total: " << context->retvals().size()
|
||||
<< " nonconstant: " << num_nonconst_outputs;
|
||||
result->outputs.resize(context->retvals().size());
|
||||
for (std::vector<XlaExpression>::size_type i = 0;
|
||||
i < context->retvals().size(); ++i) {
|
||||
const XlaExpression& retval = context->retvals()[i];
|
||||
if (retval.has_constant_value()) {
|
||||
OutputDescription& output = result->outputs[i];
|
||||
output.shape = retval.constant_value().shape();
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value();
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the output shapes, if there is a computation with non-constant
|
||||
// Compute the XLA output shape, if there is a computation with non-constant
|
||||
// outputs.
|
||||
auto computation_shape = client()->GetComputationShape(*result->computation);
|
||||
if (!computation_shape.ok()) {
|
||||
return computation_shape.status();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> computation_shape,
|
||||
client()->GetComputationShape(*result->computation));
|
||||
|
||||
result->xla_output_shape.Swap(
|
||||
computation_shape.ValueOrDie()->mutable_result());
|
||||
result->xla_output_shape.Swap(computation_shape->mutable_result());
|
||||
VLOG(2) << "XLA output shape: "
|
||||
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
|
||||
|
||||
@ -724,23 +731,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
|
||||
// Tensorflow expects a major-to-minor order of results.
|
||||
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
|
||||
|
||||
// Converts the output shapes to TensorShapes.
|
||||
int computation_output = 0;
|
||||
for (std::vector<XlaExpression>::size_type i = 0;
|
||||
i < context->retvals().size(); ++i) {
|
||||
const XlaExpression& retval = context->retvals()[i];
|
||||
if (!retval.has_constant_value()) {
|
||||
TF_RET_CHECK(computation_output < num_computation_outputs)
|
||||
<< "Computation has more outputs than expected";
|
||||
OutputDescription& output = result->outputs[i];
|
||||
output.is_constant = false;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
|
||||
xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
|
||||
computation_output),
|
||||
&output.shape));
|
||||
++computation_output;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -67,6 +67,15 @@ class XlaContext;
|
||||
// _Retval values are ordered by _Retval index, whereas kResource values are
|
||||
// ordered by the original _Arg position of the variable.
|
||||
//
|
||||
// If a shape representation function is provided as part of
|
||||
// XlaCompiler::CompileOptions, kParameter arguments and return values to an
|
||||
// entry computation will be reshaped in accordance to the shape function.
|
||||
// Arguments and return values to a non-entry computation are not reshaped.
|
||||
// Variable resource arguments are passed and returned in reshaped form, even
|
||||
// for non-entry computations. This feature allows TensorFlow to keep on-device
|
||||
// tensors with a different shape to their representation inside the XLA
|
||||
// computation.
|
||||
//
|
||||
// In both inputs and outputs, kResource values are placed the end. When
|
||||
// emitting While loop bodies, we must ensure that the loop body has
|
||||
// identical input and output signatures. By moving variable values
|
||||
@ -171,7 +180,7 @@ class XlaCompiler {
|
||||
};
|
||||
|
||||
struct OutputDescription {
|
||||
// Type and shape of the output.
|
||||
// Type and shape of the output. The shape is the unflattened shape.
|
||||
DataType type;
|
||||
TensorShape shape;
|
||||
|
||||
@ -206,10 +215,12 @@ class XlaCompiler {
|
||||
// original arguments, and are not necessarily in the same order.)
|
||||
std::vector<int> input_mapping;
|
||||
|
||||
// Input shapes of the computation.
|
||||
// Input shapes of the computation. If we are flattening inputs, these are
|
||||
// the flattened shapes.
|
||||
std::vector<xla::Shape> xla_input_shapes;
|
||||
|
||||
// Output shape in XLA format. The output shape is always a tuple.
|
||||
// Output shape in XLA format. The output shape is always a tuple. If we
|
||||
// are flattening outputs, these are the flattened shapes.
|
||||
xla::Shape xla_output_shape;
|
||||
|
||||
// TensorFlow shapes of outputs, together with the values of any
|
||||
@ -230,6 +241,8 @@ class XlaCompiler {
|
||||
std::shared_ptr<xla::XlaComputation> computation;
|
||||
};
|
||||
|
||||
typedef std::function<TensorShape(const TensorShape&, DataType)>
|
||||
ShapeRepresentationFn;
|
||||
struct Options {
|
||||
// Name of the compilation device to use. Needs to be live only during
|
||||
// XlaCompiler's constructor.
|
||||
@ -250,8 +263,7 @@ class XlaCompiler {
|
||||
// If set, the XLA representation of variables represented to XLA as the
|
||||
// shape given by this shape function. Variables are reshaped to this shape
|
||||
// on write, and reshaped to their original shape on read.
|
||||
std::function<TensorShape(const TensorShape&, DataType)>
|
||||
variable_representation_shape_fn;
|
||||
ShapeRepresentationFn shape_representation_fn;
|
||||
|
||||
// If not nullptr, populate_resource_manager is called with the
|
||||
// compilation device's resource manager when the compilation
|
||||
@ -300,7 +312,8 @@ class XlaCompiler {
|
||||
// Returns the shape of the XLA parameter for an argument 'arg'.
|
||||
// See the class comment for more details about the argument passing
|
||||
// convention.
|
||||
Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
|
||||
Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
|
||||
xla::Shape* xla_shape);
|
||||
|
||||
// Retrieves the channel handle associated with `key`. Allocates
|
||||
// a new channel handle if none exists.
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
@ -750,10 +751,7 @@ TEST_F(XlaCompilerTest, Variables) {
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
// Tests a simple graph that reads and writes a variable, with a
|
||||
// variable_representation_shape_fn passed to the compiler that flattens all
|
||||
// variable tensors to vectors.
|
||||
TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
|
||||
@ -764,7 +762,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
|
||||
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
|
||||
return std::move(graph);
|
||||
}
|
||||
|
||||
// Tests a simple graph that reads and writes a variable, with a
|
||||
// shape_representation_fn passed to the compiler that flattens all
|
||||
// variable tensors to vectors.
|
||||
TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
|
||||
|
||||
// Builds a description of the arguments.
|
||||
std::vector<XlaCompiler::Argument> args(2);
|
||||
@ -779,15 +785,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler::Options options = DefaultOptions();
|
||||
options.variable_representation_shape_fn = [](const TensorShape& shape,
|
||||
DataType type) {
|
||||
options.shape_representation_fn = [](const TensorShape& shape,
|
||||
DataType type) {
|
||||
return TensorShape({shape.num_elements()});
|
||||
};
|
||||
XlaCompiler compiler(options);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = false; // Only reshape variables.
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
|
||||
std::move(graph), args, &result));
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args, &result));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
||||
client_->GetComputationShape(*result.computation));
|
||||
|
||||
ASSERT_EQ(program_shape->parameters_size(), 2);
|
||||
EXPECT_TRUE(
|
||||
xla::ShapeUtil::Compatible(program_shape->parameters(0),
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
|
||||
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
|
||||
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||
program_shape->result(),
|
||||
xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
@ -815,5 +839,74 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
|
||||
|
||||
// Builds a description of the arguments.
|
||||
std::vector<XlaCompiler::Argument> args(2);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = TensorShape({2, 2});
|
||||
args[1].kind = XlaCompiler::Argument::kResource;
|
||||
args[1].resource_kind = XlaResource::kVariable;
|
||||
args[1].initialized = true;
|
||||
args[1].type = DT_INT32;
|
||||
args[1].shape = TensorShape({2, 2});
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler::Options options = DefaultOptions();
|
||||
options.shape_representation_fn = [](const TensorShape& shape,
|
||||
DataType type) {
|
||||
return TensorShape({shape.num_elements()});
|
||||
};
|
||||
XlaCompiler compiler(options);
|
||||
|
||||
XlaCompiler::CompileOptions compile_options;
|
||||
compile_options.is_entry_computation = true; // Reshape args and retvals.
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||
args, &result));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
|
||||
client_->GetComputationShape(*result.computation));
|
||||
|
||||
ASSERT_EQ(program_shape->parameters_size(), 2);
|
||||
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||
program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
|
||||
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||
program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
|
||||
EXPECT_TRUE(xla::ShapeUtil::Compatible(
|
||||
program_shape->result(),
|
||||
xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::S32, {4}),
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {4})})));
|
||||
|
||||
// Tests that the generated computation works.
|
||||
std::unique_ptr<xla::Literal> param0_literal =
|
||||
xla::Literal::CreateR1<int32>({4, 55, 1, -3});
|
||||
std::unique_ptr<xla::Literal> param1_literal =
|
||||
xla::Literal::CreateR1<int32>({22, 11, 33, 404});
|
||||
std::unique_ptr<xla::GlobalData> param0_data =
|
||||
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> param1_data =
|
||||
client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_
|
||||
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
|
||||
.ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::Literal> actual_literal =
|
||||
client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
|
||||
std::unique_ptr<xla::Literal> expected0 =
|
||||
xla::Literal::CreateR1<int32>({27, 67, 35, 402});
|
||||
std::unique_ptr<xla::Literal> expected1 =
|
||||
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
|
||||
std::unique_ptr<xla::Literal> expected_literal =
|
||||
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -65,26 +65,30 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
|
||||
XlaContext::XlaContext(
|
||||
XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
||||
bool is_entry_computation,
|
||||
const std::function<TensorShape(const TensorShape&, DataType)>*
|
||||
variable_representation_shape_fn)
|
||||
shape_representation_fn)
|
||||
: compiler_(compiler),
|
||||
builder_(builder),
|
||||
allow_cpu_custom_calls_(allow_cpu_custom_calls),
|
||||
resolve_compile_time_constants_(resolve_compile_time_constants),
|
||||
variable_representation_shape_fn_(variable_representation_shape_fn) {}
|
||||
is_entry_computation_(is_entry_computation),
|
||||
shape_representation_fn_(shape_representation_fn) {}
|
||||
|
||||
string XlaContext::DebugString() { return "TLA JIT context"; }
|
||||
|
||||
// This is called by the Retval Op to associate a computed value
|
||||
// with a specific return value of the subgraph.
|
||||
void XlaContext::AddRetval(int retval_index, DataType type,
|
||||
const xla::XlaOp& handle) {
|
||||
const TensorShape& shape, const xla::XlaOp& handle) {
|
||||
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
|
||||
// Add the return value to the list being built up.
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
retvals_[retval_index].set_handle(handle);
|
||||
XlaExpression e;
|
||||
e.set_handle(handle);
|
||||
retvals_[retval_index] = Retval{type, shape, e};
|
||||
}
|
||||
|
||||
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
@ -94,13 +98,11 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
|
||||
if (retvals_.size() <= retval_index) {
|
||||
retvals_.resize(retval_index + 1);
|
||||
}
|
||||
if (resolve_compile_time_constants_) {
|
||||
Tensor value;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
|
||||
retvals_[retval_index].set_constant_value(std::move(value));
|
||||
} else {
|
||||
retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal));
|
||||
}
|
||||
Tensor value;
|
||||
TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
|
||||
XlaExpression e;
|
||||
e.set_constant_value(value);
|
||||
retvals_[retval_index] = Retval{dtype, value.shape(), e};
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -117,9 +119,9 @@ Status XlaContext::CreateResource(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape,
|
||||
DataType type) const {
|
||||
return (*variable_representation_shape_fn_)(shape, type);
|
||||
TensorShape XlaContext::RepresentationShape(const TensorShape& shape,
|
||||
DataType type) const {
|
||||
return (*shape_representation_fn_)(shape, type);
|
||||
}
|
||||
|
||||
const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
|
||||
|
@ -42,11 +42,13 @@ class XlaContext : public ResourceBase {
|
||||
static XlaContext& Get(const OpKernelContext* ctx);
|
||||
static XlaContext& Get(const XlaOpKernelContext* ctx);
|
||||
|
||||
// Creates a new XlaContext.
|
||||
// Creates a new XlaContext. See the documentation on the class data fields
|
||||
// for descriptions of the arguments.
|
||||
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
|
||||
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
|
||||
bool is_entry_computation,
|
||||
const std::function<TensorShape(const TensorShape&, DataType)>*
|
||||
variable_representation_shape_fn);
|
||||
shape_representation_fn);
|
||||
|
||||
// Virtual method defined by ResourceBase.
|
||||
string DebugString() override;
|
||||
@ -58,14 +60,26 @@ class XlaContext : public ResourceBase {
|
||||
|
||||
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
|
||||
|
||||
bool resolve_compile_time_constants() const {
|
||||
return resolve_compile_time_constants_;
|
||||
}
|
||||
bool is_entry_computation() const { return is_entry_computation_; }
|
||||
|
||||
const std::vector<XlaExpression>& args() const { return args_; }
|
||||
void set_args(std::vector<XlaExpression> args);
|
||||
|
||||
const std::vector<XlaExpression>& retvals() { return retvals_; }
|
||||
struct Retval {
|
||||
DataType type;
|
||||
TensorShape shape;
|
||||
// An XlaExpression representing the Retval's value.
|
||||
XlaExpression expression;
|
||||
};
|
||||
const std::vector<Retval>& retvals() { return retvals_; }
|
||||
|
||||
// This is called by the Retval Op to associate a computed value
|
||||
// with a specific return value of the subgraph.
|
||||
void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle);
|
||||
void AddRetval(int retval_index, DataType type, const TensorShape& shape,
|
||||
const xla::XlaOp& handle);
|
||||
|
||||
// As for Retval, but for return values that are compile-time constants.
|
||||
Status AddConstRetval(int retval_index, DataType dtype,
|
||||
@ -86,9 +100,9 @@ class XlaContext : public ResourceBase {
|
||||
}
|
||||
|
||||
// Returns the XLA shape to be used to represent a variable of TF `shape`
|
||||
// and `type`.
|
||||
TensorShape VariableRepresentationShape(const TensorShape& shape,
|
||||
DataType type) const;
|
||||
// and `type`, or of an argument or return value of a top-level computation.
|
||||
TensorShape RepresentationShape(const TensorShape& shape,
|
||||
DataType type) const;
|
||||
|
||||
// Get an XLA lambda to compute Max. This is cached in the
|
||||
// XlaContext since it may be used by multiple Ops. There is a
|
||||
@ -131,15 +145,23 @@ class XlaContext : public ResourceBase {
|
||||
std::vector<XlaExpression> args_;
|
||||
|
||||
// Return values of the Tensorflow graph, indexed by _Retval index.
|
||||
std::vector<XlaExpression> retvals_;
|
||||
std::vector<Retval> retvals_;
|
||||
|
||||
// Holds ownership of resources. The resources are not ordered.
|
||||
std::vector<std::unique_ptr<XlaResource>> resources_;
|
||||
|
||||
// A function that describes how variable shapes should be represented
|
||||
// in XLA. Variable values will be reshaped to this shape. Must be non-null.
|
||||
// Is this a top-level computation, or an inner computation (e.g., a while
|
||||
// body)?
|
||||
const bool is_entry_computation_;
|
||||
|
||||
// A function that describes how the shapes of
|
||||
// a) argument and return value, for entry computations
|
||||
// b) variables, for all computations,
|
||||
// should be represented in XLA. Parameters/return values will be shaped
|
||||
// according to this function, and reshaped back to/from their declared shapes
|
||||
// for computations. Must be non-null.
|
||||
const std::function<TensorShape(const TensorShape&, DataType)>*
|
||||
variable_representation_shape_fn_;
|
||||
shape_representation_fn_;
|
||||
|
||||
// Cache of prebuilt computations indexed by their type.
|
||||
using ComputationMap = std::map<DataType, xla::XlaComputation>;
|
||||
|
@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
|
||||
}
|
||||
|
||||
XlaContext& xla_context = XlaContext::Get(context_);
|
||||
TensorShape representation_shape = xla_context.VariableRepresentationShape(
|
||||
variable->shape(), variable->type());
|
||||
TensorShape representation_shape =
|
||||
xla_context.RepresentationShape(variable->shape(), variable->type());
|
||||
if (representation_shape == variable->shape()) {
|
||||
*value = variable->value();
|
||||
} else {
|
||||
@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
|
||||
|
||||
XlaContext& xla_context = XlaContext::Get(context_);
|
||||
TensorShape representation_shape =
|
||||
xla_context.VariableRepresentationShape(shape, type);
|
||||
xla_context.RepresentationShape(shape, type);
|
||||
if (shape != representation_shape) {
|
||||
handle = builder()->Reshape(handle, representation_shape.dim_sizes());
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user