Support dynamic outputs for XLA on demand ops.
PiperOrigin-RevId: 317902879 Change-Id: I6b6dfa54855d5996ac15d4b5c48a5db5dc230025
This commit is contained in:
parent
99fea8da0d
commit
02ad000479
@ -476,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
stream->ThenRecordEvent(definition_event.get());
|
||||
}
|
||||
|
||||
std::vector<TensorShape> output_tensor_shapes;
|
||||
output_tensor_shapes.reserve(ctx->num_outputs());
|
||||
if (output.on_host_shape().is_dynamic()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
|
||||
xla::Shape output_host_shape = output.on_host_shape();
|
||||
xla::Shape output_device_shape = output.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, &output, &output_host_shape, &output_device_shape));
|
||||
|
||||
output.set_shapes(output_host_shape, output_device_shape);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const xla::Shape& subshape =
|
||||
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||
output_tensor_shapes.push_back(shape);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy XLA results to the OpOutputList.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const TensorShape& shape = compilation_result->outputs[i].shape;
|
||||
const TensorShape& shape = output_tensor_shapes[i];
|
||||
const DataType& type = compilation_result->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
|
@ -1202,6 +1202,9 @@ cc_library(
|
||||
srcs = ["transfer_manager.cc"],
|
||||
hdrs = ["transfer_manager.h"],
|
||||
deps = [
|
||||
":compiler",
|
||||
":executable",
|
||||
":maybe_owning_device_memory",
|
||||
":shaped_buffer",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1210,8 +1213,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -33,6 +34,7 @@ limitations under the License.
|
||||
using absl::StrCat;
|
||||
|
||||
namespace xla {
|
||||
|
||||
/* static */ tensorflow::mutex
|
||||
TransferManager::platform_transfer_manager_mutex_(
|
||||
tensorflow::LINKER_INITIALIZED);
|
||||
@ -200,6 +202,67 @@ void TransferManager::TransferArrayFromDevice(
|
||||
std::move(done), transfer_metadata);
|
||||
}
|
||||
|
||||
Status TransferManager::ReadDynamicShapes(se::Stream* stream,
|
||||
ShapedBuffer* device_buffer,
|
||||
Shape* host_shape,
|
||||
Shape* device_shape) {
|
||||
DCHECK(device_shape->is_dynamic());
|
||||
Shape original_device_shape = *device_shape;
|
||||
Shape original_host_shape = *host_shape;
|
||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto compiler,
|
||||
Compiler::GetForPlatform(stream->parent()->platform()));
|
||||
TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
|
||||
[&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||
const Shape& buffer_shape =
|
||||
ShapeUtil::GetSubshape(*device_shape, index);
|
||||
if (buffer_shape.IsTuple()) {
|
||||
return Status::OK();
|
||||
}
|
||||
Shape& host_sub_shape =
|
||||
*ShapeUtil::GetMutableSubshape(host_shape, index);
|
||||
Shape& device_sub_shape =
|
||||
*ShapeUtil::GetMutableSubshape(device_shape, index);
|
||||
if (device_sub_shape.is_static()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read the dynamic shape metadata from the device stream.
|
||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||
Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
|
||||
const int64 offset = shape_size_fn(buffer_shape_static);
|
||||
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
||||
if (metadata_size == 0) {
|
||||
return InvalidArgument("Dynamic shape metadata size should not be 0");
|
||||
}
|
||||
auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
|
||||
auto metadata_buffer =
|
||||
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto metadata,
|
||||
TransferArrayFromDevice(
|
||||
stream,
|
||||
ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
|
||||
metadata_buffer));
|
||||
|
||||
// Update shape size from metadata.
|
||||
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
||||
host_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
host_shape->clear_dynamic_dimensions();
|
||||
device_shape->clear_dynamic_dimensions();
|
||||
|
||||
TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
|
||||
original_device_shape));
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */ void TransferManager::RegisterTransferManager(
|
||||
se::Platform::Id platform_id,
|
||||
TransferManagerCreationFunction creation_function) {
|
||||
|
@ -184,6 +184,15 @@ class TransferManager {
|
||||
const se::DeviceMemoryBase& source,
|
||||
const TransferMetadata* transfer_metadata = nullptr);
|
||||
|
||||
// Read from a device buffer and update the dynamic dimension sizes of
|
||||
// `host_shape` and `device_shape`. The function takes in bounded dynamic
|
||||
// shapes, and returns static shapes with dynamic shapes updated.
|
||||
// The shape of the buffer also have to be compatible with the host shape and
|
||||
// device shape.
|
||||
virtual Status ReadDynamicShapes(se::Stream* stream,
|
||||
ShapedBuffer* device_buffer,
|
||||
Shape* host_shape, Shape* device_shape);
|
||||
|
||||
// Transfers the given literal into the Infeed interface of the device,
|
||||
// using the given executor.
|
||||
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
||||
|
@ -264,86 +264,28 @@ Status UpdateDynamicInputs(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::Literal> ReadMetadataLiteral(
|
||||
se::Stream* stream, se::DeviceMemoryBase buffer,
|
||||
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
|
||||
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
||||
stream->parent()->platform()));
|
||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||
xla::Shape buffer_shape_static =
|
||||
xla::ShapeUtil::MakeStaticShape(buffer_shape);
|
||||
const int64 offset = shape_size_fn(buffer_shape_static);
|
||||
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
||||
TF_RET_CHECK(metadata_size != 0);
|
||||
auto buffer_8 = se::DeviceMemory<uint8>(buffer);
|
||||
auto metadata_buffer =
|
||||
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
||||
return transfer_manager->TransferArrayFromDevice(
|
||||
stream,
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}),
|
||||
metadata_buffer);
|
||||
}
|
||||
|
||||
// For each subshape in the result buffer that's dynamic, read the dynamic
|
||||
// dimension sizes from the metadata, and update output shapes. The result shape
|
||||
// is a static and concrete shape.
|
||||
xla::Status UpdateDynamicOutputs(se::Stream* stream,
|
||||
const xla::ShapedBuffer& shaped_buffer,
|
||||
xla::Shape* output_host_shape,
|
||||
xla::Shape* output_device_shape) {
|
||||
DCHECK(output_device_shape->is_dynamic());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachElementWithStatus(
|
||||
[&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
|
||||
const xla::Shape& buffer_shape =
|
||||
xla::ShapeUtil::GetSubshape(*output_device_shape, index);
|
||||
if (buffer_shape.IsTuple()) {
|
||||
return Status::OK();
|
||||
}
|
||||
xla::Shape& host_shape =
|
||||
*xla::ShapeUtil::GetMutableSubshape(output_host_shape, index);
|
||||
xla::Shape& device_shape =
|
||||
*xla::ShapeUtil::GetMutableSubshape(output_device_shape, index);
|
||||
if (device_shape.is_static()) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto metadata,
|
||||
ReadMetadataLiteral(stream, buffer, buffer_shape,
|
||||
transfer_manager));
|
||||
// Update shape size from metadata.
|
||||
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
||||
host_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
device_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
output_host_shape->clear_dynamic_dimensions();
|
||||
output_device_shape->clear_dynamic_dimensions();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
||||
se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
|
||||
int device_ordinal) {
|
||||
XRTTupleAllocation* output_tuple;
|
||||
const xla::ScopedShapedBuffer& shaped_buffer = run_result.Result();
|
||||
if (shaped_buffer.on_device_shape().is_dynamic()) {
|
||||
xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult();
|
||||
if (shaped_buffer->on_device_shape().is_dynamic()) {
|
||||
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
||||
// dimension sizes read from metadata.
|
||||
xla::Shape output_host_shape = shaped_buffer.on_host_shape();
|
||||
xla::Shape output_device_shape = shaped_buffer.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
|
||||
xla::Shape output_host_shape = shaped_buffer->on_host_shape();
|
||||
xla::Shape output_device_shape = shaped_buffer->on_device_shape();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, shaped_buffer, &output_host_shape, &output_device_shape));
|
||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||
shaped_buffer, output_host_shape, output_device_shape, backend,
|
||||
*shaped_buffer, output_host_shape, output_device_shape, backend,
|
||||
device_ordinal, &output_tuple));
|
||||
} else {
|
||||
// Fast-path: Don't copy shapes of output buffer.
|
||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||
shaped_buffer, backend, device_ordinal, &output_tuple));
|
||||
*shaped_buffer, backend, device_ordinal, &output_tuple));
|
||||
}
|
||||
// After the output tuple is created, we can release the output result
|
||||
// buffers, to make sure they won't be cleared by its destructor.
|
||||
|
@ -123,6 +123,15 @@ class TPUTest(test.TestCase):
|
||||
result = bar() + 1
|
||||
self.assertAllEqual(result, 2)
|
||||
|
||||
def test_on_demand_op_with_dynamic_output(self):
|
||||
with ops.device("/device:TPU:0"):
|
||||
where_output = array_ops.where([True, False, True])
|
||||
self.assertAllEqual(where_output, [[0], [2]])
|
||||
|
||||
with ops.device("/device:TPU:0"):
|
||||
repeat_output = array_ops.repeat(math_ops.range(2), [1, 4])
|
||||
self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1])
|
||||
|
||||
|
||||
@parameterized.named_parameters([("PackedVar", True), ("", False)])
|
||||
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user