Support dynamic outputs for XLA on demand ops.
PiperOrigin-RevId: 317902879 Change-Id: I6b6dfa54855d5996ac15d4b5c48a5db5dc230025
This commit is contained in:
parent
99fea8da0d
commit
02ad000479
@ -476,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
|||||||
stream->ThenRecordEvent(definition_event.get());
|
stream->ThenRecordEvent(definition_event.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<TensorShape> output_tensor_shapes;
|
||||||
|
output_tensor_shapes.reserve(ctx->num_outputs());
|
||||||
|
if (output.on_host_shape().is_dynamic()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto transfer_manager,
|
||||||
|
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||||
|
|
||||||
|
xla::Shape output_host_shape = output.on_host_shape();
|
||||||
|
xla::Shape output_device_shape = output.on_device_shape();
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||||
|
stream, &output, &output_host_shape, &output_device_shape));
|
||||||
|
|
||||||
|
output.set_shapes(output_host_shape, output_device_shape);
|
||||||
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||||
|
const xla::Shape& subshape =
|
||||||
|
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||||
|
output_tensor_shapes.push_back(shape);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||||
|
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Copy XLA results to the OpOutputList.
|
// Copy XLA results to the OpOutputList.
|
||||||
int output_num = 0;
|
int output_num = 0;
|
||||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||||
const TensorShape& shape = compilation_result->outputs[i].shape;
|
const TensorShape& shape = output_tensor_shapes[i];
|
||||||
const DataType& type = compilation_result->outputs[i].type;
|
const DataType& type = compilation_result->outputs[i].type;
|
||||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||||
<< DataTypeString(type);
|
<< DataTypeString(type);
|
||||||
|
@ -1202,6 +1202,9 @@ cc_library(
|
|||||||
srcs = ["transfer_manager.cc"],
|
srcs = ["transfer_manager.cc"],
|
||||||
hdrs = ["transfer_manager.h"],
|
hdrs = ["transfer_manager.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":compiler",
|
||||||
|
":executable",
|
||||||
|
":maybe_owning_device_memory",
|
||||||
":shaped_buffer",
|
":shaped_buffer",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
@ -1210,8 +1213,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:executable",
|
|
||||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
"//tensorflow/stream_executor:device_memory",
|
"//tensorflow/stream_executor:device_memory",
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -33,6 +34,7 @@ limitations under the License.
|
|||||||
using absl::StrCat;
|
using absl::StrCat;
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
/* static */ tensorflow::mutex
|
/* static */ tensorflow::mutex
|
||||||
TransferManager::platform_transfer_manager_mutex_(
|
TransferManager::platform_transfer_manager_mutex_(
|
||||||
tensorflow::LINKER_INITIALIZED);
|
tensorflow::LINKER_INITIALIZED);
|
||||||
@ -200,6 +202,67 @@ void TransferManager::TransferArrayFromDevice(
|
|||||||
std::move(done), transfer_metadata);
|
std::move(done), transfer_metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TransferManager::ReadDynamicShapes(se::Stream* stream,
|
||||||
|
ShapedBuffer* device_buffer,
|
||||||
|
Shape* host_shape,
|
||||||
|
Shape* device_shape) {
|
||||||
|
DCHECK(device_shape->is_dynamic());
|
||||||
|
Shape original_device_shape = *device_shape;
|
||||||
|
Shape original_host_shape = *host_shape;
|
||||||
|
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto compiler,
|
||||||
|
Compiler::GetForPlatform(stream->parent()->platform()));
|
||||||
|
TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
|
||||||
|
[&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||||
|
const Shape& buffer_shape =
|
||||||
|
ShapeUtil::GetSubshape(*device_shape, index);
|
||||||
|
if (buffer_shape.IsTuple()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Shape& host_sub_shape =
|
||||||
|
*ShapeUtil::GetMutableSubshape(host_shape, index);
|
||||||
|
Shape& device_sub_shape =
|
||||||
|
*ShapeUtil::GetMutableSubshape(device_shape, index);
|
||||||
|
if (device_sub_shape.is_static()) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the dynamic shape metadata from the device stream.
|
||||||
|
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||||
|
Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
|
||||||
|
const int64 offset = shape_size_fn(buffer_shape_static);
|
||||||
|
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
||||||
|
if (metadata_size == 0) {
|
||||||
|
return InvalidArgument("Dynamic shape metadata size should not be 0");
|
||||||
|
}
|
||||||
|
auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
|
||||||
|
auto metadata_buffer =
|
||||||
|
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto metadata,
|
||||||
|
TransferArrayFromDevice(
|
||||||
|
stream,
|
||||||
|
ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
|
||||||
|
metadata_buffer));
|
||||||
|
|
||||||
|
// Update shape size from metadata.
|
||||||
|
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
||||||
|
host_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||||
|
device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
host_shape->clear_dynamic_dimensions();
|
||||||
|
device_shape->clear_dynamic_dimensions();
|
||||||
|
|
||||||
|
TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
|
||||||
|
original_device_shape));
|
||||||
|
TF_RET_CHECK(
|
||||||
|
ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ void TransferManager::RegisterTransferManager(
|
/* static */ void TransferManager::RegisterTransferManager(
|
||||||
se::Platform::Id platform_id,
|
se::Platform::Id platform_id,
|
||||||
TransferManagerCreationFunction creation_function) {
|
TransferManagerCreationFunction creation_function) {
|
||||||
|
@ -184,6 +184,15 @@ class TransferManager {
|
|||||||
const se::DeviceMemoryBase& source,
|
const se::DeviceMemoryBase& source,
|
||||||
const TransferMetadata* transfer_metadata = nullptr);
|
const TransferMetadata* transfer_metadata = nullptr);
|
||||||
|
|
||||||
|
// Read from a device buffer and update the dynamic dimension sizes of
|
||||||
|
// `host_shape` and `device_shape`. The function takes in bounded dynamic
|
||||||
|
// shapes, and returns static shapes with dynamic shapes updated.
|
||||||
|
// The shape of the buffer also have to be compatible with the host shape and
|
||||||
|
// device shape.
|
||||||
|
virtual Status ReadDynamicShapes(se::Stream* stream,
|
||||||
|
ShapedBuffer* device_buffer,
|
||||||
|
Shape* host_shape, Shape* device_shape);
|
||||||
|
|
||||||
// Transfers the given literal into the Infeed interface of the device,
|
// Transfers the given literal into the Infeed interface of the device,
|
||||||
// using the given executor.
|
// using the given executor.
|
||||||
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
||||||
|
@ -264,86 +264,28 @@ Status UpdateDynamicInputs(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::StatusOr<xla::Literal> ReadMetadataLiteral(
|
|
||||||
se::Stream* stream, se::DeviceMemoryBase buffer,
|
|
||||||
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
|
|
||||||
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
|
||||||
stream->parent()->platform()));
|
|
||||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
|
||||||
xla::Shape buffer_shape_static =
|
|
||||||
xla::ShapeUtil::MakeStaticShape(buffer_shape);
|
|
||||||
const int64 offset = shape_size_fn(buffer_shape_static);
|
|
||||||
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
|
||||||
TF_RET_CHECK(metadata_size != 0);
|
|
||||||
auto buffer_8 = se::DeviceMemory<uint8>(buffer);
|
|
||||||
auto metadata_buffer =
|
|
||||||
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
|
||||||
return transfer_manager->TransferArrayFromDevice(
|
|
||||||
stream,
|
|
||||||
xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}),
|
|
||||||
metadata_buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
// For each subshape in the result buffer that's dynamic, read the dynamic
|
|
||||||
// dimension sizes from the metadata, and update output shapes. The result shape
|
|
||||||
// is a static and concrete shape.
|
|
||||||
xla::Status UpdateDynamicOutputs(se::Stream* stream,
|
|
||||||
const xla::ShapedBuffer& shaped_buffer,
|
|
||||||
xla::Shape* output_host_shape,
|
|
||||||
xla::Shape* output_device_shape) {
|
|
||||||
DCHECK(output_device_shape->is_dynamic());
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
auto transfer_manager,
|
|
||||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
|
||||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
|
||||||
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachElementWithStatus(
|
|
||||||
[&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
|
|
||||||
const xla::Shape& buffer_shape =
|
|
||||||
xla::ShapeUtil::GetSubshape(*output_device_shape, index);
|
|
||||||
if (buffer_shape.IsTuple()) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
xla::Shape& host_shape =
|
|
||||||
*xla::ShapeUtil::GetMutableSubshape(output_host_shape, index);
|
|
||||||
xla::Shape& device_shape =
|
|
||||||
*xla::ShapeUtil::GetMutableSubshape(output_device_shape, index);
|
|
||||||
if (device_shape.is_static()) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
TF_ASSIGN_OR_RETURN(auto metadata,
|
|
||||||
ReadMetadataLiteral(stream, buffer, buffer_shape,
|
|
||||||
transfer_manager));
|
|
||||||
// Update shape size from metadata.
|
|
||||||
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
|
||||||
host_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
|
||||||
device_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}));
|
|
||||||
output_host_shape->clear_dynamic_dimensions();
|
|
||||||
output_device_shape->clear_dynamic_dimensions();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
||||||
se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
|
se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
|
||||||
int device_ordinal) {
|
int device_ordinal) {
|
||||||
XRTTupleAllocation* output_tuple;
|
XRTTupleAllocation* output_tuple;
|
||||||
const xla::ScopedShapedBuffer& shaped_buffer = run_result.Result();
|
xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult();
|
||||||
if (shaped_buffer.on_device_shape().is_dynamic()) {
|
if (shaped_buffer->on_device_shape().is_dynamic()) {
|
||||||
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
||||||
// dimension sizes read from metadata.
|
// dimension sizes read from metadata.
|
||||||
xla::Shape output_host_shape = shaped_buffer.on_host_shape();
|
xla::Shape output_host_shape = shaped_buffer->on_host_shape();
|
||||||
xla::Shape output_device_shape = shaped_buffer.on_device_shape();
|
xla::Shape output_device_shape = shaped_buffer->on_device_shape();
|
||||||
TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto transfer_manager,
|
||||||
|
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||||
|
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||||
stream, shaped_buffer, &output_host_shape, &output_device_shape));
|
stream, shaped_buffer, &output_host_shape, &output_device_shape));
|
||||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||||
shaped_buffer, output_host_shape, output_device_shape, backend,
|
*shaped_buffer, output_host_shape, output_device_shape, backend,
|
||||||
device_ordinal, &output_tuple));
|
device_ordinal, &output_tuple));
|
||||||
} else {
|
} else {
|
||||||
// Fast-path: Don't copy shapes of output buffer.
|
// Fast-path: Don't copy shapes of output buffer.
|
||||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||||
shaped_buffer, backend, device_ordinal, &output_tuple));
|
*shaped_buffer, backend, device_ordinal, &output_tuple));
|
||||||
}
|
}
|
||||||
// After the output tuple is created, we can release the output result
|
// After the output tuple is created, we can release the output result
|
||||||
// buffers, to make sure they won't be cleared by its destructor.
|
// buffers, to make sure they won't be cleared by its destructor.
|
||||||
|
@ -123,6 +123,15 @@ class TPUTest(test.TestCase):
|
|||||||
result = bar() + 1
|
result = bar() + 1
|
||||||
self.assertAllEqual(result, 2)
|
self.assertAllEqual(result, 2)
|
||||||
|
|
||||||
|
def test_on_demand_op_with_dynamic_output(self):
|
||||||
|
with ops.device("/device:TPU:0"):
|
||||||
|
where_output = array_ops.where([True, False, True])
|
||||||
|
self.assertAllEqual(where_output, [[0], [2]])
|
||||||
|
|
||||||
|
with ops.device("/device:TPU:0"):
|
||||||
|
repeat_output = array_ops.repeat(math_ops.range(2), [1, 4])
|
||||||
|
self.assertAllEqual(repeat_output, [0, 1, 1, 1, 1])
|
||||||
|
|
||||||
|
|
||||||
@parameterized.named_parameters([("PackedVar", True), ("", False)])
|
@parameterized.named_parameters([("PackedVar", True), ("", False)])
|
||||||
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user