Support dynamic outputs for XLA on demand ops.

PiperOrigin-RevId: 317902879
Change-Id: I6b6dfa54855d5996ac15d4b5c48a5db5dc230025
This commit is contained in:
Ruoxin Sang 2020-06-23 11:10:49 -07:00 committed by Geeta Chavan
parent 99fea8da0d
commit 02ad000479
6 changed files with 121 additions and 71 deletions

View File

@ -476,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
stream->ThenRecordEvent(definition_event.get());
}
std::vector<TensorShape> output_tensor_shapes;
output_tensor_shapes.reserve(ctx->num_outputs());
if (output.on_host_shape().is_dynamic()) {
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
xla::Shape output_host_shape = output.on_host_shape();
xla::Shape output_device_shape = output.on_device_shape();
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, &output, &output_host_shape, &output_device_shape));
output.set_shapes(output_host_shape, output_device_shape);
for (int i = 0; i < ctx->num_outputs(); ++i) {
const xla::Shape& subshape =
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
output_tensor_shapes.push_back(shape);
}
} else {
for (int i = 0; i < ctx->num_outputs(); ++i) {
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
}
}
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
const TensorShape& shape = compilation_result->outputs[i].shape;
const TensorShape& shape = output_tensor_shapes[i];
const DataType& type = compilation_result->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);

View File

@ -1202,6 +1202,9 @@ cc_library(
srcs = ["transfer_manager.cc"],
hdrs = ["transfer_manager.h"],
deps = [
":compiler",
":executable",
":maybe_owning_device_memory",
":shaped_buffer",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
@ -1210,8 +1213,6 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/stream_executor:device_memory",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -33,6 +34,7 @@ limitations under the License.
using absl::StrCat;
namespace xla {
/* static */ tensorflow::mutex
TransferManager::platform_transfer_manager_mutex_(
tensorflow::LINKER_INITIALIZED);
@ -200,6 +202,67 @@ void TransferManager::TransferArrayFromDevice(
std::move(done), transfer_metadata);
}
Status TransferManager::ReadDynamicShapes(se::Stream* stream,
ShapedBuffer* device_buffer,
Shape* host_shape,
Shape* device_shape) {
DCHECK(device_shape->is_dynamic());
Shape original_device_shape = *device_shape;
Shape original_host_shape = *host_shape;
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_ASSIGN_OR_RETURN(auto compiler,
Compiler::GetForPlatform(stream->parent()->platform()));
TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
[&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
const Shape& buffer_shape =
ShapeUtil::GetSubshape(*device_shape, index);
if (buffer_shape.IsTuple()) {
return Status::OK();
}
Shape& host_sub_shape =
*ShapeUtil::GetMutableSubshape(host_shape, index);
Shape& device_sub_shape =
*ShapeUtil::GetMutableSubshape(device_shape, index);
if (device_sub_shape.is_static()) {
return Status::OK();
}
// Read the dynamic shape metadata from the device stream.
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
const int64 offset = shape_size_fn(buffer_shape_static);
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
if (metadata_size == 0) {
return InvalidArgument("Dynamic shape metadata size should not be 0");
}
auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
auto metadata_buffer =
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
TF_ASSIGN_OR_RETURN(
auto metadata,
TransferArrayFromDevice(
stream,
ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
metadata_buffer));
// Update shape size from metadata.
for (int64 i = 0; i < metadata.element_count(); ++i) {
host_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
}
return Status::OK();
}));
host_shape->clear_dynamic_dimensions();
device_shape->clear_dynamic_dimensions();
TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
original_device_shape));
TF_RET_CHECK(
ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape));
return Status::OK();
}
/* static */ void TransferManager::RegisterTransferManager(
se::Platform::Id platform_id,
TransferManagerCreationFunction creation_function) {

View File

@ -184,6 +184,15 @@ class TransferManager {
const se::DeviceMemoryBase& source,
const TransferMetadata* transfer_metadata = nullptr);
// Read from a device buffer and update the dynamic dimension sizes of
// `host_shape` and `device_shape`. The function takes in bounded dynamic
// shapes, and returns static shapes with dynamic shapes updated.
// The shape of the buffer also have to be compatible with the host shape and
// device shape.
virtual Status ReadDynamicShapes(se::Stream* stream,
ShapedBuffer* device_buffer,
Shape* host_shape, Shape* device_shape);
// Transfers the given literal into the Infeed interface of the device,
// using the given executor.
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,

View File

@ -264,86 +264,28 @@ Status UpdateDynamicInputs(
return Status::OK();
}
xla::StatusOr<xla::Literal> ReadMetadataLiteral(
se::Stream* stream, se::DeviceMemoryBase buffer,
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
stream->parent()->platform()));
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
xla::Shape buffer_shape_static =
xla::ShapeUtil::MakeStaticShape(buffer_shape);
const int64 offset = shape_size_fn(buffer_shape_static);
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
TF_RET_CHECK(metadata_size != 0);
auto buffer_8 = se::DeviceMemory<uint8>(buffer);
auto metadata_buffer =
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
return transfer_manager->TransferArrayFromDevice(
stream,
xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}),
metadata_buffer);
}
// For each subshape in the result buffer that's dynamic, read the dynamic
// dimension sizes from the metadata, and update output shapes. The result shape
// is a static and concrete shape.
xla::Status UpdateDynamicOutputs(se::Stream* stream,
const xla::ShapedBuffer& shaped_buffer,
xla::Shape* output_host_shape,
xla::Shape* output_device_shape) {
DCHECK(output_device_shape->is_dynamic());
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachElementWithStatus(
[&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
const xla::Shape& buffer_shape =
xla::ShapeUtil::GetSubshape(*output_device_shape, index);
if (buffer_shape.IsTuple()) {
return Status::OK();
}
xla::Shape& host_shape =
*xla::ShapeUtil::GetMutableSubshape(output_host_shape, index);
xla::Shape& device_shape =
*xla::ShapeUtil::GetMutableSubshape(output_device_shape, index);
if (device_shape.is_static()) {
return Status::OK();
}
TF_ASSIGN_OR_RETURN(auto metadata,
ReadMetadataLiteral(stream, buffer, buffer_shape,
transfer_manager));
// Update shape size from metadata.
for (int64 i = 0; i < metadata.element_count(); ++i) {
host_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
device_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
}
return Status::OK();
}));
output_host_shape->clear_dynamic_dimensions();
output_device_shape->clear_dynamic_dimensions();
return Status::OK();
}
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
int device_ordinal) {
XRTTupleAllocation* output_tuple;
const xla::ScopedShapedBuffer& shaped_buffer = run_result.Result();
if (shaped_buffer.on_device_shape().is_dynamic()) {
xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult();
if (shaped_buffer->on_device_shape().is_dynamic()) {
// Update dynamic shapes from output buffer, and create a XRT tensor with
// dimension sizes read from metadata.
xla::Shape output_host_shape = shaped_buffer.on_host_shape();
xla::Shape output_device_shape = shaped_buffer.on_device_shape();
TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
xla::Shape output_host_shape = shaped_buffer->on_host_shape();
xla::Shape output_device_shape = shaped_buffer->on_device_shape();
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
stream, shaped_buffer, &output_host_shape, &output_device_shape));
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
shaped_buffer, output_host_shape, output_device_shape, backend,
*shaped_buffer, output_host_shape, output_device_shape, backend,
device_ordinal, &output_tuple));
} else {
// Fast-path: Don't copy shapes of output buffer.
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
shaped_buffer, backend, device_ordinal, &output_tuple));
*shaped_buffer, backend, device_ordinal, &output_tuple));
}
// After the output tuple is created, we can release the output result
// buffers, to make sure they won't be cleared by its destructor.

View File

@ -123,6 +123,15 @@ class TPUTest(test.TestCase):
result = bar() + 1
self.assertAllEqual(result, 2)
def test_on_demand_op_with_dynamic_output(self):
with ops.device("/device:TPU:0"):
where_output = array_ops.where([True, False, True])
self.assertAllEqual(where_output, [[0], [2]])
with ops.device("/device:TPU:0"):
repeat_output = array_ops.repeat(math_ops.range(2), [1, 4])
self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1])
@parameterized.named_parameters([("PackedVar", True), ("", False)])
class TPUStrategyTest(test.TestCase, parameterized.TestCase):