Implement the llvm lowering for the customcall sliceToDynamic and padToStatic for on XLA:GPU.

PiperOrigin-RevId: 315784895
Change-Id: Ibfda342ea7a0b616cb34c11198cfa38ce1cef6a9
This commit is contained in:
A. Unique TensorFlower 2020-06-10 15:47:48 -07:00 committed by TensorFlower Gardener
parent 83af443dc7
commit 64a5248407
11 changed files with 470 additions and 53 deletions

View File

@ -1194,6 +1194,7 @@ cc_library(
"//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:dump",
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
"//tensorflow/compiler/xla/service:dynamic_padder",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",

View File

@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
@ -157,6 +158,25 @@ Status GpuCompiler::OptimizeHloModule(
// most ops.
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
// where possible. Not every batchnorm op can be implemented as a call to
// cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
// Since BatchNorm inference is essentially pointwise operations, it is
// always advantageous to use kernel fusion rather than cudnn.
pipeline.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/false,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/false);
pipeline.AddPass<CudnnBatchNormRewriter>();
}
pipeline.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
pipeline.AddPass<DynamicPadder>();
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
@ -164,23 +184,6 @@ Status GpuCompiler::OptimizeHloModule(
/*layout_sensitive=*/false,
/*allow_mixed_precision=*/false);
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
// where possible. Not every batchnorm op can be implemented as a call to
// cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
// Since BatchNorm inference is essentially pointwise operations, it is
// always advantageous to use kernel fusion rather than cudnn.
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/false,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/false);
pass.AddPass<CudnnBatchNormRewriter>();
}
pass.AddPass<BatchNormExpander>(
/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
pipeline.AddPass<HloGetDimensionSizeRewriter>();
// BatchNormExpander can create zero-sized ops, so zero-sized HLO

View File

@ -93,9 +93,13 @@ class GpuCompiler : public LLVMCompiler {
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
// Capture just the pointer size, not the entire GpuCompiler object.
int64 pointer_size = pointer_size_;
return [pointer_size](const Shape& shape) {
return ShapeUtil::ByteSizeOf(shape, pointer_size);
return [pointer_size = pointer_size_](const Shape& shape) {
if (shape.is_static() || shape.IsTuple()) {
return ShapeUtil::ByteSizeOf(shape, pointer_size);
}
// Each dynamic dimension size is represented as a S32.
int64 metadata_size = sizeof(int32) * shape.dimensions_size();
return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size;
};
}

View File

@ -124,8 +124,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
return bindings_.GetIrArray(inst, consumer, shape_index);
}
// A convenient helper for calling HloToIrBindings::GetBasePointer.
llvm::Value* GetBasePointer(const HloInstruction& inst) const {
return bindings_.GetBasePointer(inst);
llvm::Value* GetBasePointer(const HloInstruction& inst,
ShapeIndexView shape_index = {}) const {
return bindings_.GetBasePointer(inst, shape_index);
}
// Generates the IrArray for each output of an hlo instruction and returns

View File

@ -371,7 +371,233 @@ Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
return IrEmitter::HandleConvolution(convolution);
}
// Input = {dynamic array(with dynamic dimension meta data at the end)}
// Output = {static array, dynamic_dim0, dynamic_dim1}
Status IrEmitterUnnested::HandlePadToStatic(HloInstruction* pad_to_static) {
int unroll_factor = 1;
string ir_name = IrName(pad_to_static);
auto kernel_thunk = BuildKernelThunk(pad_to_static,
/*implements_whole_instruction=*/true,
/*unroll_factor=*/unroll_factor);
// pseudo code for padToStatic on a 2d array
// int* source_array = input[0];
// int* dest_array = output[0];
std::vector<llvm::Value*> dynamic_dims;
const Shape& data_shape = ShapeUtil::GetSubshape(pad_to_static->shape(), {0});
const Shape& input_shape = pad_to_static->operand(0)->shape();
llvm_ir::IrArray data_array = GetIrArray(*pad_to_static, *pad_to_static, {0});
llvm::Value* source_buffer = GetBasePointer(*pad_to_static->operand(0));
llvm::Value* raw_buffer =
b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
int64 raw_data_size =
ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape));
// int* dyn_dim0_size = source_array + meta_data_offset;
// int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
for (int64 i = 1; i < pad_to_static->shape().tuple_shapes_size(); ++i) {
// Dynamic size of each dimension is attached at the end of the source
// array(operand(0)). We need to extract these value.
const Shape& dim_shape =
ShapeUtil::GetSubshape(pad_to_static->shape(), {i});
TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
const int64 dim_index = i - 1;
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
llvm::Value* dyn_dim_size = b_.CreateLoad(
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
"dyn_dim_size");
dynamic_dims.push_back(dyn_dim_size);
}
// only one thread need to store the dynamic index
// int thread_id = GetThreadId();
// int block_id = GetBlockId();
// if (thread_id == 0 && block_id == 0) {
// *output[1] = *dyn_dim0_size;
// *output[2] = *dyn_dim1_size;
// }
KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
for (int64 i = 1; i < pad_to_static->shape().tuple_shapes_size(); ++i) {
llvm::Value* dest_dim_size_address = GetBasePointer(*pad_to_static, {i});
// output[i] stores dynamic_dim_(i-1)
b_.CreateStore(dynamic_dims[i - 1],
b_.CreateBitCast(dest_dim_size_address,
b_.getInt32Ty()->getPointerTo()));
}
});
// int dyn_element_total = 1;
// dyn_element_total *= *dyn_dim0_size;
// dyn_element_total *= *dyn_dim1_size;
llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
for (llvm::Value* dynamic_dim : dynamic_dims) {
dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
/*Name=*/"dyn_element_total");
}
// linear_index = block_id * thread_per_block + thread_id;
// if (linear_index < max_num_element) {
// Index static_index =
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
// if (linerized_index < dyn_element_total) {
// Index dyn_index =
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
// dest_array[dyn_index.dim0][dyn_index.dim1] =
// source_array[static_index.dim0][static_index.dim1];
// }
// }
llvm_ir::LoopEmitter::BodyEmitter body_generator =
[&](const llvm_ir::IrArray::Index& array_index) -> Status {
llvm::Value* linearIndex =
array_index.Linearize(input_shape.dimensions(), &b_);
auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
b_.CreateICmpULT(linearIndex, dyn_element_total),
llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
// Set IR builder insertion point to the body of the if structure.
llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
absl::MakeSpan(dynamic_dims), &b_);
data_array.EmitWriteArrayElement(
dyn_index,
GetIrArray(*pad_to_static->operand(0), *pad_to_static)
.EmitReadArrayElement(array_index, &b_, /*name=*/""),
&b_, /*use_linear_index=*/false);
return Status::OK();
};
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
input_shape, ir_emitter_context_->device_description(), unroll_factor);
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
ir_emitter_context_->llvm_module());
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
unroll_factor)
.EmitLoop(ir_name,
GetIndexTypeForKernel(
pad_to_static, launch_dimensions.launch_bound(), &b_)));
thunk_sequence_->emplace_back(std::move(kernel_thunk));
return Status::OK();
}
// Input = {dynamic array(with dynamic dimension meta data at the end)}
// Output = {static array, dynamic_dim0, dynamic_dim1}
Status IrEmitterUnnested::HandleSliceToDynamic(
HloInstruction* slice_to_dynamic) {
int unroll_factor = 1;
string ir_name = IrName(slice_to_dynamic);
auto kernel_thunk = BuildKernelThunk(slice_to_dynamic,
/*implements_whole_instruction=*/true,
/*unroll_factor=*/unroll_factor);
std::vector<llvm::Value*> dynamic_dims;
const Shape& input_shape = slice_to_dynamic->operand(0)->shape();
const Shape& data_shape = slice_to_dynamic->shape();
int32 raw_data_size = ShapeUtil::ByteSizeOf(
ShapeUtil::MakeStaticShape(slice_to_dynamic->shape()));
// pseudo code for sliceToDynamic on a 2d array
// int* source_array = input[0];
// int* dest_array = output[0];
llvm::Value* dest_buffer = GetBasePointer(*slice_to_dynamic);
llvm::Value* raw_buffer =
b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
llvm_ir::IrArray data_array =
GetIrArray(*slice_to_dynamic, *slice_to_dynamic);
// calculate the location where metadata needs to be inserted
// int* dyn_dim0_size = dest_array + meta_data_offset;
// int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
for (int64 i = 1; i < slice_to_dynamic->operand_count(); ++i) {
// const int64 dim_index = i - 1;
llvm::Value* source_buffer = GetBasePointer(*slice_to_dynamic->operand(i));
llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
dynamic_dims.push_back(dyn_dim_size);
}
// only one thread need to store the dynamic index
// int thread_id = GetThreadId();
// int block_id = GetBlockId();
// if (thread_id == 0 && block_id == 0) {
// *dyn_dim0_size = *output[1];
// *dyn_dim1_size = *output[2];
// }
KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
for (int64 i = 1; i < slice_to_dynamic->operand_count(); ++i) {
const int64 dim_index = i - 1;
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
b_.getInt8Ty(), raw_buffer,
raw_data_size + dim_index * sizeof(int32));
// output[i] stores dynamic_dim_(i-1)
b_.CreateStore(
dynamic_dims[dim_index],
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
}
});
// int dyn_element_total = 1;
// dyn_element_total *= dyn_dim0_size;
// dyn_element_total *= dyn_dim1_size;
llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
for (llvm::Value* dynamic_dim : dynamic_dims) {
dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
/*Name=*/"dyn_element_total");
}
// linear_index = block_id * thread_per_block + thread_id;
// if (linear_index < max_num_element) {
// Index static_index =
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
// if (linerized_index < dyn_element_total) {
// Index dyn_index =
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
// dest_array[static_index.dim0][static_index.di] =
// source_array[dyn_index.dim0][dyn_index.dim1];
// }
// }
llvm_ir::LoopEmitter::BodyEmitter body_generator =
[&](const llvm_ir::IrArray::Index& array_index) -> Status {
llvm::Value* linearIndex =
array_index.Linearize(input_shape.dimensions(), &b_);
auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
b_.CreateICmpULT(linearIndex, dyn_element_total),
llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
// Set IR builder insertion point to the body of the if structure.
llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
absl::MakeSpan(dynamic_dims), &b_);
data_array.EmitWriteArrayElement(
array_index,
GetIrArray(*slice_to_dynamic->operand(0), *slice_to_dynamic)
.EmitReadArrayElement(dyn_index, &b_, /*name=*/"",
/*use_linear_index=*/false),
&b_);
return Status::OK();
};
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
input_shape, ir_emitter_context_->device_description(), unroll_factor);
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
ir_emitter_context_->llvm_module());
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
unroll_factor)
.EmitLoop(ir_name, GetIndexTypeForKernel(
slice_to_dynamic,
launch_dimensions.launch_bound(), &b_)));
thunk_sequence_->emplace_back(std::move(kernel_thunk));
return Status::OK();
}
Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
if (custom_call->custom_call_target() == "PadToStatic") {
return HandlePadToStatic(custom_call);
}
if (custom_call->custom_call_target() == "SliceToDynamic") {
return HandleSliceToDynamic(custom_call);
}
return ThunkEmitter(this).HandleCustomCall(custom_call);
}

View File

@ -146,6 +146,98 @@ class IrEmitterUnnested : public IrEmitter,
thunk_sequence_->emplace_back(std::move(thunk));
}
// Input = {static array, dynamic_dim0, dynamic_dim1}
// Output = {dynamic array(with dynamic dimension meta data at the end)}
// For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
// (`_` stands for padding)
// Input = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
// Output = {{1,2,3,4,5,6,_,_,_,_,2,3}}
// pseudo code for padToStatic on a 2d array
// ```
// void padToStatic(int** input, int** output, int thread_per_block,
// int meta_data_offset, int max_num_element,
// int static_dim0_size, int static_dim1_size) {
// int* source_array = input[0];
// int* dest_array = output[0];
// // extract the dynamic dimension from the source array's metadata
// int* dyn_dim0_size = source_array + meta_data_offset;
// int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
// // only one thread need to store the dynamic index
// int thread_id = GetThreadId();
// int block_id = GetBlockId();
// if (thread_id == 0 && block_id == 0) {
// *output[1] = *dyn_dim0_size;
// *output[2] = *dyn_dim1_size;
// }
// int dyn_element_total = 1;
// dyn_element_total *= *dyn_dim0_size;
// dyn_element_total *= *dyn_dim1_size;
// linear_index = block_id * thread_per_block + thread_id;
// if (linear_index < max_num_element) {
// Index static_index =
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
// if (linerized_index < dyn_element_total) {
// Index dyn_index =
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
// dest_array[dyn_index.dim0][dyn_index.dim1] =
// source_array[static_index.dim0][static_index.dim1];
// }
// }
// return;
// }
// ```
Status HandlePadToStatic(HloInstruction* pad_to_static);
// Input = {dynamic array(with dynamic dimension meta data at the end)}
// Output = {static array, dynamic_dim0, dynamic_dim1}
// For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
// (`_` stands for padding)
// Input = {{1,2,3,4,5,6,_,_,_,_,2,3}}
// Output = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
// pseudo code for sliceToDynamic on a 2d array
// ```
// void sliceToDynamic(int** input, int** output, int thread_per_block,
// int meta_data_offset, int max_num_element,
// int static_dim0_size, int static_dim1_size) {
// int* source_array = input[0];
// int* dest_array = output[0];
// // calculate the location where metadata needs to be inserted
// int* dyn_dim0_size = dest_array + meta_data_offset;
// int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
// // only one thread need to store the dynamic index
// int thread_id = GetThreadId();
// int block_id = GetBlockId();
// if (thread_id == 0 && block_id == 0) {
// *dyn_dim0_size = *output[1];
// *dyn_dim1_size = *output[2];
// }
// int dyn_element_total = 1;
// dyn_element_total *= *dyn_dim0_size;
// dyn_element_total *= *dyn_dim1_size;
// linear_index = block_id * thread_per_block + thread_id;
// if (linear_index < max_num_element) {
// Index static_index =
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
// if (linerized_index < dyn_element_total) {
// Index dyn_index =
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
// dest_array[static_index.dim0][static_index.dim1] =
// source_array[dyn_index.dim0][dyn_index.dim1];
// }
// }
// return;
// }
// ```
Status HandleSliceToDynamic(HloInstruction* slice_to_dynamic);
// A convenient helper for calling BufferAssignment::GetUniqueSlice.
StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
const HloInstruction& hlo, const ShapeIndex& index) const override {

View File

@ -249,6 +249,21 @@ tf_cc_test(
],
)
tf_cc_test(
name = "gpu_dyn_shape_test",
srcs = ["gpu_dyn_shape_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cc_test(
name = "gpu_ftz_test",
srcs = ["gpu_ftz_test.cc"],

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <utility>
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
namespace xla {
namespace gpu {
class GpuDynamicShapeTest : public GpuCodegenTest {};
TEST_F(GpuDynamicShapeTest, DynamicShapeR2) {
HloComputation::Builder builder(TestName());
xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
dyn_input_shape.set_dynamic_dimension(0, true);
HloInstruction* param_x = builder.AddInstruction(
HloInstruction::CreateParameter(0, dyn_input_shape, "x"));
builder.AddInstruction(HloInstruction::CreateUnary(
dyn_input_shape, HloOpcode::kNegate, param_x));
auto hlo_module = CreateNewVerifiedModule();
hlo_module->AddEntryComputation(builder.Build());
CompileAndVerifyIr(std::move(hlo_module),
R"(
; CHECK-LABEL: is_thred_0-true
; CHECK_LABEL: custom-call.in_dyn_bounds-true
; CHECK_LABEL: custom-call.in_bounds-true
; CHECK: %[[dyn_dim_size:.*]] = load i32, i32*
; CHECK: %[[dyn_element_total:.*]] = mul i32 1, %[[dyn_dim_size:.*]]
; CHECK: %[[linear_index:.*]] = add nuw nsw i32
; CHECK: %[[linear_index_in_range:.*]] = icmp ult i32 %[[linear_index:.*]],
; CHECK: store i32 %[[dyn_dim_size:.*]], i32*
)",
/*match_optimized_ir=*/false);
}
} // namespace gpu
} // namespace xla

View File

@ -71,6 +71,32 @@ void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
}
}
void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
llvm::Value* linear, const Shape& shape,
absl::Span<llvm::Value*> dynamic_dims,
llvm::IRBuilder<>* b) const {
CHECK_EQ(shape.dimensions_size(), dynamic_dims.size());
CHECK_EQ(multidim_.size(), shape.rank());
llvm::Value* divisor = GetConstantWithIndexType(1);
const Layout& layout = shape.layout();
for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
int64 dimension = layout.minor_to_major(i);
// If i is not the last dimension, compute
// (linear_index / divisor) % current_dimension.
// If i is the last dimension, we can skip the mod, because we assume that
// linear is in bounds.
auto* quot = b->CreateUDiv(linear, divisor, "quot");
if (i < layout.minor_to_major_size() - 1) {
(*multidim)[dimension] =
b->CreateURem(quot, dynamic_dims[dimension], "dim_value");
divisor = b->CreateMul(divisor, dynamic_dims[dimension], "divisor");
} else {
(*multidim)[dimension] = quot;
}
}
}
IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
llvm::IRBuilder<>* b)
: multidim_(shape.rank()),
@ -85,6 +111,21 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
Delinearize(&multidim_, linear, shape, b);
}
IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
absl::Span<llvm::Value*> dynamic_dims,
llvm::IRBuilder<>* b)
: multidim_(shape.rank()),
linear_(linear),
layout_(shape.layout()),
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
CHECK_NE(linear, nullptr);
index_type_ = linear->getType();
CHECK(LayoutUtil::HasLayout(shape))
<< "Shape " << ShapeUtil::HumanStringWithLayout(shape)
<< " should have a layout.";
Delinearize(&multidim_, linear, shape, dynamic_dims, b);
}
IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
absl::Span<int64 const> dimensions,
llvm::Type* index_type)

View File

@ -66,6 +66,11 @@ class IrArray {
// Precondition: "shape" has a layout.
Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
// Similar to the above constructor except using "dynamic_dims" instead of
// shape's static dimension to constructs the index.
Index(llvm::Value* linear, const Shape& shape,
absl::Span<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b);
// Constructs an index from a multi-dimensional index. 'shape' is the shape
// for which the multi-dimensional index is used. 'index_type' is the type
// of the index.
@ -180,6 +185,11 @@ class IrArray {
void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
const Shape& shape, llvm::IRBuilder<>* b) const;
// Delinearize the linear index with the dynamic dimensions.
void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
const Shape& shape, absl::Span<llvm::Value*> dynamic_dims,
llvm::IRBuilder<>* b) const;
std::vector<llvm::Value*> multidim_;
// These values are purely for efficiency; `multidim_` is enough to find the

View File

@ -146,11 +146,6 @@ class XrtClientSession : public ClientSession {
string* xla_test_device_ptr; // initial value set in main()
string* xla_platform_ptr; // initial value set in main()
bool SupportDynamicShapes() {
// TODO(jackcao): Support dynamic shapes on XLA GPU.
return *xla_test_device_ptr != "XLA_GPU";
}
string DeviceFromFlag() {
string xla_test_device = *xla_test_device_ptr;
return absl::StrCat("/device:", xla_test_device, ":0");
@ -1126,10 +1121,6 @@ TEST(RawApiTest, CompileAndExecute) {
}
TEST(RawApiTest, DynamicR1Test) {
if (!SupportDynamicShapes()) {
GTEST_SKIP()
<< "Skipping the test if backend doesn't support dynamic shapes";
}
xrt::XLAAllocation p0;
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
xrt::XLAAllocation p1;
@ -1182,10 +1173,6 @@ TEST(RawApiTest, DynamicR1Test) {
}
TEST(RawApiTest, DynamicR2Test) {
if (!SupportDynamicShapes()) {
GTEST_SKIP()
<< "Skipping the test if backend doesn't support dynamic shapes";
}
xrt::XLAAllocation p0;
*p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f},
{1.5f, 2.5f, 3.0f, -2.0f}})
@ -1243,10 +1230,6 @@ TEST(RawApiTest, DynamicR2Test) {
}
TEST(RawApiTest, DynamicR1TupleTest) {
if (!SupportDynamicShapes()) {
GTEST_SKIP()
<< "Skipping the test if backend doesn't support dynamic shapes";
}
xrt::XLAAllocation p0;
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
xrt::XLAAllocation p1;
@ -1307,10 +1290,6 @@ TEST(RawApiTest, DynamicR1TupleTest) {
}
TEST(RawApiTest, AcceptDynamicR1TupleTest) {
if (!SupportDynamicShapes()) {
GTEST_SKIP()
<< "Skipping the test if backend doesn't support dynamic shapes";
}
xrt::XLAAllocation p0;
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
xrt::XLAAllocation p1;
@ -1373,10 +1352,6 @@ TEST(RawApiTest, AcceptDynamicR1TupleTest) {
}
TEST(RawApiTest, AcceptDynamicR1Test) {
if (!SupportDynamicShapes()) {
GTEST_SKIP()
<< "Skipping the test if backend doesn't support dynamic shapes";
}
xrt::XLAAllocation p0;
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
xrt::XLAAllocation p1;
@ -1424,13 +1399,9 @@ TEST(RawApiTest, AcceptDynamicR1Test) {
}
TEST(RawApiTest, AcceptDynamicR2Test) {
if (!SupportDynamicShapes()) {
GTEST_SKIP()
<< "Skipping the test if backend doesn't support dynamic shapes";
}
xrt::XLAAllocation p0;
*p0.mutable_value() =
xla::LiteralUtil::CreateR2({{-1.0f, 3.0f, 1.0f}, {-2.0f, -1.0f, 3.0f}})
xla::LiteralUtil::CreateR2({{-1.0f, 2.0f, 3.0f}, {-4.0f, -5.0f, 6.0f}})
.ToProto();
xrt::XLAComputation c;
@ -1468,7 +1439,7 @@ TEST(RawApiTest, AcceptDynamicR2Test) {
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
auto expected = xla::LiteralUtil::CreateR2<float>(
{{1.0f, -3.0f, -1.0f}, {2.0f, 1.0f, -3.0f}});
{{1.0f, -2.0f, -3.0f}, {4.0f, 5.0f, -6.0f}});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}