From 668b1d76b09eb0165f26e81ba300b7f8df97029b Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Tue, 17 Nov 2020 13:40:51 -0800 Subject: [PATCH] Remove the use of host shape in TransferManager::ReadDynamicShapes. PiperOrigin-RevId: 342936365 Change-Id: I3c198aebf741ce3184ad30ecb0f110ac98c0f3e5 --- tensorflow/compiler/jit/xla_launch_util.cc | 7 +++---- tensorflow/compiler/xla/service/transfer_manager.cc | 8 -------- tensorflow/compiler/xla/service/transfer_manager.h | 3 +-- tensorflow/compiler/xrt/kernels/xrt_execute_op.cc | 8 ++++---- tensorflow/core/tpu/kernels/tpu_execute_op.cc | 12 ++++-------- 5 files changed, 12 insertions(+), 26 deletions(-) diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index b7f83301d2d..f793d991bde 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -449,15 +449,14 @@ Status XlaComputationLaunchContext::PopulateOutputs( auto transfer_manager, xla::TransferManager::GetForPlatform(stream->parent()->platform())); - xla::Shape output_host_shape = output.on_host_shape(); xla::Shape output_device_shape = output.on_device_shape(); TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes( - stream, &output, &output_host_shape, &output_device_shape)); + stream, &output, &output_device_shape)); - output.set_shapes(output_host_shape, output_device_shape); + output.set_shapes(output_device_shape, output_device_shape); for (int i = 0; i < ctx->num_outputs(); ++i) { const xla::Shape& subshape = - xla::ShapeUtil::GetSubshape(output_host_shape, {i}); + xla::ShapeUtil::GetSubshape(output_device_shape, {i}); TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape)); output_tensor_shapes.push_back(shape); diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 96dee4be524..64c468944e2 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -201,11 +201,9 @@ void TransferManager::TransferArrayFromDevice( Status TransferManager::ReadDynamicShapes(se::Stream* stream, ShapedBuffer* device_buffer, - Shape* host_shape, Shape* device_shape) { DCHECK(device_shape->is_dynamic()); Shape original_device_shape = *device_shape; - Shape original_host_shape = *host_shape; TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); TF_ASSIGN_OR_RETURN(auto compiler, @@ -217,8 +215,6 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream, if (buffer_shape.IsTuple()) { return Status::OK(); } - Shape& host_sub_shape = - *ShapeUtil::GetMutableSubshape(host_shape, index); Shape& device_sub_shape = *ShapeUtil::GetMutableSubshape(device_shape, index); if (device_sub_shape.is_static()) { @@ -245,18 +241,14 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream, // Update shape size from metadata. for (int64 i = 0; i < metadata.element_count(); ++i) { - host_sub_shape.mutable_dimensions()[i] = metadata.Get({i}); device_sub_shape.mutable_dimensions()[i] = metadata.Get({i}); } return Status::OK(); })); - host_shape->clear_dynamic_dimensions(); device_shape->clear_dynamic_dimensions(); TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape, original_device_shape)); - TF_RET_CHECK( - ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index c49d7d899e7..d7636c30c36 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -193,10 +193,9 @@ class TransferManager { // shapes, and returns static shapes with dynamic shapes updated. // The shape of the buffer also have to be compatible with the host shape and // device shape. - // TODO(b/170310047): remove host_shape. virtual Status ReadDynamicShapes(se::Stream* stream, ShapedBuffer* device_buffer, - Shape* host_shape, Shape* device_shape); + Shape* device_shape); // Transfers the given literal into the Infeed interface of the device, // using the given executor. diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index bfd48bd1442..fb090da669f 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -272,16 +272,16 @@ xla::StatusOr> CreateOutputTuple( if (shaped_buffer->on_device_shape().is_dynamic()) { // Update dynamic shapes from output buffer, and create a XRT tensor with // dimension sizes read from metadata. - xla::Shape output_host_shape = shaped_buffer->on_host_shape(); xla::Shape output_device_shape = shaped_buffer->on_device_shape(); TF_ASSIGN_OR_RETURN( auto transfer_manager, xla::TransferManager::GetForPlatform(stream->parent()->platform())); TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes( - stream, shaped_buffer, &output_host_shape, &output_device_shape)); + stream, shaped_buffer, &output_device_shape)); TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - *shaped_buffer, output_host_shape, output_device_shape, backend, - device_ordinal, &output_tuple)); + *shaped_buffer, + xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape), + output_device_shape, backend, device_ordinal, &output_tuple)); } else { // Fast-path: Don't copy shapes of output buffer. TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( diff --git a/tensorflow/core/tpu/kernels/tpu_execute_op.cc b/tensorflow/core/tpu/kernels/tpu_execute_op.cc index 88c799bc64b..bc3631b5cd0 100644 --- a/tensorflow/core/tpu/kernels/tpu_execute_op.cc +++ b/tensorflow/core/tpu/kernels/tpu_execute_op.cc @@ -435,16 +435,14 @@ xla::StatusOr> AllocateOutputTensors( auto output_buffers = absl::make_unique(std::move(scoped_buffers), allocator); - xla::Shape output_host_shape = output_buffers->buffers.on_host_shape(); xla::Shape output_device_shape = output_buffers->buffers.on_device_shape(); - if (!output_host_shape.is_static()) { + if (!output_device_shape.is_static()) { TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes( - stream, &output_buffers->buffers, &output_host_shape, - &output_device_shape)); + stream, &output_buffers->buffers, &output_device_shape)); for (int64 i = 0; i < sub_elements; ++i) { const xla::Shape& subshape = - xla::ShapeUtil::GetSubshape(output_host_shape, {i}); + xla::ShapeUtil::GetSubshape(output_device_shape, {i}); TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape)); output_tensor_shapes[i] = shape; @@ -454,8 +452,6 @@ xla::StatusOr> AllocateOutputTensors( // Transfers ownership of the buffers that back XLA computation output 'i' // to 'output_tensor'. auto transfer_buffers = [&](int i, Tensor* output_tensor) { - const xla::Shape& host_shape = - xla::ShapeUtil::GetTupleElementShape(output_host_shape, i); const xla::Shape& device_shape = xla::ShapeUtil::GetTupleElementShape(output_device_shape, i); @@ -464,7 +460,7 @@ xla::StatusOr> AllocateOutputTensors( // backing XlaTensor, so we let retain 'output_buffers' ownership of any // buffers in that case. if (output_tensor->NumElements() > 0) { - xla::ScopedShapedBuffer shaped_buffer(host_shape, device_shape, allocator, + xla::ScopedShapedBuffer shaped_buffer(device_shape, allocator, device_ordinal); shaped_buffer.buffers().ForEachMutableElement( [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {