[XLA] Basic (R1) support for CPU bounded dynamic shapes.

- Add dynamic tensor metadata read/write in XRT.
- Implement two custom calls: PadToStatic and SliceToDynamic -- R1 only.
- Some helper functions in shape util to do sanity check.
- Tests -- R1 Only.

PiperOrigin-RevId: 311398639
Change-Id: I7129fd13f4e0a2b7a14efb52eb814f753a15e05e
This commit is contained in:
Yunxing Dai 2020-05-13 13:57:37 -07:00 committed by TensorFlower Gardener
parent 8d1e8b350c
commit 5fee245d9f
12 changed files with 726 additions and 23 deletions

View File

@ -331,6 +331,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",

View File

@ -146,6 +146,7 @@ cc_library(
"//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:convolution_group_converter",
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:dynamic_padder",
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",

View File

@ -72,6 +72,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -239,7 +240,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
HloPassPipeline pipeline("HLO passes through layout assignment");
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
// Expand random number generation.
pipeline.AddPass<RngExpander>();
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
@ -273,6 +273,13 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<ConvolutionGroupConverter>(
cost_model,
/*convert_batch_groups_only=*/false);
pipeline.AddPass<ScatterExpander>();
pipeline.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
pipeline.AddPass<DynamicPadder>();
pipeline.AddPass<HloGetDimensionSizeRewriter>();
pipeline.AddPass<ConvCanonicalization>(target_machine_features);
{
auto& pass =
@ -281,12 +288,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
/*allow_mixed_precision=*/false);
pass.AddPass<TreeReductionRewriter>();
pass.AddPass<ScatterExpander>();
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
pipeline.AddPass<HloGetDimensionSizeRewriter>();
AlgebraicSimplifierOptions options;
options.set_enable_dot_strength_reduction(false);
pass.AddPass<AlgebraicSimplifier>(options);

View File

@ -363,7 +363,12 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
if (shape.IsOpaque()) {
return sizeof(void*);
}
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
if (shape.is_static() || shape.IsTuple()) {
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
}
// Each dynamic dimension size is represented as a S32.
int64 metadata_size = sizeof(int32) * shape.dimensions_size();
return ShapeUtil::ByteSizeOf(shape, sizeof(void*)) + metadata_size;
}
const InstructionValueSet& CpuExecutable::GetRootValueSet() const {

View File

@ -2357,7 +2357,95 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
return Status::OK();
}
Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
// TODO(jackcao): Generalize this to generic llvm emitter.
TF_RET_CHECK(hlo->shape().rank() == 1);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
for (int64 i = 1; i < hlo->operand_count(); ++i) {
const int64 dim_index = i - 1;
llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i));
llvm::LoadInst* dim_size = b_.CreateLoad(source_buffer, "dim_size");
llvm::Value* dest_buffer = GetEmittedValueFor(hlo);
llvm::Value* raw_buffer =
b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
int32 raw_data_size =
ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape()));
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
b_.CreateStore(dim_size,
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
}
return EmitTargetElementLoop(hlo,
[=](const llvm_ir::IrArray::Index& dest_index) {
// TODO(jackcao): Properly linearize dest_index
// and delinearize to source index.
return GetIrArrayFor(hlo->operand(0))
.EmitReadArrayElement(dest_index, &b_);
});
}
Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
// TODO(jackcao): Generalize this to generic llvm emitter.
TF_RET_CHECK(hlo->operand(0)->shape().rank() == 1);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
assignment_.GetUniqueSlice(hlo, {0}));
const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0});
llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
llvm_ir::IrArray data_array(data_address, data_shape);
TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter(
[=](const llvm_ir::IrArray::Index& dest_index) {
// TODO(jackcao): Properly linearize dest_index and
// delinearize to source index.
return GetIrArrayFor(hlo->operand(0))
.EmitReadArrayElement(dest_index, &b_);
},
llvm_ir::IrArray(data_address, data_shape), &b_)
.EmitLoop(IrName(hlo)));
std::vector<llvm::Value*> tuple_operand_ptrs;
tuple_operand_ptrs.push_back(data_array.GetBasePointer());
// PadToStatic has a dynamic tensor as input and variadic size of outputs:
// (static_tensor, dynamic_dim_0, dynamic_dim_1, ... )
// Dynamic dimension sizes starts from output index 1.
for (int64 i = 1; i < hlo->shape().tuple_shapes_size(); ++i) {
// Read from the metadata section of the dynamic input (operand 0).
const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i});
TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice,
assignment_.GetUniqueSlice(hlo, {i}));
llvm::Value* dest_dim_size_address =
EmitBufferPointer(dim_size_slice, data_shape);
const int64 dim_index = i - 1;
llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0));
llvm::Value* raw_buffer =
b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
int32 raw_data_size = ShapeUtil::ByteSizeOf(
ShapeUtil::MakeStaticShape(hlo->operand(0)->shape()));
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
llvm::Value* dim_size = b_.CreateLoad(
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
b_.CreateStore(dim_size, b_.CreateBitCast(dest_dim_size_address,
b_.getInt32Ty()->getPointerTo()));
tuple_operand_ptrs.push_back(dest_dim_size_address);
}
// Emit static tensor and dynamic sizes as one tuple.
llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_);
return Status::OK();
}
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
if (custom_call->custom_call_target() == "PadToStatic") {
return HandlePadToStatic(custom_call);
}
if (custom_call->custom_call_target() == "SliceToDynamic") {
return HandleSliceToDynamic(custom_call);
}
absl::Span<HloInstruction* const> operands(custom_call->operands());
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
llvm::AllocaInst* operands_alloca =

View File

@ -183,6 +183,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
}
private:
Status HandleSliceToDynamic(HloInstruction* hlo);
Status HandlePadToStatic(HloInstruction* hlo);
Status HandleAllReduceSingleReplica(HloInstruction* crs);
Status HandleAllReduceMultipleReplica(HloInstruction* crs);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -93,6 +94,18 @@ class ShapedBuffer {
buffers_.replace_shape_ptr(&on_device_shape_);
}
// Reset the shape of this shaped buffer and underlying buffer structure.
//
// Precondition: EqualStructure(this->on_device_shape_, on_device_shape).
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
<< "Structures are not the same. new: " << on_device_shape
<< ", old: " << on_device_shape_;
on_host_shape_ = on_host_shape;
on_device_shape_ = on_device_shape;
buffers_.replace_shape_ptr(&on_device_shape_);
}
// Returns the underlying ShapeTree containing all the device addresses in the
// ShapedBuffer.
const ShapeTree<se::DeviceMemoryBase>& buffers() const { return buffers_; }

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/ascii.h"
#include "absl/strings/numbers.h"
@ -150,6 +151,19 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
return equal;
}
/* static */ bool ShapeUtil::EqualStructure(const Shape& lhs,
const Shape& rhs) {
bool equal = true;
ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
equal &= IndexIsValid(rhs, index);
});
ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
equal &= IndexIsValid(lhs, index);
});
return equal;
}
/* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
int64 accum = 0;
for (int64 dimension : shape.dimensions()) {
@ -261,6 +275,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return ValidateShape(*shape);
}
/* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) {
Shape result = original;
result.clear_dynamic_dimensions();
return result;
}
/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
Shape result;
result.set_element_type(TUPLE);
@ -626,8 +646,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
if (shape.element_type() == TUPLE) {
return ByteSizeOfTupleIndexTable(shape, pointer_size);
} else if (shape.IsArray()) {
int64 byte_size = ByteSizeOfElements(shape);
return byte_size;
return ByteSizeOfElements(shape);
} else if (shape.element_type() == TOKEN) {
return 0;
} else if (shape.element_type() == OPAQUE_TYPE) {
@ -1441,6 +1460,19 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified(
return shape;
}
/* static */ bool ShapeUtil::DynamicShapeIsCompatible(
const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
if (dynamic_shape.rank() != bounded_shape.rank()) {
return false;
}
for (int64 i = 0; i < dynamic_shape.rank(); ++i) {
if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
return false;
}
}
return true;
}
/* static */ Shape ShapeUtil::FilterDimensions(
const std::function<bool(int64)>& p, Shape shape) {
CHECK(shape.IsArray());

View File

@ -298,6 +298,16 @@ class ShapeUtil {
// As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
// Two shapes have same structure if all subshape indices of lhs are presented
// on rhs and vice versa.
// A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to
// (S32, (F32[3], S32[2])) as their structures are both (,(,))
//
// In contrast, (F32, (F32, F32)) is structurally different from
// ((F32, F32), F32) as the former has structure (,(,)) while the latter has
// ((,),)
static bool EqualStructure(const Shape& lhs, const Shape& rhs);
// Returns the number of dimensions for which the dimension is not (trivially)
// 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
// fluff. Note that zero dimensions are included in the true rank, e.g.,
@ -339,6 +349,9 @@ class ShapeUtil {
// element type changed to type.
static Shape ChangeElementType(const Shape& original, PrimitiveType type);
// Retursn a shape with same dimensions but with all dimensions set to static.
static Shape MakeStaticShape(const Shape& original);
// Creates a tuple shape from a slice of element shapes within the tuple.
static Shape MakeTupleShape(absl::Span<const Shape> shapes);
@ -643,12 +656,16 @@ class ShapeUtil {
static Shape FilterDimensions(const std::function<bool(int64)>& p,
Shape shape);
// Iterates through all the shape indexes, in minor to major order, starting
// from the base indexes, incrementing by the incr steps, up to count
// (index[i] < base[i] + count[i]), and calls the visitor_function with the
// current index.
// The visitor_function visitor function should return true if it wants to
// continue, or false otherwise.
// Returns true if `dynamic_shape` has dimensions that are less-equal to the
// "bounded_shape".
static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
const xla::Shape& bounded_shape);
// Iterates through all the shape indexes, in minor to major order,
// starting from the base indexes, incrementing by the incr steps, up to
// count (index[i] < base[i] + count[i]), and calls the visitor_function
// with the current index. The visitor_function visitor function should
// return true if it wants to continue, or false otherwise.
//
// visitor_function must be a callable of type
// StatusOr<bool>(absl::Span<int64>) or compatible.

View File

@ -49,6 +49,7 @@ cc_library(
deps = [
":xrt_state_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
@ -38,7 +39,11 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/monitoring/timed.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/stream_executor.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
@ -146,6 +151,231 @@ xla::StatusOr<InputBuffers> GetChainedOpInputs(
return std::move(input_buffers);
}
// Given a shape, returns a byte array representing the shape metadata of the
// shape. The shape metadata contains dimensions sizes stored as contiguous S32.
std::vector<int32> PrepareMetadata(const xla::Shape& shape) {
DCHECK(shape.is_static());
DCHECK(shape.IsArray());
// Each dimension size is stored as a S32.
std::vector<int32> result(shape.dimensions_size());
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
result[i] = shape.dimensions(i);
}
return result;
}
// Given a buffer with dynamic shape, update buffer metadata at the correct
// offset starting from that buffer.
//
// +-----------+
// |Payload |
// +-----------+
// | Padding |
// +-----------+
// |dim_size_0 | (each dim_size is a S32):
// +-----------+
// |dim_size_1 |
// +-----------+
// ..........
// +-----------+
//
// Size of payload = ByteSizeOf(runtime_shape)
// Size of payload + padding = ByteSizeOf(compile_time_shape_static)
// Size of payload + padding + metadata = ByteSizeOf(compile_time_shape)
Status UpdateMetadata(se::Stream* stream, se::DeviceMemory<uint8>* buffer,
const xla::Shape& compile_time_shape,
const xla::Shape& runtime_shape) {
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
stream->parent()->platform()));
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
xla::Shape compile_time_shape_static =
xla::ShapeUtil::MakeStaticShape(compile_time_shape);
uint64 offset = shape_size_fn(compile_time_shape_static);
uint64 metadata_size = shape_size_fn(compile_time_shape) - offset;
auto metadata_buffer =
stream->parent()->GetSubBuffer(buffer, offset, metadata_size);
auto metadata_literal = std::make_shared<xla::Literal>(
xla::LiteralUtil::CreateR1<int32>(PrepareMetadata(runtime_shape)));
TF_RETURN_IF_ERROR(transfer_manager->TransferArrayToDeviceAsync(
stream, *metadata_literal, metadata_buffer));
// Retain the literal until the end of the transfer.
stream->ThenDoHostCallback([metadata_literal]() { return Status::OK(); });
return Status::OK();
}
// Given a static input buffer, convert it to dynamic form by expanding it to
// the bounded size and attaching a metadata filled with dimension sizes.
//
// From:
// +--------+
// |Payload |
// +--------+
//
// To:
//
// +--------+
// |Payload |
// +--------+
// | Padding|
// +--------+
// |Metadata|
// +--------+
//
// As we can't expand the size of an existing memory allocation, a reallocation
// is required. A list of new allocations are returned after this function. The
// caller is reponsible for maintaining those allocations.
xla::StatusOr<std::vector<se::OwningDeviceMemory>> UpdateDynamicInputs(
se::Stream* stream, se::DeviceMemoryAllocator* allocator,
std::vector<xla::ShapedBuffer*> runtime_inputs,
const std::vector<xla::ShapeLayout>& compile_time_shapes) {
std::vector<se::OwningDeviceMemory> new_allocations;
TF_RET_CHECK(runtime_inputs.size() == compile_time_shapes.size());
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
stream->parent()->platform()));
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
for (int64 i = 0; i < compile_time_shapes.size(); i++) {
const xla::Shape& compile_time_shape = compile_time_shapes[i].shape();
if (compile_time_shape.is_static()) {
continue;
}
auto* runtime_input = runtime_inputs[i];
bool element_modified = false;
TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
compile_time_shape,
[&](const xla::Shape& compile_time_shape,
const xla::ShapeIndex& index) -> Status {
if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) {
return Status::OK();
}
const xla::Shape& runtime_shape = xla::ShapeUtil::GetSubshape(
runtime_input->on_device_shape(), index);
TF_RET_CHECK(!runtime_shape.IsTuple());
TF_RET_CHECK(xla::ShapeUtil::DynamicShapeIsCompatible(
runtime_shape, compile_time_shape));
se::DeviceMemoryBase* static_input =
runtime_input->buffers().mutable_element(index);
TF_ASSIGN_OR_RETURN(
auto dynamic_input,
allocator->Allocate(stream->parent()->device_ordinal(),
shape_size_fn(compile_time_shape)));
new_allocations.emplace_back(std::move(dynamic_input));
se::DeviceMemory<uint8>* dynamic_input_base =
new_allocations.back().ptr();
// Send the original data to the new location.
stream->ThenMemcpyD2D(dynamic_input_base, *static_input,
static_input->size());
TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base,
compile_time_shape, runtime_shape));
// Modify the memory location in the input shape tree to point to the
// new input.
runtime_input->set_buffer(*dynamic_input_base, index);
element_modified = true;
return Status::OK();
}));
if (element_modified) {
runtime_input->set_shapes(compile_time_shape, compile_time_shape);
// The input location has been modified, need to fix tuple table to
// point to the correct address.
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
TF_RETURN_IF_ERROR(
transfer_manager->WriteTupleIndexTablesAsync(stream, *runtime_input));
}
}
return std::move(new_allocations);
}
xla::StatusOr<xla::Literal> ReadMetadataLiteral(
se::Stream* stream, se::DeviceMemoryBase* buffer,
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
stream->parent()->platform()));
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
xla::Shape buffer_shape_static =
xla::ShapeUtil::MakeStaticShape(buffer_shape);
const int64 offset = shape_size_fn(buffer_shape_static);
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
TF_RET_CHECK(metadata_size != 0);
auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
auto metadata_buffer =
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
return transfer_manager->TransferArrayFromDevice(
stream,
xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}),
metadata_buffer);
}
// For each subshape in the result buffer that's dynamic, read the dynamic
// dimension sizes from the metadata, and update output shapes. The result shape
// is a static and concrete shape.
xla::Status UpdateDynamicOutputs(se::Stream* stream,
xla::ShapedBuffer* shaped_buffer,
xla::Shape* output_host_shape,
xla::Shape* output_device_shape) {
DCHECK(output_device_shape->is_dynamic());
TF_ASSIGN_OR_RETURN(
auto transfer_manager,
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
const xla::Shape& buffer_shape =
xla::ShapeUtil::GetSubshape(*output_device_shape, index);
if (buffer_shape.IsTuple()) {
return Status::OK();
}
xla::Shape& host_shape =
*xla::ShapeUtil::GetMutableSubshape(output_host_shape, index);
xla::Shape& device_shape =
*xla::ShapeUtil::GetMutableSubshape(output_device_shape, index);
if (device_shape.is_static()) {
return Status::OK();
}
TF_ASSIGN_OR_RETURN(auto metadata,
ReadMetadataLiteral(stream, buffer, buffer_shape,
transfer_manager));
// Update shape size from metadata.
for (int64 i = 0; i < metadata.element_count(); ++i) {
host_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
device_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
}
return Status::OK();
}));
output_host_shape->clear_dynamic_dimensions();
output_device_shape->clear_dynamic_dimensions();
return Status::OK();
}
// Create output tuple from run_result.
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
se::Stream* stream, xla::ScopedShapedBuffer run_result,
xla::Backend* backend, int device_ordinal) {
XRTTupleAllocation* output_tuple;
xla::ShapedBuffer shaped_buffer = run_result.release();
if (shaped_buffer.on_device_shape().is_dynamic()) {
// Update dynamic shapes from output buffer, and create a XRT tensor with
// dimension sizes read from metadata.
xla::Shape output_host_shape = shaped_buffer.on_host_shape();
xla::Shape output_device_shape = shaped_buffer.on_device_shape();
TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
stream, &shaped_buffer, &output_host_shape, &output_device_shape));
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
shaped_buffer, output_host_shape, output_device_shape, backend,
device_ordinal, &output_tuple));
} else {
// Fast-path: Don't copy shapes of output buffer.
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
shaped_buffer, backend, device_ordinal, &output_tuple));
}
return RefPtr<XRTTupleAllocation>(output_tuple);
}
xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
@ -191,18 +421,31 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
Env* env = Env::Default();
auto start_time = env->NowMicros();
const std::vector<xla::ShapeLayout>& shape_layouts =
executable->executable()
->module_config()
.entry_computation_layout()
.parameter_layouts();
TF_ASSIGN_OR_RETURN(auto new_allocations,
UpdateDynamicInputs(stream, run_options.allocator(),
input_buffers.input_pointers,
shape_layouts));
auto new_allocations_ptr =
std::make_shared<std::vector<se::OwningDeviceMemory>>(
std::move(new_allocations));
TF_ASSIGN_OR_RETURN(
xla::ScopedShapedBuffer run_result,
executable->Run(input_buffers.input_pointers, run_options));
// Retain the new allocation for input memory until the end of execution.
stream->ThenDoHostCallback([new_allocations_ptr]() { return Status::OK(); });
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
auto shaped_buffer = run_result.release();
XRTTupleAllocation* output_tuple;
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
shaped_buffer, device_ref->backend(), device_ref->device_ordinal(),
&output_tuple));
RefPtr<XRTTupleAllocation> output_tuple_ptr(output_tuple);
TF_ASSIGN_OR_RETURN(
RefPtr<XRTTupleAllocation> output_tuple_ptr,
CreateOutputTuple(stream, std::move(run_result), device_ref->backend(),
device_ref->device_ordinal()));
// The ScopedShapedBuffer returned by the executable Run() API, in case of
// input/output buffer aliasing, might have holes in it, which need to be
@ -215,7 +458,7 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size());
return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
? output_tuple->AliasBufferFrom(
? output_tuple_ptr->AliasBufferFrom(
*input_buffers.input_tuples[alias.parameter_number],
alias.parameter_index, output_index)
: Status::OK();

View File

@ -49,6 +49,67 @@ limitations under the License.
namespace tensorflow {
namespace {
xla::XlaComputation ReturnDynamicR1() {
xla::XlaBuilder builder("ReturnDynamicR1");
auto p0 = xla::Parameter(&builder, 0,
xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
auto p1 = xla::Parameter(&builder, 1,
xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
"P2");
auto sum = xla::Add(p0, p1);
auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
return builder.Build(pad_sum).ValueOrDie();
}
xla::XlaComputation AcceptDynamicR1() {
xla::XlaBuilder builder("AcceptDynamicR1");
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
dyn_shape.set_dynamic_dimension(0, true);
auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1");
auto sum = xla::Add(p0, p1);
return builder.Build(sum).ValueOrDie();
}
xla::XlaComputation ReturnDynamicR1Tuple() {
xla::XlaBuilder builder("ReturnDynamicR1Tuple");
auto p0 = xla::Parameter(&builder, 0,
xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
auto p1 = xla::Parameter(&builder, 1,
xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
"P2");
auto sum = xla::Add(p0, p1);
auto sub = xla::Sub(p0, p1);
auto one = xla::One(&builder, xla::S32);
auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0);
auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub});
return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie();
}
xla::XlaComputation AcceptDynamicR1Tuple() {
xla::XlaBuilder builder("AcceptDynamicR1");
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
dyn_shape.set_dynamic_dimension(0, true);
xla::Shape tuple_shape =
xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
xla::Shape nest_tuple_shape =
xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
auto p = xla::Parameter(&builder, 0, tuple_shape, "P0");
auto p0 = xla::GetTupleElement(p, 0);
auto p1 = xla::GetTupleElement(p, 1);
auto sum = xla::Add(p0, p1);
return builder.Build(sum).ValueOrDie();
}
template <typename T>
xla::LiteralProto CreateR0(T v) {
auto array = xla::LiteralUtil::CreateR0<T>(v);
return array.ToProto();
}
class XrtClientSession : public ClientSession {
public:
explicit XrtClientSession(const Scope& scope) : ClientSession(scope) {
@ -61,6 +122,11 @@ class XrtClientSession : public ClientSession {
string* xla_test_device_ptr; // initial value set in main()
string* xla_platform_ptr; // initial value set in main()
bool SupportDynamicShapes() {
// TODO(jackcao): Support dynamic shapes on XLA GPU.
return *xla_test_device_ptr != "XLA_GPU";
}
string DeviceFromFlag() {
string xla_test_device = *xla_test_device_ptr;
return absl::StrCat("/device:", xla_test_device, ":0");
@ -1035,6 +1101,239 @@ TEST(RawApiTest, CompileAndExecute) {
EXPECT_EQ(program_shape.parameters_size(), 2);
}
TEST(RawApiTest, DynamicR1Test) {
if (!SupportDynamicShapes()) {
return;
}
xrt::XLAAllocation p0;
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
xrt::XLAAllocation p1;
*p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f});
xrt::XLAAllocation p2;
*p2.mutable_value() = CreateR0<xla::int32>(2);
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
dyn_shape.set_dynamic_dimension(0, true);
*shapes->mutable_result() = dyn_shape.ToProto();
StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
e.set_release_input_handles(true);
e.set_release_compilation_handle(true);
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
Scope cpu_root = root.WithDevice("/device:CPU:0");
auto e_config = ops::Const(cpu_root, e.SerializeAsString());
auto computation = ops::Const(cpu_root, c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, computation);
auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
auto p0_handle = ops::XRTAllocate(root, p0_value);
auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
auto p1_handle = ops::XRTAllocate(root, p1_value);
auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
auto p2_handle = ops::XRTAllocate(root, p2_value);
auto result = ops::XRTExecute(
root, c_handle.handle, e_config,
{Output(p0_handle), Output(p1_handle), Output(p2_handle)});
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
TF_ASSERT_OK(root.status());
XrtClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
TEST(RawApiTest, DynamicR1TupleTest) {
if (!SupportDynamicShapes()) {
return;
}
xrt::XLAAllocation p0;
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
xrt::XLAAllocation p1;
*p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f});
xrt::XLAAllocation p2;
*p2.mutable_value() = CreateR0<xla::int32>(2);
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
dyn_shape.set_dynamic_dimension(0, true);
*shapes->mutable_result() =
xla::ShapeUtil::MakeTupleShape(
{dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape})
.ToProto();
StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
e.set_release_input_handles(true);
e.set_release_compilation_handle(true);
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
Scope cpu_root = root.WithDevice("/device:CPU:0");
auto e_config = ops::Const(cpu_root, e.SerializeAsString());
auto computation = ops::Const(cpu_root, c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, computation);
auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
auto p0_handle = ops::XRTAllocate(root, p0_value);
auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
auto p1_handle = ops::XRTAllocate(root, p1_value);
auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
auto p2_handle = ops::XRTAllocate(root, p2_value);
auto result = ops::XRTExecute(
root, c_handle.handle, e_config,
{Output(p0_handle), Output(p1_handle), Output(p2_handle)});
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
TF_ASSERT_OK(root.status());
XrtClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
auto expected0 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
auto expected1 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f, 0.0f});
auto expected2 = xla::LiteralUtil::CreateR1<float>({0.0f, 3.0f, 1.0f});
auto expected =
xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
TEST(RawApiTest, AcceptDynamicR1TupleTest) {
if (!SupportDynamicShapes()) {
return;
}
xrt::XLAAllocation p0;
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
xrt::XLAAllocation p1;
*p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
xrt::XLATupleNode tuple_desc;
auto subdesc_10 = tuple_desc.add_tuples();
auto subdesc_11 = tuple_desc.add_tuples();
subdesc_10->set_input_index(0);
subdesc_10->set_release_input_handle(true);
subdesc_11->set_input_index(1);
subdesc_11->set_release_input_handle(true);
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
dyn_input_shape.set_dynamic_dimension(0, true);
xla::Shape dyn_tuple_shape =
xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape});
*shapes->add_parameters() = dyn_tuple_shape.ToProto();
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
dyn_shape.set_dynamic_dimension(0, true);
*shapes->mutable_result() = dyn_shape.ToProto();
StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
e.set_release_input_handles(true);
e.set_release_compilation_handle(true);
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
Scope cpu_root = root.WithDevice("/device:CPU:0");
auto e_config = ops::Const(cpu_root, e.SerializeAsString());
auto computation = ops::Const(cpu_root, c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, computation);
auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
auto p0_handle = ops::XRTAllocate(root, p0_value);
auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
auto p1_handle = ops::XRTAllocate(root, p1_value);
auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"),
tuple_desc.SerializeAsString());
auto t0_handle = ops::XRTMakeTuple(
root, tuple_0,
{static_cast<Output>(p0_handle), static_cast<Output>(p1_handle)});
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
{static_cast<Output>(t0_handle)});
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
TF_ASSERT_OK(root.status());
XrtClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
TEST(RawApiTest, AcceptDynamicR1Test) {
if (!SupportDynamicShapes()) {
return;
}
xrt::XLAAllocation p0;
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
xrt::XLAAllocation p1;
*p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
dyn_input_shape.set_dynamic_dimension(0, true);
*shapes->add_parameters() = dyn_input_shape.ToProto();
*shapes->add_parameters() = dyn_input_shape.ToProto();
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
dyn_shape.set_dynamic_dimension(0, true);
*shapes->mutable_result() = dyn_shape.ToProto();
StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
e.set_release_input_handles(true);
e.set_release_compilation_handle(true);
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
Scope cpu_root = root.WithDevice("/device:CPU:0");
auto e_config = ops::Const(cpu_root, e.SerializeAsString());
auto computation = ops::Const(cpu_root, c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, computation);
auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
auto allocate_op_0 = ops::XRTAllocate(root, p0_value);
auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
auto allocate_op_1 = ops::XRTAllocate(root, p1_value);
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
{Output(allocate_op_0), Output(allocate_op_1)});
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
TF_ASSERT_OK(root.status());
XrtClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
xrt::XLAAllocation p0;
*p0.mutable_value() = FloatVector({1.0f, 2.0f});