From 41af9782f4b0ffc19aad4bfc9652c4e910152459 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 16 May 2018 13:34:10 -0700 Subject: [PATCH] Automated g4 rollback of changelist 196691101 PiperOrigin-RevId: 196879933 --- .../compiler/aot/tests/tfcompile_test.cc | 8 +- .../compiler/jit/kernels/xla_launch_op.cc | 16 ++- .../compiler/jit/xla_compile_on_demand_op.cc | 7 +- tensorflow/compiler/jit/xla_cpu_device.cc | 9 +- tensorflow/compiler/jit/xla_device.cc | 43 +++--- tensorflow/compiler/jit/xla_device.h | 33 +++-- tensorflow/compiler/jit/xla_device_context.cc | 49 +++++-- tensorflow/compiler/jit/xla_device_context.h | 14 +- tensorflow/compiler/jit/xla_gpu_device.cc | 3 +- tensorflow/compiler/jit/xla_launch_util.cc | 5 - tensorflow/compiler/tests/BUILD | 126 ++++++++++-------- .../compiler/tests/xla_device_gpu_test.py | 48 +++++++ tensorflow/compiler/tests/xla_device_test.py | 37 ++--- tensorflow/compiler/tf2xla/BUILD | 1 + tensorflow/compiler/tf2xla/graph_compiler.cc | 7 +- .../compiler/tf2xla/kernels/retval_op.cc | 31 +++-- tensorflow/compiler/tf2xla/xla_compiler.cc | 110 +++++++-------- tensorflow/compiler/tf2xla/xla_compiler.h | 25 +++- .../compiler/tf2xla/xla_compiler_test.cc | 111 +++++++++++++-- tensorflow/compiler/tf2xla/xla_context.cc | 30 +++-- tensorflow/compiler/tf2xla/xla_context.h | 44 ++++-- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 6 +- 22 files changed, 507 insertions(+), 256 deletions(-) create mode 100644 tensorflow/compiler/tests/xla_device_gpu_test.py diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 868d752927b..fee46280e9a 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -551,14 +551,16 @@ TEST(TFCompileTest, HloProfiling) { auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%arg1.0.1)"); auto add_profile_line = HasSubstr( - "%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%arg1.0.1)"); auto tuple_profile_line = HasSubstr( "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " - "%dot.0.2, f32[2,2]{1,0} %add.0.5)"); + "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); + auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); + auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); EXPECT_THAT(hlo_profile_lines, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 86a9fd3b8e1..9d856346eca 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -112,7 +112,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - const XlaDevice::Metadata* metadata; + const XlaDevice::Metadata* metadata = nullptr; Status s = XlaDevice::GetMetadata(ctx, &metadata); bool allocate_xla_tensors = s.ok(); @@ -153,9 +153,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); options.device_allocator = xla_allocator; - // TODO(b/77671268): We don't set variable_representation_shape_fn here. This - // is restricted to Variables, but we need something like this to apply to - // normal Tensors too. + if (metadata) { + options.shape_representation_fn = metadata->shape_representation_fn(); + } const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; @@ -164,9 +164,11 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { for (int i : constants_) { constant_args.insert({i, ctx->input(i)}); } - OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, - variables, ctx, &kernel, &executable, - /*compile_options=*/nullptr)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; + OP_REQUIRES_OK( + ctx, cache->Compile(options, function_, constant_args, variables, ctx, + &kernel, &executable, &compile_options)); VLOG(1) << "Executing XLA Computation..."; diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 6b83cf67ffc..ab644ff5a61 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -156,11 +156,14 @@ Status XlaCompileOnDemandOp::Compile( options.client = metadata.client(); options.flib_def = new FunctionLibraryDefinition(OpRegistry::Global(), FunctionDefLibrary{}); + options.shape_representation_fn = metadata.shape_representation_fn(); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; std::map variable_args = GetVariables(ctx); return cache->CompileSingleOp(options, constant_arguments, variable_args, ctx, - result, executable, - /*compile_options=*/nullptr); + result, executable, &compile_options); } void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index bc07dbd7bdf..ea9e0366043 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -50,10 +50,11 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, - DEVICE_CPU_XLA_JIT, options, name_prefix, - registration, - /*transfer_as_literal=*/false, &device)); + TF_RETURN_IF_ERROR( + XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, + name_prefix, registration, + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index cb376a787ad..f13b46c532e 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -110,7 +110,9 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, std::unique_ptr* device) { + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; @@ -129,17 +131,19 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice(options, attrs, device_ordinal, - DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal)); + device->reset(new XlaDevice( + options, attrs, device_ordinal, DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal, shape_representation_fn)); return Status::OK(); } -XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type) +XlaDevice::Metadata::Metadata( + int device_ordinal, se::Platform* platform, const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) : device_ordinal_(device_ordinal), device_type_(device_type), - platform_(platform) {} + platform_(platform), + shape_representation_fn_(std::move(shape_representation_fn)) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -170,17 +174,20 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } -XlaDevice::XlaDevice(const SessionOptions& options, - const DeviceAttributes& attrs, int device_ordinal, - const DeviceType& jit_device_name, se::Platform* platform, - bool transfer_as_literal) +XlaDevice::XlaDevice( + const SessionOptions& options, const DeviceAttributes& attrs, + int device_ordinal, const DeviceType& jit_device_name, + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn) : LocalDevice(options, attrs), - xla_metadata_(device_ordinal, platform, jit_device_name), + xla_metadata_(device_ordinal, platform, jit_device_name, + shape_representation_fn), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), platform_(platform), - transfer_as_literal_(transfer_as_literal) { + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(shape_representation_fn) { VLOG(1) << "Created XLA device " << jit_device_name; } @@ -232,8 +239,8 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() { // gpu_device_info_->default_context. gpu_device_info_ = MakeUnique(); gpu_device_info_->stream = stream; - gpu_device_info_->default_context = - new XlaDeviceContext(stream, client(), transfer_as_literal_); + gpu_device_info_->default_context = new XlaDeviceContext( + stream, client(), transfer_as_literal_, shape_representation_fn_); set_tensorflow_gpu_device_info(gpu_device_info_.get()); } @@ -247,7 +254,8 @@ Status XlaDevice::FillContextMap(const Graph* graph, TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); // Call GetAllocator for the side-effect of ensuring the allocator is created. GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_); + auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, + shape_representation_fn_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -294,7 +302,8 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_); + XlaTransferManager manager(stream, client(), transfer_as_literal_, + shape_representation_fn_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 3ae87308cc7..d5d345d43b1 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -17,8 +17,7 @@ limitations under the License. // runtime. // // Operators assigned to an XlaDevice are compiled into XLA computations. -// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state -// is managed by XLA. +// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. // // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), // under different names (e.g., XLA_CPU or XLA_GPU). @@ -27,6 +26,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -50,7 +50,8 @@ class XlaDevice : public LocalDevice { class Metadata { public: Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type); + const DeviceType& device_type, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); // The index of the device on this host. int device_ordinal() const; @@ -58,11 +59,15 @@ class XlaDevice : public LocalDevice { se::Platform* platform() const; xla::LocalClient* client() const; const DeviceType& jit_device_type() const; + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { + return shape_representation_fn_; + } private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -76,16 +81,19 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. - static Status Create(const string& platform_name, const string& device_name, - int device_ordinal, const string& jit_device_name, - const SessionOptions& options, const string& name_prefix, - const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, - std::unique_ptr* device); + static Status Create( + const string& platform_name, const string& device_name, + int device_ordinal, const string& jit_device_name, + const SessionOptions& options, const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + std::unique_ptr* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal); + se::Platform* platform, bool transfer_as_literal, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -116,8 +124,8 @@ class XlaDevice : public LocalDevice { // The name of the device that is used to compile Ops for this XlaDevice. DeviceType jit_device_name_; // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_; // Not owned. + se::Platform* platform_; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and @@ -126,6 +134,7 @@ class XlaDevice : public LocalDevice { // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. bool transfer_as_literal_; + XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // If set, holds default device context (that we must Unref) // and its stream. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index bf8c1886a02..ff30b62bad7 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -47,13 +47,14 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager(se::Stream* stream, - xla::LocalClient* client, - bool transfer_as_literal) +XlaTransferManager::XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) : stream_(stream), client_(client), transfer_manager_(client->backend().transfer_manager()), - transfer_as_literal_(transfer_as_literal) {} + transfer_as_literal_(transfer_as_literal), + shape_representation_fn_(std::move(shape_representation_fn)) {} Status XlaTransferManager::TransferLiteralToDevice( const Tensor& host_tensor, Tensor* device_tensor) const { @@ -76,7 +77,15 @@ Status XlaTransferManager::TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice( stream_->parent(), shaped_buffer)); VLOG(1) << "Transfer from device as literal: " << literal->ToString(); - return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor); + Tensor tensor; + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); + // Reshape the tensor back to its declared shape. + if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { + return errors::Internal( + "Tensor::CopyFrom failed when copying from XLA device to CPU"); + } + return Status::OK(); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, @@ -96,9 +105,17 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); + + TensorShape shape; + if (shape_representation_fn_) { + shape = shape_representation_fn_(device_tensor->shape(), + device_tensor->dtype()); + } else { + shape = device_tensor->shape(); + } if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer( - device_tensor->dtype(), device_tensor->shape(), client_, + device_tensor->dtype(), shape, client_, stream_->parent()->device_ordinal()); if (!s.ok()) { done(s); @@ -106,12 +123,18 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } } - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); Status status; if (transfer_as_literal_) { - status = TransferLiteralToDevice(*cpu_tensor, device_tensor); + Tensor reshaped_cpu_tensor; + if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { + done(errors::Internal( + "Tensor::CopyFrom failed when copying from CPU to XLA device")); + return; + } + status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); } else { + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. Status block_status = stream_->BlockHostUntilDone(); @@ -171,9 +194,11 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal) - : manager_(stream, client, transfer_as_literal) {} +XlaDeviceContext::XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn) + : manager_(stream, client, transfer_as_literal, + std::move(shape_representation_fn)) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index d7f5f1d2089..9af96558684 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/jit/xla_tensor.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -45,8 +46,9 @@ class XlaDeviceAllocator : public Allocator { // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal); + explicit XlaTransferManager( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -69,7 +71,8 @@ class XlaTransferManager { // Transfer manager, for marshalling data to and from the device. xla::TransferManager* transfer_manager_; // True if we must use XLA's TransferManager for correct device transfers. - bool transfer_as_literal_; + const bool transfer_as_literal_; + const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -77,8 +80,9 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, - bool transfer_as_literal); + explicit XlaDeviceContext( + se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, + XlaCompiler::ShapeRepresentationFn shape_representation_fn); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index a8afbf9dcd7..26842fbe5cc 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -48,7 +48,8 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, Status status = XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, &device); + /*transfer_as_literal=*/false, + /*shape_representation_fn=*/{}, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 6a0f557627d..d0c7a936512 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -195,11 +195,6 @@ void XlaComputationLaunchContext::PopulateOutputs( OP_REQUIRES_OK( ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); - if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { - OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer( - const_tensor.dtype(), const_tensor.shape(), - client_, stream->parent()->device_ordinal())); - } Device* device = dynamic_cast(ctx->device()); OP_REQUIRES(ctx, device != nullptr, diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 96dfc8d8f1c..213ab95a12f 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -42,7 +42,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform", "//tensorflow/python:random_seed", "//tensorflow/python:session", @@ -58,7 +58,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -72,7 +72,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -93,7 +93,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -111,7 +111,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:bitwise_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -127,7 +127,7 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -141,7 +141,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -156,7 +156,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -170,7 +170,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -184,7 +184,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -209,7 +209,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -225,7 +225,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -241,7 +241,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -263,7 +263,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -291,7 +291,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -307,7 +307,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -326,7 +326,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn", @@ -346,7 +346,7 @@ tf_xla_py_test( "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:spectral_ops", ], @@ -360,7 +360,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -372,7 +372,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -388,7 +388,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -403,7 +403,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", ], @@ -431,7 +431,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -446,7 +446,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -458,7 +458,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -472,7 +472,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -485,7 +485,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -498,7 +498,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -513,7 +513,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -530,7 +530,7 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -545,7 +545,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -561,7 +561,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -574,7 +574,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", ], ) @@ -586,7 +586,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -598,7 +598,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -613,7 +613,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -626,7 +626,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:platform_test", @@ -641,7 +641,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -657,7 +657,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -670,7 +670,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/contrib/stateless", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -684,7 +684,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -703,7 +703,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -716,7 +716,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -730,7 +730,7 @@ tf_xla_py_test( srcs = ["fused_batchnorm_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn", @@ -749,7 +749,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -768,7 +768,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], @@ -783,7 +783,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -795,7 +795,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -808,21 +808,34 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "xla_device_test", + size = "small", + srcs = ["xla_device_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) cuda_py_test( - name = "xla_device_test", + name = "xla_device_gpu_test", size = "small", - srcs = ["xla_device_test.py"], + srcs = ["xla_device_gpu_test.py"], additional_deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", ], ) @@ -839,7 +852,6 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", "//tensorflow/python:layers", "//tensorflow/python:math_ops", @@ -887,7 +899,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:variables", @@ -902,7 +914,7 @@ cuda_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -940,7 +952,7 @@ tf_xla_py_test( srcs = ["fake_quant_ops_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) @@ -952,7 +964,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework", "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/compiler/tests/xla_device_gpu_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py new file mode 100644 index 00000000000..1e30ebd55d0 --- /dev/null +++ b/tensorflow/compiler/tests/xla_device_gpu_test.py @@ -0,0 +1,48 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test cases for XLA devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class XlaDeviceGpuTest(test.TestCase): + + def testCopiesToAndFromGpuWork(self): + """Tests that copies between GPU and XLA devices work.""" + if not test.is_gpu_available(): + return + + with session_lib.Session() as sess: + x = array_ops.placeholder(dtypes.float32, [2]) + with ops.device("GPU"): + y = x * 2 + with ops.device("device:XLA_CPU:0"): + z = y * y + with ops.device("GPU"): + w = y + z + result = sess.run(w, {x: [1.5, 0.5]}) + self.assertAllClose(result, [12., 2.], rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index f5c228f8305..b707bd0963d 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -1,4 +1,4 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,30 +18,33 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class XlaDeviceTest(test.TestCase): +class XlaDeviceTest(XLATestCase): def testCopies(self): - """Tests that copies between GPU and XLA devices work.""" - if not test.is_gpu_available(): - return + """Tests that copies onto and off XLA devices work.""" + shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3], + [16384, 1], [1, 16384], [1, 20000, 1, 1]] + for dtype in self.numeric_types: + for shape in shapes: + with self.test_session() as sess: + with ops.device("CPU"): + x = array_ops.placeholder(dtype, shape) + with self.test_scope(): + y = x + x + with ops.device("CPU"): + z = array_ops.identity(y) - with session_lib.Session() as sess: - x = array_ops.placeholder(dtypes.float32, [2]) - with ops.device("GPU"): - y = x * 2 - with ops.device("device:XLA_CPU:0"): - z = y * y - with ops.device("GPU"): - w = y + z - result = sess.run(w, {x: [1.5, 0.5]}) - self.assertAllClose(result, [12., 2.], rtol=1e-3) + inputs = np.random.randint(-100, 100, shape).astype(dtype) + result = sess.run(z, {x: inputs}) + self.assertAllCloseAccordingToType(result, inputs + inputs) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 4fca51f54d3..cd57452302f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -325,6 +325,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:cpu_plugin", diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 8115a26210a..b1cb76aeaab 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -208,10 +208,11 @@ Status GraphCompiler::CompileFunctionalNode(Node* n, TF_RETURN_IF_ERROR( PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; XlaCompiler::CompilationResult result; - - TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), - func, arguments, &result)); + TF_RETURN_IF_ERROR( + compiler->CompileFunction(compile_options, func, arguments, &result)); TF_RET_CHECK(arguments.size() == expressions.size()); diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 70547290eae..a7112786384 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -55,18 +55,33 @@ class RetvalOp : public XlaOpKernel { } XlaContext& tc = XlaContext::Get(ctx); - if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { + if (tc.resolve_compile_time_constants() && + (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - // The core from which a return value is returned depends on the core - // assignment of the input to the retval .Since we can't change the core - // assignment of as this point, create a tuple/get-tuple-element - // combination so that the core will be set on them. - auto tuple_elem = - ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); - tc.AddRetval(index_, dtype_, tuple_elem); + TensorShape shape = ctx->InputShape(0); + TensorShape representation_shape = + tc.is_entry_computation() + ? tc.RepresentationShape(shape, ctx->input_type(0)) + : shape; + + xla::XlaOp output = input; + if (tc.is_entry_computation()) { + output = + ctx->builder()->Reshape(input, representation_shape.dim_sizes()); + } else { + // The core from which a return value is returned depends on the + // device assignment of the input to the retval. Since we can't change + // the device assignment of "input" at this point, we must always + // introduce an operator here, even if the shape does not change. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + output = ctx->builder()->GetTupleElement( + ctx->builder()->Tuple({output}), 0); + } + tc.AddRetval(index_, dtype_, shape, output); } } } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3d1946c332b..5a6db7736e5 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,10 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include #include +#include -#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -28,7 +27,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -40,7 +38,6 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -110,10 +107,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); - // The default variable representation shape is the identity function. - if (!options_.variable_representation_shape_fn) { - options_.variable_representation_shape_fn = - [](const TensorShape& shape, DataType type) { return shape; }; + // The default shape representation function is the identity. + if (!options_.shape_representation_fn) { + options_.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return shape; }; } } @@ -230,20 +227,25 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, + bool is_entry_computation, xla::Shape* xla_shape) { switch (arg.kind) { case XlaCompiler::Argument::kConstant: - return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), - xla_shape); - case XlaCompiler::Argument::kParameter: - return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + LOG(FATAL) << "Unreachable case"; + case XlaCompiler::Argument::kParameter: { + TensorShape shape = + is_entry_computation + ? options_.shape_representation_fn(arg.shape, arg.type) + : arg.shape; + return TensorShapeToXLAShape(arg.type, shape, xla_shape); + } case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { TensorShape representation_shape = - options_.variable_representation_shape_fn(arg.shape, arg.type); + options_.shape_representation_fn(arg.shape, arg.type); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -337,16 +339,25 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, Status BuildComputation( const std::vector& args, const std::vector& arg_cores, - const std::vector& retvals, + const std::vector& retvals, const std::vector>& resources, bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, xla::XlaComputation* computation, int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* outputs, std::vector* resource_updates) { std::vector elems; elems.reserve(retvals.size()); - for (const XlaExpression& retval : retvals) { - if (!retval.has_constant_value()) { + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + output.type = retvals[i].type; + output.shape = retvals[i].shape; + const XlaExpression& retval = retvals[i].expression; + if (retval.has_constant_value()) { + output.is_constant = true; + output.constant_value = retval.constant_value(); + } else { + output.is_constant = false; elems.push_back(retval.handle()); } } @@ -490,8 +501,8 @@ Status XlaCompiler::BuildArguments( std::vector arg_shapes(input_mapping->size()); for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR( - XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); + TF_RETURN_IF_ERROR(XLAShapeForArgument( + args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); } if (use_tuple_arg) { @@ -567,7 +578,8 @@ Status XlaCompiler::BuildArguments( builder->ClearOpMetadata(); - // Fill in the handles in non-constant arguments. + // Fill in the handles in non-constant arguments, and reshape parameters + // back to their correct shapes. VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; @@ -586,7 +598,15 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kParameter: - arg_expression.set_handle(arg_handles[i]); + // Reshape parameters back to their correct shapes. + // TODO(b/76097077): propagate device assignments onto arguments and + // return values of functions, and then reshape unconditionally. + if (is_entry_computation) { + arg_expression.set_handle( + builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + } else { + arg_expression.set_handle(arg_handles[i]); + } break; case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: @@ -661,10 +681,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); xla::XlaBuilder builder(name); - XlaContext* context = - new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, - &options_.variable_representation_shape_fn); + XlaContext* context = new XlaContext( + this, &builder, options_.allow_cpu_custom_calls, + options.resolve_compile_time_constants, options.is_entry_computation, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); std::vector arg_expressions; @@ -681,35 +701,22 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared(); + result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->resource_updates)); + &num_nonconst_outputs, &result->outputs, &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; - result->outputs.resize(context->retvals().size()); - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (retval.has_constant_value()) { - OutputDescription& output = result->outputs[i]; - output.shape = retval.constant_value().shape(); - output.is_constant = true; - output.constant_value = retval.constant_value(); - } - } - // Compute the output shapes, if there is a computation with non-constant + // Compute the XLA output shape, if there is a computation with non-constant // outputs. - auto computation_shape = client()->GetComputationShape(*result->computation); - if (!computation_shape.ok()) { - return computation_shape.status(); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr computation_shape, + client()->GetComputationShape(*result->computation)); - result->xla_output_shape.Swap( - computation_shape.ValueOrDie()->mutable_result()); + result->xla_output_shape.Swap(computation_shape->mutable_result()); VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); @@ -724,23 +731,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); - // Converts the output shapes to TensorShapes. - int computation_output = 0; - for (std::vector::size_type i = 0; - i < context->retvals().size(); ++i) { - const XlaExpression& retval = context->retvals()[i]; - if (!retval.has_constant_value()) { - TF_RET_CHECK(computation_output < num_computation_outputs) - << "Computation has more outputs than expected"; - OutputDescription& output = result->outputs[i]; - output.is_constant = false; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape( - xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, - computation_output), - &output.shape)); - ++computation_output; - } - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index ca6cd822ef4..621fbc149a6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -67,6 +67,15 @@ class XlaContext; // _Retval values are ordered by _Retval index, whereas kResource values are // ordered by the original _Arg position of the variable. // +// If a shape representation function is provided as part of +// XlaCompiler::CompileOptions, kParameter arguments and return values to an +// entry computation will be reshaped in accordance to the shape function. +// Arguments and return values to a non-entry computation are not reshaped. +// Variable resource arguments are passed and returned in reshaped form, even +// for non-entry computations. This feature allows TensorFlow to keep on-device +// tensors with a different shape to their representation inside the XLA +// computation. +// // In both inputs and outputs, kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has // identical input and output signatures. By moving variable values @@ -171,7 +180,7 @@ class XlaCompiler { }; struct OutputDescription { - // Type and shape of the output. + // Type and shape of the output. The shape is the unflattened shape. DataType type; TensorShape shape; @@ -206,10 +215,12 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Input shapes of the computation. + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. std::vector xla_input_shapes; - // Output shape in XLA format. The output shape is always a tuple. + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. xla::Shape xla_output_shape; // TensorFlow shapes of outputs, together with the values of any @@ -230,6 +241,8 @@ class XlaCompiler { std::shared_ptr computation; }; + typedef std::function + ShapeRepresentationFn; struct Options { // Name of the compilation device to use. Needs to be live only during // XlaCompiler's constructor. @@ -250,8 +263,7 @@ class XlaCompiler { // If set, the XLA representation of variables represented to XLA as the // shape given by this shape function. Variables are reshaped to this shape // on write, and reshaped to their original shape on read. - std::function - variable_representation_shape_fn; + ShapeRepresentationFn shape_representation_fn; // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation @@ -300,7 +312,8 @@ class XlaCompiler { // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); + Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, + xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 4382ffe6ba3..5670545f9d4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" @@ -750,10 +751,7 @@ TEST_F(XlaCompilerTest, Variables) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -// Tests a simple graph that reads and writes a variable, with a -// variable_representation_shape_fn passed to the compiler that flattens all -// variable tensors to vectors. -TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { +xla::StatusOr> BuildTestGraph() { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); @@ -764,7 +762,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_ASSERT_OK(scope.ToGraph(graph.get())); + TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); + return std::move(graph); +} + +// Tests a simple graph that reads and writes a variable, with a +// shape_representation_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); // Builds a description of the arguments. std::vector args(2); @@ -779,15 +785,33 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.variable_representation_shape_fn = [](const TensorShape& shape, - DataType type) { + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { return TensorShape({shape.num_elements()}); }; XlaCompiler compiler(options); + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = false; // Only reshape variables. + XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", - std::move(graph), args, &result)); + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE( + xla::ShapeUtil::Compatible(program_shape->parameters(0), + xla::ShapeUtil::MakeShape(xla::S32, {2, 2}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); // Tests that the generated computation works. std::unique_ptr param0_literal = @@ -815,5 +839,74 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } +TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr graph, BuildTestGraph()); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 2}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.shape_representation_fn = [](const TensorShape& shape, + DataType type) { + return TensorShape({shape.num_elements()}); + }; + XlaCompiler compiler(options); + + XlaCompiler::CompileOptions compile_options; + compile_options.is_entry_computation = true; // Reshape args and retvals. + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), + args, &result)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr program_shape, + client_->GetComputationShape(*result.computation)); + + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); + EXPECT_TRUE(xla::ShapeUtil::Compatible( + program_shape->result(), + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {4}), + xla::ShapeUtil::MakeShape(xla::S32, {4})}))); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR1({4, 55, 1, -3}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({22, 11, 33, 404}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR1({27, 67, 35, 402}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({26, 66, 34, 401}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 3dd2d183f3a..098072d33cd 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -65,26 +65,30 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn) + shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants), - variable_representation_shape_fn_(variable_representation_shape_fn) {} + is_entry_computation_(is_entry_computation), + shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. void XlaContext::AddRetval(int retval_index, DataType type, - const xla::XlaOp& handle) { + const TensorShape& shape, const xla::XlaOp& handle) { VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; // Add the return value to the list being built up. if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - retvals_[retval_index].set_handle(handle); + XlaExpression e; + e.set_handle(handle); + retvals_[retval_index] = Retval{type, shape, e}; } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, @@ -94,13 +98,11 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - if (resolve_compile_time_constants_) { - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - retvals_[retval_index].set_constant_value(std::move(value)); - } else { - retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal)); - } + Tensor value; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); + XlaExpression e; + e.set_constant_value(value); + retvals_[retval_index] = Retval{dtype, value.shape(), e}; return Status::OK(); } @@ -117,9 +119,9 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, - DataType type) const { - return (*variable_representation_shape_fn_)(shape, type); +TensorShape XlaContext::RepresentationShape(const TensorShape& shape, + DataType type) const { + return (*shape_representation_fn_)(shape, type); } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 1136ffe5073..341bf6ff1f3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -42,11 +42,13 @@ class XlaContext : public ResourceBase { static XlaContext& Get(const OpKernelContext* ctx); static XlaContext& Get(const XlaOpKernelContext* ctx); - // Creates a new XlaContext. + // Creates a new XlaContext. See the documentation on the class data fields + // for descriptions of the arguments. XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + bool is_entry_computation, const std::function* - variable_representation_shape_fn); + shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -58,14 +60,26 @@ class XlaContext : public ResourceBase { bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } + bool resolve_compile_time_constants() const { + return resolve_compile_time_constants_; + } + bool is_entry_computation() const { return is_entry_computation_; } + const std::vector& args() const { return args_; } void set_args(std::vector args); - const std::vector& retvals() { return retvals_; } + struct Retval { + DataType type; + TensorShape shape; + // An XlaExpression representing the Retval's value. + XlaExpression expression; + }; + const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle); + void AddRetval(int retval_index, DataType type, const TensorShape& shape, + const xla::XlaOp& handle); // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, @@ -86,9 +100,9 @@ class XlaContext : public ResourceBase { } // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`. - TensorShape VariableRepresentationShape(const TensorShape& shape, - DataType type) const; + // and `type`, or of an argument or return value of a top-level computation. + TensorShape RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -131,15 +145,23 @@ class XlaContext : public ResourceBase { std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // A function that describes how variable shapes should be represented - // in XLA. Variable values will be reshaped to this shape. Must be non-null. + // Is this a top-level computation, or an inner computation (e.g., a while + // body)? + const bool is_entry_computation_; + + // A function that describes how the shapes of + // a) argument and return value, for entry computations + // b) variables, for all computations, + // should be represented in XLA. Parameters/return values will be shaped + // according to this function, and reshaped back to/from their declared shapes + // for computations. Must be non-null. const std::function* - variable_representation_shape_fn_; + shape_representation_fn_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2b65f4d5d59..76c68d81af4 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, } XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = xla_context.VariableRepresentationShape( - variable->shape(), variable->type()); + TensorShape representation_shape = + xla_context.RepresentationShape(variable->shape(), variable->type()); if (representation_shape == variable->shape()) { *value = variable->value(); } else { @@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, XlaContext& xla_context = XlaContext::Get(context_); TensorShape representation_shape = - xla_context.VariableRepresentationShape(shape, type); + xla_context.RepresentationShape(shape, type); if (shape != representation_shape) { handle = builder()->Reshape(handle, representation_shape.dim_sizes()); }