[XLA] Basic (R1) support for CPU bounded dynamic shapes.
- Add dynamic tensor metadata read/write in XRT. - Implement two custom calls: PadToStatic and SliceToDynamic -- R1 only. - Some helper functions in shape util to do sanity check. - Tests -- R1 Only. PiperOrigin-RevId: 311398639 Change-Id: I7129fd13f4e0a2b7a14efb52eb814f753a15e05e
This commit is contained in:
parent
8d1e8b350c
commit
5fee245d9f
@ -331,6 +331,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -146,6 +146,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||
"//tensorflow/compiler/xla/service:dynamic_padder",
|
||||
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
|
@ -72,6 +72,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -239,7 +240,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
HloPassPipeline pipeline("HLO passes through layout assignment");
|
||||
pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
|
||||
// Expand random number generation.
|
||||
pipeline.AddPass<RngExpander>();
|
||||
pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
|
||||
@ -273,6 +273,13 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
pipeline.AddPass<ConvolutionGroupConverter>(
|
||||
cost_model,
|
||||
/*convert_batch_groups_only=*/false);
|
||||
pipeline.AddPass<ScatterExpander>();
|
||||
pipeline.AddPass<BatchNormExpander>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true);
|
||||
pipeline.AddPass<DynamicPadder>();
|
||||
pipeline.AddPass<HloGetDimensionSizeRewriter>();
|
||||
pipeline.AddPass<ConvCanonicalization>(target_machine_features);
|
||||
{
|
||||
auto& pass =
|
||||
@ -281,12 +288,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
/*allow_mixed_precision=*/false);
|
||||
|
||||
pass.AddPass<TreeReductionRewriter>();
|
||||
pass.AddPass<ScatterExpander>();
|
||||
pass.AddPass<BatchNormExpander>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true);
|
||||
pipeline.AddPass<HloGetDimensionSizeRewriter>();
|
||||
AlgebraicSimplifierOptions options;
|
||||
options.set_enable_dot_strength_reduction(false);
|
||||
pass.AddPass<AlgebraicSimplifier>(options);
|
||||
|
@ -363,7 +363,12 @@ StatusOr<ExecutionOutput> CpuExecutable::ExecuteAsyncOnStream(
|
||||
if (shape.IsOpaque()) {
|
||||
return sizeof(void*);
|
||||
}
|
||||
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
|
||||
if (shape.is_static() || shape.IsTuple()) {
|
||||
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
|
||||
}
|
||||
// Each dynamic dimension size is represented as a S32.
|
||||
int64 metadata_size = sizeof(int32) * shape.dimensions_size();
|
||||
return ShapeUtil::ByteSizeOf(shape, sizeof(void*)) + metadata_size;
|
||||
}
|
||||
|
||||
const InstructionValueSet& CpuExecutable::GetRootValueSet() const {
|
||||
|
@ -2357,7 +2357,95 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) {
|
||||
// TODO(jackcao): Generalize this to generic llvm emitter.
|
||||
TF_RET_CHECK(hlo->shape().rank() == 1);
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
|
||||
for (int64 i = 1; i < hlo->operand_count(); ++i) {
|
||||
const int64 dim_index = i - 1;
|
||||
llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i));
|
||||
llvm::LoadInst* dim_size = b_.CreateLoad(source_buffer, "dim_size");
|
||||
llvm::Value* dest_buffer = GetEmittedValueFor(hlo);
|
||||
llvm::Value* raw_buffer =
|
||||
b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
|
||||
|
||||
int32 raw_data_size =
|
||||
ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape()));
|
||||
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
|
||||
b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
|
||||
b_.CreateStore(dim_size,
|
||||
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
|
||||
}
|
||||
|
||||
return EmitTargetElementLoop(hlo,
|
||||
[=](const llvm_ir::IrArray::Index& dest_index) {
|
||||
// TODO(jackcao): Properly linearize dest_index
|
||||
// and delinearize to source index.
|
||||
return GetIrArrayFor(hlo->operand(0))
|
||||
.EmitReadArrayElement(dest_index, &b_);
|
||||
});
|
||||
}
|
||||
|
||||
Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) {
|
||||
// TODO(jackcao): Generalize this to generic llvm emitter.
|
||||
TF_RET_CHECK(hlo->operand(0)->shape().rank() == 1);
|
||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice,
|
||||
assignment_.GetUniqueSlice(hlo, {0}));
|
||||
const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0});
|
||||
llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape);
|
||||
llvm_ir::IrArray data_array(data_address, data_shape);
|
||||
TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter(
|
||||
[=](const llvm_ir::IrArray::Index& dest_index) {
|
||||
// TODO(jackcao): Properly linearize dest_index and
|
||||
// delinearize to source index.
|
||||
return GetIrArrayFor(hlo->operand(0))
|
||||
.EmitReadArrayElement(dest_index, &b_);
|
||||
},
|
||||
llvm_ir::IrArray(data_address, data_shape), &b_)
|
||||
.EmitLoop(IrName(hlo)));
|
||||
std::vector<llvm::Value*> tuple_operand_ptrs;
|
||||
tuple_operand_ptrs.push_back(data_array.GetBasePointer());
|
||||
|
||||
// PadToStatic has a dynamic tensor as input and variadic size of outputs:
|
||||
// (static_tensor, dynamic_dim_0, dynamic_dim_1, ... )
|
||||
// Dynamic dimension sizes starts from output index 1.
|
||||
for (int64 i = 1; i < hlo->shape().tuple_shapes_size(); ++i) {
|
||||
// Read from the metadata section of the dynamic input (operand 0).
|
||||
const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i});
|
||||
TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice,
|
||||
assignment_.GetUniqueSlice(hlo, {i}));
|
||||
llvm::Value* dest_dim_size_address =
|
||||
EmitBufferPointer(dim_size_slice, data_shape);
|
||||
const int64 dim_index = i - 1;
|
||||
llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0));
|
||||
llvm::Value* raw_buffer =
|
||||
b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
|
||||
int32 raw_data_size = ShapeUtil::ByteSizeOf(
|
||||
ShapeUtil::MakeStaticShape(hlo->operand(0)->shape()));
|
||||
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
|
||||
b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
|
||||
llvm::Value* dim_size = b_.CreateLoad(
|
||||
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
|
||||
b_.CreateStore(dim_size, b_.CreateBitCast(dest_dim_size_address,
|
||||
b_.getInt32Ty()->getPointerTo()));
|
||||
tuple_operand_ptrs.push_back(dest_dim_size_address);
|
||||
}
|
||||
|
||||
// Emit static tensor and dynamic sizes as one tuple.
|
||||
llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
if (custom_call->custom_call_target() == "PadToStatic") {
|
||||
return HandlePadToStatic(custom_call);
|
||||
}
|
||||
if (custom_call->custom_call_target() == "SliceToDynamic") {
|
||||
return HandleSliceToDynamic(custom_call);
|
||||
}
|
||||
absl::Span<HloInstruction* const> operands(custom_call->operands());
|
||||
llvm::Type* i8_ptr_type = b_.getInt8PtrTy();
|
||||
llvm::AllocaInst* operands_alloca =
|
||||
|
@ -183,6 +183,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
}
|
||||
|
||||
private:
|
||||
Status HandleSliceToDynamic(HloInstruction* hlo);
|
||||
Status HandlePadToStatic(HloInstruction* hlo);
|
||||
Status HandleAllReduceSingleReplica(HloInstruction* crs);
|
||||
Status HandleAllReduceMultipleReplica(HloInstruction* crs);
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
@ -93,6 +94,18 @@ class ShapedBuffer {
|
||||
buffers_.replace_shape_ptr(&on_device_shape_);
|
||||
}
|
||||
|
||||
// Reset the shape of this shaped buffer and underlying buffer structure.
|
||||
//
|
||||
// Precondition: EqualStructure(this->on_device_shape_, on_device_shape).
|
||||
void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
|
||||
CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
|
||||
<< "Structures are not the same. new: " << on_device_shape
|
||||
<< ", old: " << on_device_shape_;
|
||||
on_host_shape_ = on_host_shape;
|
||||
on_device_shape_ = on_device_shape;
|
||||
buffers_.replace_shape_ptr(&on_device_shape_);
|
||||
}
|
||||
|
||||
// Returns the underlying ShapeTree containing all the device addresses in the
|
||||
// ShapedBuffer.
|
||||
const ShapeTree<se::DeviceMemoryBase>& buffers() const { return buffers_; }
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
@ -150,6 +151,19 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
return equal;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::EqualStructure(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
bool equal = true;
|
||||
ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
equal &= IndexIsValid(rhs, index);
|
||||
});
|
||||
ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) {
|
||||
equal &= IndexIsValid(lhs, index);
|
||||
});
|
||||
|
||||
return equal;
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
|
||||
int64 accum = 0;
|
||||
for (int64 dimension : shape.dimensions()) {
|
||||
@ -261,6 +275,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
return ValidateShape(*shape);
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) {
|
||||
Shape result = original;
|
||||
result.clear_dynamic_dimensions();
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
|
||||
Shape result;
|
||||
result.set_element_type(TUPLE);
|
||||
@ -626,8 +646,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
if (shape.element_type() == TUPLE) {
|
||||
return ByteSizeOfTupleIndexTable(shape, pointer_size);
|
||||
} else if (shape.IsArray()) {
|
||||
int64 byte_size = ByteSizeOfElements(shape);
|
||||
return byte_size;
|
||||
return ByteSizeOfElements(shape);
|
||||
} else if (shape.element_type() == TOKEN) {
|
||||
return 0;
|
||||
} else if (shape.element_type() == OPAQUE_TYPE) {
|
||||
@ -1441,6 +1460,19 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified(
|
||||
return shape;
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::DynamicShapeIsCompatible(
|
||||
const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) {
|
||||
if (dynamic_shape.rank() != bounded_shape.rank()) {
|
||||
return false;
|
||||
}
|
||||
for (int64 i = 0; i < dynamic_shape.rank(); ++i) {
|
||||
if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::FilterDimensions(
|
||||
const std::function<bool(int64)>& p, Shape shape) {
|
||||
CHECK(shape.IsArray());
|
||||
|
@ -298,6 +298,16 @@ class ShapeUtil {
|
||||
// As Equal, but allow one of lhs and rhs to be F16 while the other is F32.
|
||||
static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Two shapes have same structure if all subshape indices of lhs are presented
|
||||
// on rhs and vice versa.
|
||||
// A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to
|
||||
// (S32, (F32[3], S32[2])) as their structures are both (,(,))
|
||||
//
|
||||
// In contrast, (F32, (F32, F32)) is structurally different from
|
||||
// ((F32, F32), F32) as the former has structure (,(,)) while the latter has
|
||||
// ((,),)
|
||||
static bool EqualStructure(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns the number of dimensions for which the dimension is not (trivially)
|
||||
// 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
|
||||
// fluff. Note that zero dimensions are included in the true rank, e.g.,
|
||||
@ -339,6 +349,9 @@ class ShapeUtil {
|
||||
// element type changed to type.
|
||||
static Shape ChangeElementType(const Shape& original, PrimitiveType type);
|
||||
|
||||
// Retursn a shape with same dimensions but with all dimensions set to static.
|
||||
static Shape MakeStaticShape(const Shape& original);
|
||||
|
||||
// Creates a tuple shape from a slice of element shapes within the tuple.
|
||||
static Shape MakeTupleShape(absl::Span<const Shape> shapes);
|
||||
|
||||
@ -643,12 +656,16 @@ class ShapeUtil {
|
||||
static Shape FilterDimensions(const std::function<bool(int64)>& p,
|
||||
Shape shape);
|
||||
|
||||
// Iterates through all the shape indexes, in minor to major order, starting
|
||||
// from the base indexes, incrementing by the incr steps, up to count
|
||||
// (index[i] < base[i] + count[i]), and calls the visitor_function with the
|
||||
// current index.
|
||||
// The visitor_function visitor function should return true if it wants to
|
||||
// continue, or false otherwise.
|
||||
// Returns true if `dynamic_shape` has dimensions that are less-equal to the
|
||||
// "bounded_shape".
|
||||
static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape,
|
||||
const xla::Shape& bounded_shape);
|
||||
|
||||
// Iterates through all the shape indexes, in minor to major order,
|
||||
// starting from the base indexes, incrementing by the incr steps, up to
|
||||
// count (index[i] < base[i] + count[i]), and calls the visitor_function
|
||||
// with the current index. The visitor_function visitor function should
|
||||
// return true if it wants to continue, or false otherwise.
|
||||
//
|
||||
// visitor_function must be a callable of type
|
||||
// StatusOr<bool>(absl::Span<int64>) or compatible.
|
||||
|
@ -49,6 +49,7 @@ cc_library(
|
||||
deps = [
|
||||
":xrt_state_ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
@ -38,7 +39,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/lib/monitoring/timed.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/stream_executor/device_memory.h"
|
||||
#include "tensorflow/stream_executor/device_memory_allocator.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
#include "tensorflow/stream_executor/stream_executor.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_internal.h"
|
||||
|
||||
@ -146,6 +151,231 @@ xla::StatusOr<InputBuffers> GetChainedOpInputs(
|
||||
return std::move(input_buffers);
|
||||
}
|
||||
|
||||
// Given a shape, returns a byte array representing the shape metadata of the
|
||||
// shape. The shape metadata contains dimensions sizes stored as contiguous S32.
|
||||
std::vector<int32> PrepareMetadata(const xla::Shape& shape) {
|
||||
DCHECK(shape.is_static());
|
||||
DCHECK(shape.IsArray());
|
||||
// Each dimension size is stored as a S32.
|
||||
std::vector<int32> result(shape.dimensions_size());
|
||||
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
|
||||
result[i] = shape.dimensions(i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Given a buffer with dynamic shape, update buffer metadata at the correct
|
||||
// offset starting from that buffer.
|
||||
//
|
||||
// +-----------+
|
||||
// |Payload |
|
||||
// +-----------+
|
||||
// | Padding |
|
||||
// +-----------+
|
||||
// |dim_size_0 | (each dim_size is a S32):
|
||||
// +-----------+
|
||||
// |dim_size_1 |
|
||||
// +-----------+
|
||||
// ..........
|
||||
// +-----------+
|
||||
//
|
||||
// Size of payload = ByteSizeOf(runtime_shape)
|
||||
// Size of payload + padding = ByteSizeOf(compile_time_shape_static)
|
||||
// Size of payload + padding + metadata = ByteSizeOf(compile_time_shape)
|
||||
Status UpdateMetadata(se::Stream* stream, se::DeviceMemory<uint8>* buffer,
|
||||
const xla::Shape& compile_time_shape,
|
||||
const xla::Shape& runtime_shape) {
|
||||
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
||||
stream->parent()->platform()));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||
xla::Shape compile_time_shape_static =
|
||||
xla::ShapeUtil::MakeStaticShape(compile_time_shape);
|
||||
uint64 offset = shape_size_fn(compile_time_shape_static);
|
||||
uint64 metadata_size = shape_size_fn(compile_time_shape) - offset;
|
||||
auto metadata_buffer =
|
||||
stream->parent()->GetSubBuffer(buffer, offset, metadata_size);
|
||||
|
||||
auto metadata_literal = std::make_shared<xla::Literal>(
|
||||
xla::LiteralUtil::CreateR1<int32>(PrepareMetadata(runtime_shape)));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->TransferArrayToDeviceAsync(
|
||||
stream, *metadata_literal, metadata_buffer));
|
||||
// Retain the literal until the end of the transfer.
|
||||
stream->ThenDoHostCallback([metadata_literal]() { return Status::OK(); });
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Given a static input buffer, convert it to dynamic form by expanding it to
|
||||
// the bounded size and attaching a metadata filled with dimension sizes.
|
||||
//
|
||||
// From:
|
||||
// +--------+
|
||||
// |Payload |
|
||||
// +--------+
|
||||
//
|
||||
// To:
|
||||
//
|
||||
// +--------+
|
||||
// |Payload |
|
||||
// +--------+
|
||||
// | Padding|
|
||||
// +--------+
|
||||
// |Metadata|
|
||||
// +--------+
|
||||
//
|
||||
// As we can't expand the size of an existing memory allocation, a reallocation
|
||||
// is required. A list of new allocations are returned after this function. The
|
||||
// caller is reponsible for maintaining those allocations.
|
||||
xla::StatusOr<std::vector<se::OwningDeviceMemory>> UpdateDynamicInputs(
|
||||
se::Stream* stream, se::DeviceMemoryAllocator* allocator,
|
||||
std::vector<xla::ShapedBuffer*> runtime_inputs,
|
||||
const std::vector<xla::ShapeLayout>& compile_time_shapes) {
|
||||
std::vector<se::OwningDeviceMemory> new_allocations;
|
||||
TF_RET_CHECK(runtime_inputs.size() == compile_time_shapes.size());
|
||||
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
||||
stream->parent()->platform()));
|
||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||
for (int64 i = 0; i < compile_time_shapes.size(); i++) {
|
||||
const xla::Shape& compile_time_shape = compile_time_shapes[i].shape();
|
||||
if (compile_time_shape.is_static()) {
|
||||
continue;
|
||||
}
|
||||
auto* runtime_input = runtime_inputs[i];
|
||||
|
||||
bool element_modified = false;
|
||||
TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
|
||||
compile_time_shape,
|
||||
[&](const xla::Shape& compile_time_shape,
|
||||
const xla::ShapeIndex& index) -> Status {
|
||||
if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const xla::Shape& runtime_shape = xla::ShapeUtil::GetSubshape(
|
||||
runtime_input->on_device_shape(), index);
|
||||
TF_RET_CHECK(!runtime_shape.IsTuple());
|
||||
TF_RET_CHECK(xla::ShapeUtil::DynamicShapeIsCompatible(
|
||||
runtime_shape, compile_time_shape));
|
||||
se::DeviceMemoryBase* static_input =
|
||||
runtime_input->buffers().mutable_element(index);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto dynamic_input,
|
||||
allocator->Allocate(stream->parent()->device_ordinal(),
|
||||
shape_size_fn(compile_time_shape)));
|
||||
new_allocations.emplace_back(std::move(dynamic_input));
|
||||
se::DeviceMemory<uint8>* dynamic_input_base =
|
||||
new_allocations.back().ptr();
|
||||
// Send the original data to the new location.
|
||||
stream->ThenMemcpyD2D(dynamic_input_base, *static_input,
|
||||
static_input->size());
|
||||
TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base,
|
||||
compile_time_shape, runtime_shape));
|
||||
// Modify the memory location in the input shape tree to point to the
|
||||
// new input.
|
||||
runtime_input->set_buffer(*dynamic_input_base, index);
|
||||
element_modified = true;
|
||||
return Status::OK();
|
||||
}));
|
||||
if (element_modified) {
|
||||
runtime_input->set_shapes(compile_time_shape, compile_time_shape);
|
||||
// The input location has been modified, need to fix tuple table to
|
||||
// point to the correct address.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
transfer_manager->WriteTupleIndexTablesAsync(stream, *runtime_input));
|
||||
}
|
||||
}
|
||||
return std::move(new_allocations);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::Literal> ReadMetadataLiteral(
|
||||
se::Stream* stream, se::DeviceMemoryBase* buffer,
|
||||
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
|
||||
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
||||
stream->parent()->platform()));
|
||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||
xla::Shape buffer_shape_static =
|
||||
xla::ShapeUtil::MakeStaticShape(buffer_shape);
|
||||
const int64 offset = shape_size_fn(buffer_shape_static);
|
||||
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
||||
TF_RET_CHECK(metadata_size != 0);
|
||||
auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
|
||||
auto metadata_buffer =
|
||||
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
||||
return transfer_manager->TransferArrayFromDevice(
|
||||
stream,
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}),
|
||||
metadata_buffer);
|
||||
}
|
||||
|
||||
// For each subshape in the result buffer that's dynamic, read the dynamic
|
||||
// dimension sizes from the metadata, and update output shapes. The result shape
|
||||
// is a static and concrete shape.
|
||||
xla::Status UpdateDynamicOutputs(se::Stream* stream,
|
||||
xla::ShapedBuffer* shaped_buffer,
|
||||
xla::Shape* output_host_shape,
|
||||
xla::Shape* output_device_shape) {
|
||||
DCHECK(output_device_shape->is_dynamic());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||
TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus(
|
||||
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||
const xla::Shape& buffer_shape =
|
||||
xla::ShapeUtil::GetSubshape(*output_device_shape, index);
|
||||
if (buffer_shape.IsTuple()) {
|
||||
return Status::OK();
|
||||
}
|
||||
xla::Shape& host_shape =
|
||||
*xla::ShapeUtil::GetMutableSubshape(output_host_shape, index);
|
||||
xla::Shape& device_shape =
|
||||
*xla::ShapeUtil::GetMutableSubshape(output_device_shape, index);
|
||||
if (device_shape.is_static()) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto metadata,
|
||||
ReadMetadataLiteral(stream, buffer, buffer_shape,
|
||||
transfer_manager));
|
||||
// Update shape size from metadata.
|
||||
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
||||
host_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
device_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
output_host_shape->clear_dynamic_dimensions();
|
||||
output_device_shape->clear_dynamic_dimensions();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create output tuple from run_result.
|
||||
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
||||
se::Stream* stream, xla::ScopedShapedBuffer run_result,
|
||||
xla::Backend* backend, int device_ordinal) {
|
||||
XRTTupleAllocation* output_tuple;
|
||||
xla::ShapedBuffer shaped_buffer = run_result.release();
|
||||
if (shaped_buffer.on_device_shape().is_dynamic()) {
|
||||
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
||||
// dimension sizes read from metadata.
|
||||
xla::Shape output_host_shape = shaped_buffer.on_host_shape();
|
||||
xla::Shape output_device_shape = shaped_buffer.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
|
||||
stream, &shaped_buffer, &output_host_shape, &output_device_shape));
|
||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||
shaped_buffer, output_host_shape, output_device_shape, backend,
|
||||
device_ordinal, &output_tuple));
|
||||
} else {
|
||||
// Fast-path: Don't copy shapes of output buffer.
|
||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||
shaped_buffer, backend, device_ordinal, &output_tuple));
|
||||
}
|
||||
return RefPtr<XRTTupleAllocation>(output_tuple);
|
||||
}
|
||||
|
||||
xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
||||
OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref,
|
||||
xla::LocalExecutable* executable, const InputBuffers& input_buffers,
|
||||
@ -191,18 +421,31 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
||||
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
const std::vector<xla::ShapeLayout>& shape_layouts =
|
||||
executable->executable()
|
||||
->module_config()
|
||||
.entry_computation_layout()
|
||||
.parameter_layouts();
|
||||
TF_ASSIGN_OR_RETURN(auto new_allocations,
|
||||
UpdateDynamicInputs(stream, run_options.allocator(),
|
||||
input_buffers.input_pointers,
|
||||
shape_layouts));
|
||||
auto new_allocations_ptr =
|
||||
std::make_shared<std::vector<se::OwningDeviceMemory>>(
|
||||
std::move(new_allocations));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::ScopedShapedBuffer run_result,
|
||||
executable->Run(input_buffers.input_pointers, run_options));
|
||||
// Retain the new allocation for input memory until the end of execution.
|
||||
stream->ThenDoHostCallback([new_allocations_ptr]() { return Status::OK(); });
|
||||
|
||||
auto elapsed = env->NowMicros() - start_time;
|
||||
VLOG(2) << "Elapsed time: " << elapsed << "us";
|
||||
|
||||
auto shaped_buffer = run_result.release();
|
||||
XRTTupleAllocation* output_tuple;
|
||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||
shaped_buffer, device_ref->backend(), device_ref->device_ordinal(),
|
||||
&output_tuple));
|
||||
RefPtr<XRTTupleAllocation> output_tuple_ptr(output_tuple);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
RefPtr<XRTTupleAllocation> output_tuple_ptr,
|
||||
CreateOutputTuple(stream, std::move(run_result), device_ref->backend(),
|
||||
device_ref->device_ordinal()));
|
||||
|
||||
// The ScopedShapedBuffer returned by the executable Run() API, in case of
|
||||
// input/output buffer aliasing, might have holes in it, which need to be
|
||||
@ -215,7 +458,7 @@ xla::StatusOr<RefPtr<XRTTupleAllocation>> RunExecutable(
|
||||
const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
|
||||
TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size());
|
||||
return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
|
||||
? output_tuple->AliasBufferFrom(
|
||||
? output_tuple_ptr->AliasBufferFrom(
|
||||
*input_buffers.input_tuples[alias.parameter_number],
|
||||
alias.parameter_index, output_index)
|
||||
: Status::OK();
|
||||
|
@ -49,6 +49,67 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
xla::XlaComputation ReturnDynamicR1() {
|
||||
xla::XlaBuilder builder("ReturnDynamicR1");
|
||||
auto p0 = xla::Parameter(&builder, 0,
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
|
||||
auto p1 = xla::Parameter(&builder, 1,
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
|
||||
auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
|
||||
"P2");
|
||||
auto sum = xla::Add(p0, p1);
|
||||
auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
|
||||
return builder.Build(pad_sum).ValueOrDie();
|
||||
}
|
||||
|
||||
xla::XlaComputation AcceptDynamicR1() {
|
||||
xla::XlaBuilder builder("AcceptDynamicR1");
|
||||
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
|
||||
dyn_shape.set_dynamic_dimension(0, true);
|
||||
auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
|
||||
auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1");
|
||||
auto sum = xla::Add(p0, p1);
|
||||
return builder.Build(sum).ValueOrDie();
|
||||
}
|
||||
|
||||
xla::XlaComputation ReturnDynamicR1Tuple() {
|
||||
xla::XlaBuilder builder("ReturnDynamicR1Tuple");
|
||||
auto p0 = xla::Parameter(&builder, 0,
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
|
||||
auto p1 = xla::Parameter(&builder, 1,
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
|
||||
auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
|
||||
"P2");
|
||||
auto sum = xla::Add(p0, p1);
|
||||
auto sub = xla::Sub(p0, p1);
|
||||
auto one = xla::One(&builder, xla::S32);
|
||||
auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
|
||||
auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0);
|
||||
auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub});
|
||||
return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie();
|
||||
}
|
||||
|
||||
xla::XlaComputation AcceptDynamicR1Tuple() {
|
||||
xla::XlaBuilder builder("AcceptDynamicR1");
|
||||
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
|
||||
dyn_shape.set_dynamic_dimension(0, true);
|
||||
xla::Shape tuple_shape =
|
||||
xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
|
||||
xla::Shape nest_tuple_shape =
|
||||
xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
|
||||
auto p = xla::Parameter(&builder, 0, tuple_shape, "P0");
|
||||
auto p0 = xla::GetTupleElement(p, 0);
|
||||
auto p1 = xla::GetTupleElement(p, 1);
|
||||
auto sum = xla::Add(p0, p1);
|
||||
return builder.Build(sum).ValueOrDie();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
xla::LiteralProto CreateR0(T v) {
|
||||
auto array = xla::LiteralUtil::CreateR0<T>(v);
|
||||
return array.ToProto();
|
||||
}
|
||||
|
||||
class XrtClientSession : public ClientSession {
|
||||
public:
|
||||
explicit XrtClientSession(const Scope& scope) : ClientSession(scope) {
|
||||
@ -61,6 +122,11 @@ class XrtClientSession : public ClientSession {
|
||||
string* xla_test_device_ptr; // initial value set in main()
|
||||
string* xla_platform_ptr; // initial value set in main()
|
||||
|
||||
bool SupportDynamicShapes() {
|
||||
// TODO(jackcao): Support dynamic shapes on XLA GPU.
|
||||
return *xla_test_device_ptr != "XLA_GPU";
|
||||
}
|
||||
|
||||
string DeviceFromFlag() {
|
||||
string xla_test_device = *xla_test_device_ptr;
|
||||
return absl::StrCat("/device:", xla_test_device, ":0");
|
||||
@ -1035,6 +1101,239 @@ TEST(RawApiTest, CompileAndExecute) {
|
||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||
}
|
||||
|
||||
TEST(RawApiTest, DynamicR1Test) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
return;
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
|
||||
xrt::XLAAllocation p1;
|
||||
*p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f});
|
||||
xrt::XLAAllocation p2;
|
||||
*p2.mutable_value() = CreateR0<xla::int32>(2);
|
||||
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
|
||||
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
|
||||
dyn_shape.set_dynamic_dimension(0, true);
|
||||
*shapes->mutable_result() = dyn_shape.ToProto();
|
||||
StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
e.set_release_input_handles(true);
|
||||
e.set_release_compilation_handle(true);
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
Scope cpu_root = root.WithDevice("/device:CPU:0");
|
||||
auto e_config = ops::Const(cpu_root, e.SerializeAsString());
|
||||
auto computation = ops::Const(cpu_root, c.SerializeAsString());
|
||||
auto c_handle = ops::XRTCompile(root, computation);
|
||||
auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
|
||||
auto p0_handle = ops::XRTAllocate(root, p0_value);
|
||||
auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
|
||||
auto p1_handle = ops::XRTAllocate(root, p1_value);
|
||||
auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
|
||||
auto p2_handle = ops::XRTAllocate(root, p2_value);
|
||||
auto result = ops::XRTExecute(
|
||||
root, c_handle.handle, e_config,
|
||||
{Output(p0_handle), Output(p1_handle), Output(p2_handle)});
|
||||
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
XrtClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
|
||||
|
||||
xla::LiteralProto response;
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
|
||||
auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, DynamicR1TupleTest) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
return;
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
|
||||
xrt::XLAAllocation p1;
|
||||
*p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f});
|
||||
xrt::XLAAllocation p2;
|
||||
*p2.mutable_value() = CreateR0<xla::int32>(2);
|
||||
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
|
||||
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
|
||||
dyn_shape.set_dynamic_dimension(0, true);
|
||||
*shapes->mutable_result() =
|
||||
xla::ShapeUtil::MakeTupleShape(
|
||||
{dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape})
|
||||
.ToProto();
|
||||
StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
e.set_release_input_handles(true);
|
||||
e.set_release_compilation_handle(true);
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
Scope cpu_root = root.WithDevice("/device:CPU:0");
|
||||
auto e_config = ops::Const(cpu_root, e.SerializeAsString());
|
||||
auto computation = ops::Const(cpu_root, c.SerializeAsString());
|
||||
auto c_handle = ops::XRTCompile(root, computation);
|
||||
auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
|
||||
auto p0_handle = ops::XRTAllocate(root, p0_value);
|
||||
auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
|
||||
auto p1_handle = ops::XRTAllocate(root, p1_value);
|
||||
auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
|
||||
auto p2_handle = ops::XRTAllocate(root, p2_value);
|
||||
auto result = ops::XRTExecute(
|
||||
root, c_handle.handle, e_config,
|
||||
{Output(p0_handle), Output(p1_handle), Output(p2_handle)});
|
||||
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
XrtClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
|
||||
|
||||
xla::LiteralProto response;
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
|
||||
|
||||
auto expected0 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
|
||||
auto expected1 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f, 0.0f});
|
||||
auto expected2 = xla::LiteralUtil::CreateR1<float>({0.0f, 3.0f, 1.0f});
|
||||
auto expected =
|
||||
xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2});
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, AcceptDynamicR1TupleTest) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
return;
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
|
||||
xrt::XLAAllocation p1;
|
||||
*p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
|
||||
|
||||
xrt::XLATupleNode tuple_desc;
|
||||
auto subdesc_10 = tuple_desc.add_tuples();
|
||||
auto subdesc_11 = tuple_desc.add_tuples();
|
||||
subdesc_10->set_input_index(0);
|
||||
subdesc_10->set_release_input_handle(true);
|
||||
subdesc_11->set_input_index(1);
|
||||
subdesc_11->set_release_input_handle(true);
|
||||
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
|
||||
dyn_input_shape.set_dynamic_dimension(0, true);
|
||||
xla::Shape dyn_tuple_shape =
|
||||
xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape});
|
||||
*shapes->add_parameters() = dyn_tuple_shape.ToProto();
|
||||
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
|
||||
dyn_shape.set_dynamic_dimension(0, true);
|
||||
*shapes->mutable_result() = dyn_shape.ToProto();
|
||||
StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
e.set_release_input_handles(true);
|
||||
e.set_release_compilation_handle(true);
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
Scope cpu_root = root.WithDevice("/device:CPU:0");
|
||||
auto e_config = ops::Const(cpu_root, e.SerializeAsString());
|
||||
auto computation = ops::Const(cpu_root, c.SerializeAsString());
|
||||
auto c_handle = ops::XRTCompile(root, computation);
|
||||
auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
|
||||
auto p0_handle = ops::XRTAllocate(root, p0_value);
|
||||
auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
|
||||
auto p1_handle = ops::XRTAllocate(root, p1_value);
|
||||
|
||||
auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"),
|
||||
tuple_desc.SerializeAsString());
|
||||
auto t0_handle = ops::XRTMakeTuple(
|
||||
root, tuple_0,
|
||||
{static_cast<Output>(p0_handle), static_cast<Output>(p1_handle)});
|
||||
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
|
||||
{static_cast<Output>(t0_handle)});
|
||||
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
XrtClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
|
||||
|
||||
xla::LiteralProto response;
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
|
||||
|
||||
auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, AcceptDynamicR1Test) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
return;
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
|
||||
xrt::XLAAllocation p1;
|
||||
*p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
|
||||
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
|
||||
dyn_input_shape.set_dynamic_dimension(0, true);
|
||||
*shapes->add_parameters() = dyn_input_shape.ToProto();
|
||||
*shapes->add_parameters() = dyn_input_shape.ToProto();
|
||||
xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
|
||||
dyn_shape.set_dynamic_dimension(0, true);
|
||||
*shapes->mutable_result() = dyn_shape.ToProto();
|
||||
StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
e.set_release_input_handles(true);
|
||||
e.set_release_compilation_handle(true);
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
Scope cpu_root = root.WithDevice("/device:CPU:0");
|
||||
auto e_config = ops::Const(cpu_root, e.SerializeAsString());
|
||||
auto computation = ops::Const(cpu_root, c.SerializeAsString());
|
||||
auto c_handle = ops::XRTCompile(root, computation);
|
||||
auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
|
||||
auto allocate_op_0 = ops::XRTAllocate(root, p0_value);
|
||||
auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
|
||||
auto allocate_op_1 = ops::XRTAllocate(root, p1_value);
|
||||
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
|
||||
{Output(allocate_op_0), Output(allocate_op_1)});
|
||||
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
XrtClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
|
||||
|
||||
xla::LiteralProto response;
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
|
||||
|
||||
auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f});
|
||||
|
Loading…
Reference in New Issue
Block a user