From 02ad000479047c2ef2a7e3d3d0ee2c5e01ff736a Mon Sep 17 00:00:00 2001 From: Ruoxin Sang Date: Tue, 23 Jun 2020 11:10:49 -0700 Subject: [PATCH] Support dynamic outputs for XLA on demand ops. PiperOrigin-RevId: 317902879 Change-Id: I6b6dfa54855d5996ac15d4b5c48a5db5dc230025 --- tensorflow/compiler/jit/xla_launch_util.cc | 28 ++++++- tensorflow/compiler/xla/service/BUILD | 5 +- .../compiler/xla/service/transfer_manager.cc | 63 +++++++++++++++ .../compiler/xla/service/transfer_manager.h | 9 +++ .../compiler/xrt/kernels/xrt_execute_op.cc | 78 +++---------------- .../python/distribute/tpu_strategy_test.py | 9 +++ 6 files changed, 121 insertions(+), 71 deletions(-) diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index fc0ff8d9445..eb31b23c991 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -476,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs( stream->ThenRecordEvent(definition_event.get()); } + std::vector output_tensor_shapes; + output_tensor_shapes.reserve(ctx->num_outputs()); + if (output.on_host_shape().is_dynamic()) { + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + + xla::Shape output_host_shape = output.on_host_shape(); + xla::Shape output_device_shape = output.on_device_shape(); + TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes( + stream, &output, &output_host_shape, &output_device_shape)); + + output.set_shapes(output_host_shape, output_device_shape); + for (int i = 0; i < ctx->num_outputs(); ++i) { + const xla::Shape& subshape = + xla::ShapeUtil::GetSubshape(output_host_shape, {i}); + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape)); + output_tensor_shapes.push_back(shape); + } + } else { + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_tensor_shapes.push_back(compilation_result->outputs[i].shape); + } + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { - const TensorShape& shape = compilation_result->outputs[i].shape; + const TensorShape& shape = output_tensor_shapes[i]; const DataType& type = compilation_result->outputs[i].type; VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " << DataTypeString(type); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 2fd457e8e47..10e2d7e65d1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1202,6 +1202,9 @@ cc_library( srcs = ["transfer_manager.cc"], hdrs = ["transfer_manager.h"], deps = [ + ":compiler", + ":executable", + ":maybe_owning_device_memory", ":shaped_buffer", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1210,8 +1213,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/service:executable", - "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory", diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index ebb0226476f..0fd64209152 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -33,6 +34,7 @@ limitations under the License. using absl::StrCat; namespace xla { + /* static */ tensorflow::mutex TransferManager::platform_transfer_manager_mutex_( tensorflow::LINKER_INITIALIZED); @@ -200,6 +202,67 @@ void TransferManager::TransferArrayFromDevice( std::move(done), transfer_metadata); } +Status TransferManager::ReadDynamicShapes(se::Stream* stream, + ShapedBuffer* device_buffer, + Shape* host_shape, + Shape* device_shape) { + DCHECK(device_shape->is_dynamic()); + Shape original_device_shape = *device_shape; + Shape original_host_shape = *host_shape; + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + + TF_ASSIGN_OR_RETURN(auto compiler, + Compiler::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const Shape& buffer_shape = + ShapeUtil::GetSubshape(*device_shape, index); + if (buffer_shape.IsTuple()) { + return Status::OK(); + } + Shape& host_sub_shape = + *ShapeUtil::GetMutableSubshape(host_shape, index); + Shape& device_sub_shape = + *ShapeUtil::GetMutableSubshape(device_shape, index); + if (device_sub_shape.is_static()) { + return Status::OK(); + } + + // Read the dynamic shape metadata from the device stream. + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape); + const int64 offset = shape_size_fn(buffer_shape_static); + int64 metadata_size = shape_size_fn(buffer_shape) - offset; + if (metadata_size == 0) { + return InvalidArgument("Dynamic shape metadata size should not be 0"); + } + auto buffer_8 = se::DeviceMemory(*buffer); + auto metadata_buffer = + stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + TF_ASSIGN_OR_RETURN( + auto metadata, + TransferArrayFromDevice( + stream, + ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}), + metadata_buffer)); + + // Update shape size from metadata. + for (int64 i = 0; i < metadata.element_count(); ++i) { + host_sub_shape.mutable_dimensions()[i] = metadata.Get({i}); + device_sub_shape.mutable_dimensions()[i] = metadata.Get({i}); + } + return Status::OK(); + })); + host_shape->clear_dynamic_dimensions(); + device_shape->clear_dynamic_dimensions(); + + TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape, + original_device_shape)); + TF_RET_CHECK( + ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape)); + return Status::OK(); +} + /* static */ void TransferManager::RegisterTransferManager( se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index e3f8ceacc42..c0670d26eee 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -184,6 +184,15 @@ class TransferManager { const se::DeviceMemoryBase& source, const TransferMetadata* transfer_metadata = nullptr); + // Read from a device buffer and update the dynamic dimension sizes of + // `host_shape` and `device_shape`. The function takes in bounded dynamic + // shapes, and returns static shapes with dynamic shapes updated. + // The shape of the buffer also have to be compatible with the host shape and + // device shape. + virtual Status ReadDynamicShapes(se::Stream* stream, + ShapedBuffer* device_buffer, + Shape* host_shape, Shape* device_shape); + // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 3bd8af577c8..bfd48bd1442 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -264,86 +264,28 @@ Status UpdateDynamicInputs( return Status::OK(); } -xla::StatusOr ReadMetadataLiteral( - se::Stream* stream, se::DeviceMemoryBase buffer, - const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) { - TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( - stream->parent()->platform())); - auto shape_size_fn = compiler->ShapeSizeBytesFunction(); - xla::Shape buffer_shape_static = - xla::ShapeUtil::MakeStaticShape(buffer_shape); - const int64 offset = shape_size_fn(buffer_shape_static); - int64 metadata_size = shape_size_fn(buffer_shape) - offset; - TF_RET_CHECK(metadata_size != 0); - auto buffer_8 = se::DeviceMemory(buffer); - auto metadata_buffer = - stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); - return transfer_manager->TransferArrayFromDevice( - stream, - xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}), - metadata_buffer); -} - -// For each subshape in the result buffer that's dynamic, read the dynamic -// dimension sizes from the metadata, and update output shapes. The result shape -// is a static and concrete shape. -xla::Status UpdateDynamicOutputs(se::Stream* stream, - const xla::ShapedBuffer& shaped_buffer, - xla::Shape* output_host_shape, - xla::Shape* output_device_shape) { - DCHECK(output_device_shape->is_dynamic()); - TF_ASSIGN_OR_RETURN( - auto transfer_manager, - xla::TransferManager::GetForPlatform(stream->parent()->platform())); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachElementWithStatus( - [&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) { - const xla::Shape& buffer_shape = - xla::ShapeUtil::GetSubshape(*output_device_shape, index); - if (buffer_shape.IsTuple()) { - return Status::OK(); - } - xla::Shape& host_shape = - *xla::ShapeUtil::GetMutableSubshape(output_host_shape, index); - xla::Shape& device_shape = - *xla::ShapeUtil::GetMutableSubshape(output_device_shape, index); - if (device_shape.is_static()) { - return Status::OK(); - } - TF_ASSIGN_OR_RETURN(auto metadata, - ReadMetadataLiteral(stream, buffer, buffer_shape, - transfer_manager)); - // Update shape size from metadata. - for (int64 i = 0; i < metadata.element_count(); ++i) { - host_shape.mutable_dimensions()[i] = metadata.Get({i}); - device_shape.mutable_dimensions()[i] = metadata.Get({i}); - } - return Status::OK(); - })); - output_host_shape->clear_dynamic_dimensions(); - output_device_shape->clear_dynamic_dimensions(); - return Status::OK(); -} - xla::StatusOr> CreateOutputTuple( se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend, int device_ordinal) { XRTTupleAllocation* output_tuple; - const xla::ScopedShapedBuffer& shaped_buffer = run_result.Result(); - if (shaped_buffer.on_device_shape().is_dynamic()) { + xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult(); + if (shaped_buffer->on_device_shape().is_dynamic()) { // Update dynamic shapes from output buffer, and create a XRT tensor with // dimension sizes read from metadata. - xla::Shape output_host_shape = shaped_buffer.on_host_shape(); - xla::Shape output_device_shape = shaped_buffer.on_device_shape(); - TF_RETURN_IF_ERROR(UpdateDynamicOutputs( + xla::Shape output_host_shape = shaped_buffer->on_host_shape(); + xla::Shape output_device_shape = shaped_buffer->on_device_shape(); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes( stream, shaped_buffer, &output_host_shape, &output_device_shape)); TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - shaped_buffer, output_host_shape, output_device_shape, backend, + *shaped_buffer, output_host_shape, output_device_shape, backend, device_ordinal, &output_tuple)); } else { // Fast-path: Don't copy shapes of output buffer. TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - shaped_buffer, backend, device_ordinal, &output_tuple)); + *shaped_buffer, backend, device_ordinal, &output_tuple)); } // After the output tuple is created, we can release the output result // buffers, to make sure they won't be cleared by its destructor. diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index 142743a6ec2..850981e073e 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -123,6 +123,15 @@ class TPUTest(test.TestCase): result = bar() + 1 self.assertAllEqual(result, 2) + def test_on_demand_op_with_dynamic_output(self): + with ops.device("/device:TPU:0"): + where_output = array_ops.where([True, False, True]) + self.assertAllEqual(where_output, [[0], [2]]) + + with ops.device("/device:TPU:0"): + repeat_output = array_ops.repeat(math_ops.range(2), [1, 4]) + self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1]) + @parameterized.named_parameters([("PackedVar", True), ("", False)]) class TPUStrategyTest(test.TestCase, parameterized.TestCase):