From 5fee245d9ff76b9fae2b7404f47aae83c84b8564 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 13 May 2020 13:57:37 -0700 Subject: [PATCH] [XLA] Basic (R1) support for CPU bounded dynamic shapes. - Add dynamic tensor metadata read/write in XRT. - Implement two custom calls: PadToStatic and SliceToDynamic -- R1 only. - Some helper functions in shape util to do sanity check. - Tests -- R1 Only. PiperOrigin-RevId: 311398639 Change-Id: I7129fd13f4e0a2b7a14efb52eb814f753a15e05e --- tensorflow/compiler/xla/BUILD | 1 + tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 15 +- .../xla/service/cpu/cpu_executable.cc | 7 +- .../compiler/xla/service/cpu/ir_emitter.cc | 88 ++++++ .../compiler/xla/service/cpu/ir_emitter.h | 2 + .../compiler/xla/service/shaped_buffer.h | 13 + tensorflow/compiler/xla/shape_util.cc | 36 ++- tensorflow/compiler/xla/shape_util.h | 29 +- tensorflow/compiler/xrt/kernels/BUILD | 1 + .../compiler/xrt/kernels/xrt_execute_op.cc | 257 ++++++++++++++- tensorflow/compiler/xrt/tests/raw_api_test.cc | 299 ++++++++++++++++++ 12 files changed, 726 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 0193bea9d6d..45f49cee328 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -331,6 +331,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 121bdedf2dd..2f432cd9356 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -146,6 +146,7 @@ cc_library( "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", + "//tensorflow/compiler/xla/service:dynamic_padder", "//tensorflow/compiler/xla/service:dynamic_index_splitter", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index b04237138e8..fe769bbdd2a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -72,6 +72,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/dump.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" +#include "tensorflow/compiler/xla/service/dynamic_padder.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -239,7 +240,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( HloPassPipeline pipeline("HLO passes through layout assignment"); pipeline.AddInvariantChecker(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - // Expand random number generation. pipeline.AddPass(); pipeline.AddPass(RandomAlgorithm::RNG_PHILOX); @@ -273,6 +273,13 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/false); + pipeline.AddPass(); + pipeline.AddPass( + /*rewrite_training_op=*/true, + /*rewrite_inference_op=*/true, + /*rewrite_grad_op=*/true); + pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(target_machine_features); { auto& pass = @@ -281,12 +288,6 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*allow_mixed_precision=*/false); pass.AddPass(); - pass.AddPass(); - pass.AddPass( - /*rewrite_training_op=*/true, - /*rewrite_inference_op=*/true, - /*rewrite_grad_op=*/true); - pipeline.AddPass(); AlgebraicSimplifierOptions options; options.set_enable_dot_strength_reduction(false); pass.AddPass(options); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 8c1ae0179c0..f031daecb1f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -363,7 +363,12 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( if (shape.IsOpaque()) { return sizeof(void*); } - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + if (shape.is_static() || shape.IsTuple()) { + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + } + // Each dynamic dimension size is represented as a S32. + int64 metadata_size = sizeof(int32) * shape.dimensions_size(); + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)) + metadata_size; } const InstructionValueSet& CpuExecutable::GetRootValueSet() const { diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 2b715bfa17a..f516a1538d3 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -2357,7 +2357,95 @@ Status IrEmitter::HandleCall(HloInstruction* call) { return Status::OK(); } +Status IrEmitter::HandleSliceToDynamic(HloInstruction* hlo) { + // TODO(jackcao): Generalize this to generic llvm emitter. + TF_RET_CHECK(hlo->shape().rank() == 1); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + for (int64 i = 1; i < hlo->operand_count(); ++i) { + const int64 dim_index = i - 1; + llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(i)); + llvm::LoadInst* dim_size = b_.CreateLoad(source_buffer, "dim_size"); + llvm::Value* dest_buffer = GetEmittedValueFor(hlo); + llvm::Value* raw_buffer = + b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo()); + + int32 raw_data_size = + ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(hlo->shape())); + llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( + b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); + b_.CreateStore(dim_size, + b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); + } + + return EmitTargetElementLoop(hlo, + [=](const llvm_ir::IrArray::Index& dest_index) { + // TODO(jackcao): Properly linearize dest_index + // and delinearize to source index. + return GetIrArrayFor(hlo->operand(0)) + .EmitReadArrayElement(dest_index, &b_); + }); +} + +Status IrEmitter::HandlePadToStatic(HloInstruction* hlo) { + // TODO(jackcao): Generalize this to generic llvm emitter. + TF_RET_CHECK(hlo->operand(0)->shape().rank() == 1); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(hlo)); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice data_slice, + assignment_.GetUniqueSlice(hlo, {0})); + const Shape& data_shape = ShapeUtil::GetSubshape(hlo->shape(), {0}); + llvm::Value* data_address = EmitBufferPointer(data_slice, data_shape); + llvm_ir::IrArray data_array(data_address, data_shape); + TF_RETURN_IF_ERROR(llvm_ir::LoopEmitter( + [=](const llvm_ir::IrArray::Index& dest_index) { + // TODO(jackcao): Properly linearize dest_index and + // delinearize to source index. + return GetIrArrayFor(hlo->operand(0)) + .EmitReadArrayElement(dest_index, &b_); + }, + llvm_ir::IrArray(data_address, data_shape), &b_) + .EmitLoop(IrName(hlo))); + std::vector tuple_operand_ptrs; + tuple_operand_ptrs.push_back(data_array.GetBasePointer()); + + // PadToStatic has a dynamic tensor as input and variadic size of outputs: + // (static_tensor, dynamic_dim_0, dynamic_dim_1, ... ) + // Dynamic dimension sizes starts from output index 1. + for (int64 i = 1; i < hlo->shape().tuple_shapes_size(); ++i) { + // Read from the metadata section of the dynamic input (operand 0). + const Shape& dim_shape = ShapeUtil::GetSubshape(hlo->shape(), {i}); + TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dim_size_slice, + assignment_.GetUniqueSlice(hlo, {i})); + llvm::Value* dest_dim_size_address = + EmitBufferPointer(dim_size_slice, data_shape); + const int64 dim_index = i - 1; + llvm::Value* source_buffer = GetEmittedValueFor(hlo->operand(0)); + llvm::Value* raw_buffer = + b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo()); + int32 raw_data_size = ShapeUtil::ByteSizeOf( + ShapeUtil::MakeStaticShape(hlo->operand(0)->shape())); + llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32( + b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32)); + llvm::Value* dim_size = b_.CreateLoad( + b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo())); + b_.CreateStore(dim_size, b_.CreateBitCast(dest_dim_size_address, + b_.getInt32Ty()->getPointerTo())); + tuple_operand_ptrs.push_back(dest_dim_size_address); + } + + // Emit static tensor and dynamic sizes as one tuple. + llvm_ir::EmitTuple(GetIrArrayFor(hlo), tuple_operand_ptrs, &b_); + return Status::OK(); +} + Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { + if (custom_call->custom_call_target() == "PadToStatic") { + return HandlePadToStatic(custom_call); + } + if (custom_call->custom_call_target() == "SliceToDynamic") { + return HandleSliceToDynamic(custom_call); + } absl::Span operands(custom_call->operands()); llvm::Type* i8_ptr_type = b_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 24524c67b11..9b0d11e9f3f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -183,6 +183,8 @@ class IrEmitter : public DfsHloVisitorWithDefault, } private: + Status HandleSliceToDynamic(HloInstruction* hlo); + Status HandlePadToStatic(HloInstruction* hlo); Status HandleAllReduceSingleReplica(HloInstruction* crs); Status HandleAllReduceMultipleReplica(HloInstruction* crs); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index a1872330648..b7a67b4e66e 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -93,6 +94,18 @@ class ShapedBuffer { buffers_.replace_shape_ptr(&on_device_shape_); } + // Reset the shape of this shaped buffer and underlying buffer structure. + // + // Precondition: EqualStructure(this->on_device_shape_, on_device_shape). + void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) { + CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_)) + << "Structures are not the same. new: " << on_device_shape + << ", old: " << on_device_shape_; + on_host_shape_ = on_host_shape; + on_device_shape_ = on_device_shape; + buffers_.replace_shape_ptr(&on_device_shape_); + } + // Returns the underlying ShapeTree containing all the device addresses in the // ShapedBuffer. const ShapeTree& buffers() const { return buffers_; } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 22ee5a16a30..52cbb8f95ac 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/strings/ascii.h" #include "absl/strings/numbers.h" @@ -150,6 +151,19 @@ StatusOr MakeShapeWithLayoutInternal( return equal; } +/* static */ bool ShapeUtil::EqualStructure(const Shape& lhs, + const Shape& rhs) { + bool equal = true; + ForEachSubshape(lhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) { + equal &= IndexIsValid(rhs, index); + }); + ForEachSubshape(rhs, [&](const Shape& /*subshape*/, const ShapeIndex& index) { + equal &= IndexIsValid(lhs, index); + }); + + return equal; +} + /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -261,6 +275,12 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ValidateShape(*shape); } +/* static */ Shape ShapeUtil::MakeStaticShape(const Shape& original) { + Shape result = original; + result.clear_dynamic_dimensions(); + return result; +} + /* static */ Shape ShapeUtil::MakeTupleShape(absl::Span shapes) { Shape result; result.set_element_type(TUPLE); @@ -626,8 +646,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( if (shape.element_type() == TUPLE) { return ByteSizeOfTupleIndexTable(shape, pointer_size); } else if (shape.IsArray()) { - int64 byte_size = ByteSizeOfElements(shape); - return byte_size; + return ByteSizeOfElements(shape); } else if (shape.element_type() == TOKEN) { return 0; } else if (shape.element_type() == OPAQUE_TYPE) { @@ -1441,6 +1460,19 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified( return shape; } +/* static */ bool ShapeUtil::DynamicShapeIsCompatible( + const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { + if (dynamic_shape.rank() != bounded_shape.rank()) { + return false; + } + for (int64 i = 0; i < dynamic_shape.rank(); ++i) { + if (dynamic_shape.dimensions(i) > bounded_shape.dimensions(i)) { + return false; + } + } + return true; +} + /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { CHECK(shape.IsArray()); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 7e05e17865d..dde56587482 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -298,6 +298,16 @@ class ShapeUtil { // As Equal, but allow one of lhs and rhs to be F16 while the other is F32. static bool EqualIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Two shapes have same structure if all subshape indices of lhs are presented + // on rhs and vice versa. + // A nested tuple shape of (F32, (S32[2], F32[2, 2])) is structurally equal to + // (S32, (F32[3], S32[2])) as their structures are both (,(,)) + // + // In contrast, (F32, (F32, F32)) is structurally different from + // ((F32, F32), F32) as the former has structure (,(,)) while the latter has + // ((,),) + static bool EqualStructure(const Shape& lhs, const Shape& rhs); + // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just // fluff. Note that zero dimensions are included in the true rank, e.g., @@ -339,6 +349,9 @@ class ShapeUtil { // element type changed to type. static Shape ChangeElementType(const Shape& original, PrimitiveType type); + // Retursn a shape with same dimensions but with all dimensions set to static. + static Shape MakeStaticShape(const Shape& original); + // Creates a tuple shape from a slice of element shapes within the tuple. static Shape MakeTupleShape(absl::Span shapes); @@ -643,12 +656,16 @@ class ShapeUtil { static Shape FilterDimensions(const std::function& p, Shape shape); - // Iterates through all the shape indexes, in minor to major order, starting - // from the base indexes, incrementing by the incr steps, up to count - // (index[i] < base[i] + count[i]), and calls the visitor_function with the - // current index. - // The visitor_function visitor function should return true if it wants to - // continue, or false otherwise. + // Returns true if `dynamic_shape` has dimensions that are less-equal to the + // "bounded_shape". + static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, + const xla::Shape& bounded_shape); + + // Iterates through all the shape indexes, in minor to major order, + // starting from the base indexes, incrementing by the incr steps, up to + // count (index[i] < base[i] + count[i]), and calls the visitor_function + // with the current index. The visitor_function visitor function should + // return true if it wants to continue, or false otherwise. // // visitor_function must be a callable of type // StatusOr(absl::Span) or compatible. diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index d71e6e2cc73..494ba29e981 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -49,6 +49,7 @@ cc_library( deps = [ ":xrt_state_ops", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index d39b37387f2..2fc599e42df 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" @@ -38,7 +39,11 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/monitoring/timed.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/device_memory_allocator.h" +#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/stream_executor.h" #include "tensorflow/stream_executor/stream_executor_internal.h" @@ -146,6 +151,231 @@ xla::StatusOr GetChainedOpInputs( return std::move(input_buffers); } +// Given a shape, returns a byte array representing the shape metadata of the +// shape. The shape metadata contains dimensions sizes stored as contiguous S32. +std::vector PrepareMetadata(const xla::Shape& shape) { + DCHECK(shape.is_static()); + DCHECK(shape.IsArray()); + // Each dimension size is stored as a S32. + std::vector result(shape.dimensions_size()); + for (int64 i = 0; i < shape.dimensions_size(); ++i) { + result[i] = shape.dimensions(i); + } + return result; +} + +// Given a buffer with dynamic shape, update buffer metadata at the correct +// offset starting from that buffer. +// +// +-----------+ +// |Payload | +// +-----------+ +// | Padding | +// +-----------+ +// |dim_size_0 | (each dim_size is a S32): +// +-----------+ +// |dim_size_1 | +// +-----------+ +// .......... +// +-----------+ +// +// Size of payload = ByteSizeOf(runtime_shape) +// Size of payload + padding = ByteSizeOf(compile_time_shape_static) +// Size of payload + padding + metadata = ByteSizeOf(compile_time_shape) +Status UpdateMetadata(se::Stream* stream, se::DeviceMemory* buffer, + const xla::Shape& compile_time_shape, + const xla::Shape& runtime_shape) { + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + xla::Shape compile_time_shape_static = + xla::ShapeUtil::MakeStaticShape(compile_time_shape); + uint64 offset = shape_size_fn(compile_time_shape_static); + uint64 metadata_size = shape_size_fn(compile_time_shape) - offset; + auto metadata_buffer = + stream->parent()->GetSubBuffer(buffer, offset, metadata_size); + + auto metadata_literal = std::make_shared( + xla::LiteralUtil::CreateR1(PrepareMetadata(runtime_shape))); + TF_RETURN_IF_ERROR(transfer_manager->TransferArrayToDeviceAsync( + stream, *metadata_literal, metadata_buffer)); + // Retain the literal until the end of the transfer. + stream->ThenDoHostCallback([metadata_literal]() { return Status::OK(); }); + return Status::OK(); +} + +// Given a static input buffer, convert it to dynamic form by expanding it to +// the bounded size and attaching a metadata filled with dimension sizes. +// +// From: +// +--------+ +// |Payload | +// +--------+ +// +// To: +// +// +--------+ +// |Payload | +// +--------+ +// | Padding| +// +--------+ +// |Metadata| +// +--------+ +// +// As we can't expand the size of an existing memory allocation, a reallocation +// is required. A list of new allocations are returned after this function. The +// caller is reponsible for maintaining those allocations. +xla::StatusOr> UpdateDynamicInputs( + se::Stream* stream, se::DeviceMemoryAllocator* allocator, + std::vector runtime_inputs, + const std::vector& compile_time_shapes) { + std::vector new_allocations; + TF_RET_CHECK(runtime_inputs.size() == compile_time_shapes.size()); + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + for (int64 i = 0; i < compile_time_shapes.size(); i++) { + const xla::Shape& compile_time_shape = compile_time_shapes[i].shape(); + if (compile_time_shape.is_static()) { + continue; + } + auto* runtime_input = runtime_inputs[i]; + + bool element_modified = false; + TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( + compile_time_shape, + [&](const xla::Shape& compile_time_shape, + const xla::ShapeIndex& index) -> Status { + if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) { + return Status::OK(); + } + const xla::Shape& runtime_shape = xla::ShapeUtil::GetSubshape( + runtime_input->on_device_shape(), index); + TF_RET_CHECK(!runtime_shape.IsTuple()); + TF_RET_CHECK(xla::ShapeUtil::DynamicShapeIsCompatible( + runtime_shape, compile_time_shape)); + se::DeviceMemoryBase* static_input = + runtime_input->buffers().mutable_element(index); + TF_ASSIGN_OR_RETURN( + auto dynamic_input, + allocator->Allocate(stream->parent()->device_ordinal(), + shape_size_fn(compile_time_shape))); + new_allocations.emplace_back(std::move(dynamic_input)); + se::DeviceMemory* dynamic_input_base = + new_allocations.back().ptr(); + // Send the original data to the new location. + stream->ThenMemcpyD2D(dynamic_input_base, *static_input, + static_input->size()); + TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base, + compile_time_shape, runtime_shape)); + // Modify the memory location in the input shape tree to point to the + // new input. + runtime_input->set_buffer(*dynamic_input_base, index); + element_modified = true; + return Status::OK(); + })); + if (element_modified) { + runtime_input->set_shapes(compile_time_shape, compile_time_shape); + // The input location has been modified, need to fix tuple table to + // point to the correct address. + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR( + transfer_manager->WriteTupleIndexTablesAsync(stream, *runtime_input)); + } + } + return std::move(new_allocations); +} + +xla::StatusOr ReadMetadataLiteral( + se::Stream* stream, se::DeviceMemoryBase* buffer, + const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) { + TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( + stream->parent()->platform())); + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + xla::Shape buffer_shape_static = + xla::ShapeUtil::MakeStaticShape(buffer_shape); + const int64 offset = shape_size_fn(buffer_shape_static); + int64 metadata_size = shape_size_fn(buffer_shape) - offset; + TF_RET_CHECK(metadata_size != 0); + auto buffer_8 = se::DeviceMemory(*buffer); + auto metadata_buffer = + stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + return transfer_manager->TransferArrayFromDevice( + stream, + xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}), + metadata_buffer); +} + +// For each subshape in the result buffer that's dynamic, read the dynamic +// dimension sizes from the metadata, and update output shapes. The result shape +// is a static and concrete shape. +xla::Status UpdateDynamicOutputs(se::Stream* stream, + xla::ShapedBuffer* shaped_buffer, + xla::Shape* output_host_shape, + xla::Shape* output_device_shape) { + DCHECK(output_device_shape->is_dynamic()); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const xla::Shape& buffer_shape = + xla::ShapeUtil::GetSubshape(*output_device_shape, index); + if (buffer_shape.IsTuple()) { + return Status::OK(); + } + xla::Shape& host_shape = + *xla::ShapeUtil::GetMutableSubshape(output_host_shape, index); + xla::Shape& device_shape = + *xla::ShapeUtil::GetMutableSubshape(output_device_shape, index); + if (device_shape.is_static()) { + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(auto metadata, + ReadMetadataLiteral(stream, buffer, buffer_shape, + transfer_manager)); + // Update shape size from metadata. + for (int64 i = 0; i < metadata.element_count(); ++i) { + host_shape.mutable_dimensions()[i] = metadata.Get({i}); + device_shape.mutable_dimensions()[i] = metadata.Get({i}); + } + return Status::OK(); + })); + output_host_shape->clear_dynamic_dimensions(); + output_device_shape->clear_dynamic_dimensions(); + return Status::OK(); +} + +// Create output tuple from run_result. +xla::StatusOr> CreateOutputTuple( + se::Stream* stream, xla::ScopedShapedBuffer run_result, + xla::Backend* backend, int device_ordinal) { + XRTTupleAllocation* output_tuple; + xla::ShapedBuffer shaped_buffer = run_result.release(); + if (shaped_buffer.on_device_shape().is_dynamic()) { + // Update dynamic shapes from output buffer, and create a XRT tensor with + // dimension sizes read from metadata. + xla::Shape output_host_shape = shaped_buffer.on_host_shape(); + xla::Shape output_device_shape = shaped_buffer.on_device_shape(); + TF_RETURN_IF_ERROR(UpdateDynamicOutputs( + stream, &shaped_buffer, &output_host_shape, &output_device_shape)); + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + shaped_buffer, output_host_shape, output_device_shape, backend, + device_ordinal, &output_tuple)); + } else { + // Fast-path: Don't copy shapes of output buffer. + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + shaped_buffer, backend, device_ordinal, &output_tuple)); + } + return RefPtr(output_tuple); +} + xla::StatusOr> RunExecutable( OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, xla::LocalExecutable* executable, const InputBuffers& input_buffers, @@ -191,18 +421,31 @@ xla::StatusOr> RunExecutable( Env* env = Env::Default(); auto start_time = env->NowMicros(); + const std::vector& shape_layouts = + executable->executable() + ->module_config() + .entry_computation_layout() + .parameter_layouts(); + TF_ASSIGN_OR_RETURN(auto new_allocations, + UpdateDynamicInputs(stream, run_options.allocator(), + input_buffers.input_pointers, + shape_layouts)); + auto new_allocations_ptr = + std::make_shared>( + std::move(new_allocations)); TF_ASSIGN_OR_RETURN( xla::ScopedShapedBuffer run_result, executable->Run(input_buffers.input_pointers, run_options)); + // Retain the new allocation for input memory until the end of execution. + stream->ThenDoHostCallback([new_allocations_ptr]() { return Status::OK(); }); + auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; - auto shaped_buffer = run_result.release(); - XRTTupleAllocation* output_tuple; - TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - shaped_buffer, device_ref->backend(), device_ref->device_ordinal(), - &output_tuple)); - RefPtr output_tuple_ptr(output_tuple); + TF_ASSIGN_OR_RETURN( + RefPtr output_tuple_ptr, + CreateOutputTuple(stream, std::move(run_result), device_ref->backend(), + device_ref->device_ordinal())); // The ScopedShapedBuffer returned by the executable Run() API, in case of // input/output buffer aliasing, might have holes in it, which need to be @@ -215,7 +458,7 @@ xla::StatusOr> RunExecutable( const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size()); return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias - ? output_tuple->AliasBufferFrom( + ? output_tuple_ptr->AliasBufferFrom( *input_buffers.input_tuples[alias.parameter_number], alias.parameter_index, output_index) : Status::OK(); diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 243289c8821..fbf9dfd0a17 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -49,6 +49,67 @@ limitations under the License. namespace tensorflow { namespace { +xla::XlaComputation ReturnDynamicR1() { + xla::XlaBuilder builder("ReturnDynamicR1"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto pad_sum = xla::SetDimensionSize(sum, p2, 0); + return builder.Build(pad_sum).ValueOrDie(); +} + +xla::XlaComputation AcceptDynamicR1() { + xla::XlaBuilder builder("AcceptDynamicR1"); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0"); + auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1"); + auto sum = xla::Add(p0, p1); + return builder.Build(sum).ValueOrDie(); +} + +xla::XlaComputation ReturnDynamicR1Tuple() { + xla::XlaBuilder builder("ReturnDynamicR1Tuple"); + auto p0 = xla::Parameter(&builder, 0, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0"); + auto p1 = xla::Parameter(&builder, 1, + xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1"); + auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}), + "P2"); + auto sum = xla::Add(p0, p1); + auto sub = xla::Sub(p0, p1); + auto one = xla::One(&builder, xla::S32); + auto pad_sum = xla::SetDimensionSize(sum, p2, 0); + auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0); + auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub}); + return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie(); +} + +xla::XlaComputation AcceptDynamicR1Tuple() { + xla::XlaBuilder builder("AcceptDynamicR1"); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + xla::Shape tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); + xla::Shape nest_tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape}); + auto p = xla::Parameter(&builder, 0, tuple_shape, "P0"); + auto p0 = xla::GetTupleElement(p, 0); + auto p1 = xla::GetTupleElement(p, 1); + auto sum = xla::Add(p0, p1); + return builder.Build(sum).ValueOrDie(); +} + +template +xla::LiteralProto CreateR0(T v) { + auto array = xla::LiteralUtil::CreateR0(v); + return array.ToProto(); +} + class XrtClientSession : public ClientSession { public: explicit XrtClientSession(const Scope& scope) : ClientSession(scope) { @@ -61,6 +122,11 @@ class XrtClientSession : public ClientSession { string* xla_test_device_ptr; // initial value set in main() string* xla_platform_ptr; // initial value set in main() +bool SupportDynamicShapes() { + // TODO(jackcao): Support dynamic shapes on XLA GPU. + return *xla_test_device_ptr != "XLA_GPU"; +} + string DeviceFromFlag() { string xla_test_device = *xla_test_device_ptr; return absl::StrCat("/device:", xla_test_device, ":0"); @@ -1035,6 +1101,239 @@ TEST(RawApiTest, CompileAndExecute) { EXPECT_EQ(program_shape.parameters_size(), 2); } +TEST(RawApiTest, DynamicR1Test) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f}); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, DynamicR1TupleTest) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f}); + xrt::XLAAllocation p2; + *p2.mutable_value() = CreateR0(2); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = + xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto(); + *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = + xla::ShapeUtil::MakeTupleShape( + {dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape}) + .ToProto(); + StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + auto p2_value = ops::Const(cpu_root, p2.SerializeAsString()); + auto p2_handle = ops::XRTAllocate(root, p2_value); + auto result = ops::XRTExecute( + root, c_handle.handle, e_config, + {Output(p0_handle), Output(p1_handle), Output(p2_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected0 = xla::LiteralUtil::CreateR1({2.0f, 1.0f}); + auto expected1 = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f, 0.0f}); + auto expected2 = xla::LiteralUtil::CreateR1({0.0f, 3.0f, 1.0f}); + auto expected = + xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, AcceptDynamicR1TupleTest) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); + + xrt::XLATupleNode tuple_desc; + auto subdesc_10 = tuple_desc.add_tuples(); + auto subdesc_11 = tuple_desc.add_tuples(); + subdesc_10->set_input_index(0); + subdesc_10->set_release_input_handle(true); + subdesc_11->set_input_index(1); + subdesc_11->set_release_input_handle(true); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_input_shape.set_dynamic_dimension(0, true); + xla::Shape dyn_tuple_shape = + xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape}); + *shapes->add_parameters() = dyn_tuple_shape.ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto p0_handle = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto p1_handle = ops::XRTAllocate(root, p1_value); + + auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"), + tuple_desc.SerializeAsString()); + auto t0_handle = ops::XRTMakeTuple( + root, tuple_0, + {static_cast(p0_handle), static_cast(p1_handle)}); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {static_cast(t0_handle)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + +TEST(RawApiTest, AcceptDynamicR1Test) { + if (!SupportDynamicShapes()) { + return; + } + xrt::XLAAllocation p0; + *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f}); + xrt::XLAAllocation p1; + *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f}); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_input_shape.set_dynamic_dimension(0, true); + *shapes->add_parameters() = dyn_input_shape.ToProto(); + *shapes->add_parameters() = dyn_input_shape.ToProto(); + xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4}); + dyn_shape.set_dynamic_dimension(0, true); + *shapes->mutable_result() = dyn_shape.ToProto(); + StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot()); + + xrt::XRTExecutionConfig e; + e.set_release_input_handles(true); + e.set_release_compilation_handle(true); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + Scope cpu_root = root.WithDevice("/device:CPU:0"); + auto e_config = ops::Const(cpu_root, e.SerializeAsString()); + auto computation = ops::Const(cpu_root, c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto p0_value = ops::Const(cpu_root, p0.SerializeAsString()); + auto allocate_op_0 = ops::XRTAllocate(root, p0_value); + auto p1_value = ops::Const(cpu_root, p1.SerializeAsString()); + auto allocate_op_1 = ops::XRTAllocate(root, p1_value); + auto result = ops::XRTExecute(root, c_handle.handle, e_config, + {Output(allocate_op_0), Output(allocate_op_1)}); + auto read_back = ops::XRTReadLiteralAndRelease(root, result); + TF_ASSERT_OK(root.status()); + + XrtClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); + + xla::LiteralProto response; + EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); + + auto expected = xla::LiteralUtil::CreateR1({2.0f, 1.0f, 0.0f}); + EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); +} + TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { xrt::XLAAllocation p0; *p0.mutable_value() = FloatVector({1.0f, 2.0f});