Remove the use of host shape in TransferManager::ReadDynamicShapes.
PiperOrigin-RevId: 342936365 Change-Id: I3c198aebf741ce3184ad30ecb0f110ac98c0f3e5
This commit is contained in:
parent
9cec6e4e5b
commit
668b1d76b0
tensorflow
compiler
core/tpu/kernels
@ -449,15 +449,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
|
||||
xla::Shape output_host_shape = output.on_host_shape();
|
||||
xla::Shape output_device_shape = output.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, &output, &output_host_shape, &output_device_shape));
|
||||
stream, &output, &output_device_shape));
|
||||
|
||||
output.set_shapes(output_host_shape, output_device_shape);
|
||||
output.set_shapes(output_device_shape, output_device_shape);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const xla::Shape& subshape =
|
||||
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||
xla::ShapeUtil::GetSubshape(output_device_shape, {i});
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||
output_tensor_shapes.push_back(shape);
|
||||
|
@ -201,11 +201,9 @@ void TransferManager::TransferArrayFromDevice(
|
||||
|
||||
Status TransferManager::ReadDynamicShapes(se::Stream* stream,
|
||||
ShapedBuffer* device_buffer,
|
||||
Shape* host_shape,
|
||||
Shape* device_shape) {
|
||||
DCHECK(device_shape->is_dynamic());
|
||||
Shape original_device_shape = *device_shape;
|
||||
Shape original_host_shape = *host_shape;
|
||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto compiler,
|
||||
@ -217,8 +215,6 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream,
|
||||
if (buffer_shape.IsTuple()) {
|
||||
return Status::OK();
|
||||
}
|
||||
Shape& host_sub_shape =
|
||||
*ShapeUtil::GetMutableSubshape(host_shape, index);
|
||||
Shape& device_sub_shape =
|
||||
*ShapeUtil::GetMutableSubshape(device_shape, index);
|
||||
if (device_sub_shape.is_static()) {
|
||||
@ -245,18 +241,14 @@ Status TransferManager::ReadDynamicShapes(se::Stream* stream,
|
||||
|
||||
// Update shape size from metadata.
|
||||
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
||||
host_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
host_shape->clear_dynamic_dimensions();
|
||||
device_shape->clear_dynamic_dimensions();
|
||||
|
||||
TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
|
||||
original_device_shape));
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -193,10 +193,9 @@ class TransferManager {
|
||||
// shapes, and returns static shapes with dynamic shapes updated.
|
||||
// The shape of the buffer also have to be compatible with the host shape and
|
||||
// device shape.
|
||||
// TODO(b/170310047): remove host_shape.
|
||||
virtual Status ReadDynamicShapes(se::Stream* stream,
|
||||
ShapedBuffer* device_buffer,
|
||||
Shape* host_shape, Shape* device_shape);
|
||||
Shape* device_shape);
|
||||
|
||||
// Transfers the given literal into the Infeed interface of the device,
|
||||
// using the given executor.
|
||||
|
@ -272,16 +272,16 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
||||
if (shaped_buffer->on_device_shape().is_dynamic()) {
|
||||
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
||||
// dimension sizes read from metadata.
|
||||
xla::Shape output_host_shape = shaped_buffer->on_host_shape();
|
||||
xla::Shape output_device_shape = shaped_buffer->on_device_shape();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, shaped_buffer, &output_host_shape, &output_device_shape));
|
||||
stream, shaped_buffer, &output_device_shape));
|
||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||
*shaped_buffer, output_host_shape, output_device_shape, backend,
|
||||
device_ordinal, &output_tuple));
|
||||
*shaped_buffer,
|
||||
xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape),
|
||||
output_device_shape, backend, device_ordinal, &output_tuple));
|
||||
} else {
|
||||
// Fast-path: Don't copy shapes of output buffer.
|
||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||
|
@ -435,16 +435,14 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
|
||||
auto output_buffers =
|
||||
absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
|
||||
|
||||
xla::Shape output_host_shape = output_buffers->buffers.on_host_shape();
|
||||
xla::Shape output_device_shape = output_buffers->buffers.on_device_shape();
|
||||
|
||||
if (!output_host_shape.is_static()) {
|
||||
if (!output_device_shape.is_static()) {
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, &output_buffers->buffers, &output_host_shape,
|
||||
&output_device_shape));
|
||||
stream, &output_buffers->buffers, &output_device_shape));
|
||||
for (int64 i = 0; i < sub_elements; ++i) {
|
||||
const xla::Shape& subshape =
|
||||
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||
xla::ShapeUtil::GetSubshape(output_device_shape, {i});
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||
output_tensor_shapes[i] = shape;
|
||||
@ -454,8 +452,6 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
|
||||
// Transfers ownership of the buffers that back XLA computation output 'i'
|
||||
// to 'output_tensor'.
|
||||
auto transfer_buffers = [&](int i, Tensor* output_tensor) {
|
||||
const xla::Shape& host_shape =
|
||||
xla::ShapeUtil::GetTupleElementShape(output_host_shape, i);
|
||||
const xla::Shape& device_shape =
|
||||
xla::ShapeUtil::GetTupleElementShape(output_device_shape, i);
|
||||
|
||||
@ -464,7 +460,7 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
|
||||
// backing XlaTensor, so we let retain 'output_buffers' ownership of any
|
||||
// buffers in that case.
|
||||
if (output_tensor->NumElements() > 0) {
|
||||
xla::ScopedShapedBuffer shaped_buffer(host_shape, device_shape, allocator,
|
||||
xla::ScopedShapedBuffer shaped_buffer(device_shape, allocator,
|
||||
device_ordinal);
|
||||
shaped_buffer.buffers().ForEachMutableElement(
|
||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||
|
Loading…
Reference in New Issue
Block a user