Remove the use of host shape in TransferManager::ReadDynamicShapes.

PiperOrigin-RevId: 342936365
Change-Id: I3c198aebf741ce3184ad30ecb0f110ac98c0f3e5
This commit is contained in:
Yunxing Dai 2020-11-17 13:40:51 -08:00 committed by TensorFlower Gardener
parent 9cec6e4e5b
commit 668b1d76b0
5 changed files with 12 additions and 26 deletions

View File

@ -449,15 +449,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
xla::Shape output_host_shape = output.on_host_shape();
xla::Shape output_device_shape = output.on_device_shape();
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, &output, &output_host_shape, &output_device_shape));
stream, &output, &output_device_shape));
output.set_shapes(output_host_shape, output_device_shape);
output.set_shapes(output_device_shape, output_device_shape);
for (int i = 0; i < ctx->num_outputs(); ++i) {
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
xla::ShapeUtil::GetSubshape(output_device_shape, {i});
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
output_tensor_shapes.push_back(shape);

View File

@ -201,11 +201,9 @@ void TransferManager::TransferArrayFromDevice(
Status TransferManager::ReadDynamicShapes(se::Stream* stream,
ShapedBuffer* device_buffer,
Shape* host_shape,
Shape* device_shape) {
DCHECK(device_shape->is_dynamic());
Shape original_device_shape = *device_shape;
Shape original_host_shape = *host_shape;
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_ASSIGN_OR_RETURN(auto compiler,
@ -217,8 +215,6 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream,
if (buffer_shape.IsTuple()) {
return Status::OK();
}
Shape& host_sub_shape =
*ShapeUtil::GetMutableSubshape(host_shape, index);
Shape& device_sub_shape =
*ShapeUtil::GetMutableSubshape(device_shape, index);
if (device_sub_shape.is_static()) {
@ -245,18 +241,14 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream,
// Update shape size from metadata.
for (int64 i = 0; i < metadata.element_count(); ++i) {
host_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
}
return Status::OK();
}));
host_shape->clear_dynamic_dimensions();
device_shape->clear_dynamic_dimensions();
TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
original_device_shape));
TF_RET_CHECK(
ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape));
return Status::OK();
}

View File

@ -193,10 +193,9 @@ class TransferManager {
// shapes, and returns static shapes with dynamic shapes updated.
// The shape of the buffer also have to be compatible with the host shape and
// device shape.
// TODO(b/170310047): remove host_shape.
virtual Status ReadDynamicShapes(se::Stream* stream,
ShapedBuffer* device_buffer,
Shape* host_shape, Shape* device_shape);
Shape* device_shape);
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.

View File

@ -272,16 +272,16 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
if (shaped_buffer->on_device_shape().is_dynamic()) {
// Update dynamic shapes from output buffer, and create a XRT tensor with
// dimension sizes read from metadata.
xla::Shape output_host_shape = shaped_buffer->on_host_shape();
xla::Shape output_device_shape = shaped_buffer->on_device_shape();
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, shaped_buffer, &output_host_shape, &output_device_shape));
stream, shaped_buffer, &output_device_shape));
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
*shaped_buffer, output_host_shape, output_device_shape, backend,
device_ordinal, &output_tuple));
*shaped_buffer,
xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape),
output_device_shape, backend, device_ordinal, &output_tuple));
} else {
// Fast-path: Don't copy shapes of output buffer.
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(

View File

@ -435,16 +435,14 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
auto output_buffers =
absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
xla::Shape output_host_shape = output_buffers->buffers.on_host_shape();
xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
if (!output_host_shape.is_static()) {
if (!output_device_shape.is_static()) {
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, &output_buffers->buffers, &output_host_shape,
&output_device_shape));
stream, &output_buffers->buffers, &output_device_shape));
for (int64 i = 0; i < sub_elements; ++i) {
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
xla::ShapeUtil::GetSubshape(output_device_shape, {i});
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
output_tensor_shapes[i] = shape;
@ -454,8 +452,6 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
// Transfers ownership of the buffers that back XLA computation output 'i'
// to 'output_tensor'.
auto transfer_buffers = [&](int i, Tensor* output_tensor) {
const xla::Shape& host_shape =
xla::ShapeUtil::GetTupleElementShape(output_host_shape, i);
const xla::Shape& device_shape =
xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
@ -464,7 +460,7 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
// backing XlaTensor, so we let retain 'output_buffers' ownership of any
// buffers in that case.
if (output_tensor->NumElements() > 0) {
xla::ScopedShapedBuffer shaped_buffer(host_shape, device_shape, allocator,
xla::ScopedShapedBuffer shaped_buffer(device_shape, allocator,
device_ordinal);
shaped_buffer.buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {