Implement the llvm lowering for the customcall sliceToDynamic and padToStatic for on XLA:GPU.
PiperOrigin-RevId: 315784895 Change-Id: Ibfda342ea7a0b616cb34c11198cfa38ce1cef6a9
This commit is contained in:
parent
83af443dc7
commit
64a5248407
@ -1194,6 +1194,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||
"//tensorflow/compiler/xla/service:dump",
|
||||
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
|
||||
"//tensorflow/compiler/xla/service:dynamic_padder",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
||||
@ -157,6 +158,25 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
// most ops.
|
||||
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
|
||||
|
||||
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
|
||||
// where possible. Not every batchnorm op can be implemented as a call to
|
||||
// cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
|
||||
if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
|
||||
// Since BatchNorm inference is essentially pointwise operations, it is
|
||||
// always advantageous to use kernel fusion rather than cudnn.
|
||||
pipeline.AddPass<BatchNormExpander>(
|
||||
/*rewrite_training_op=*/false,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/false);
|
||||
pipeline.AddPass<CudnnBatchNormRewriter>();
|
||||
}
|
||||
pipeline.AddPass<BatchNormExpander>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true);
|
||||
|
||||
pipeline.AddPass<DynamicPadder>();
|
||||
|
||||
{
|
||||
auto& pass =
|
||||
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
|
||||
@ -164,23 +184,6 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
/*layout_sensitive=*/false,
|
||||
/*allow_mixed_precision=*/false);
|
||||
|
||||
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
|
||||
// where possible. Not every batchnorm op can be implemented as a call to
|
||||
// cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
|
||||
if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
|
||||
// Since BatchNorm inference is essentially pointwise operations, it is
|
||||
// always advantageous to use kernel fusion rather than cudnn.
|
||||
pass.AddPass<BatchNormExpander>(
|
||||
/*rewrite_training_op=*/false,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/false);
|
||||
pass.AddPass<CudnnBatchNormRewriter>();
|
||||
}
|
||||
pass.AddPass<BatchNormExpander>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_inference_op=*/true,
|
||||
/*rewrite_grad_op=*/true);
|
||||
|
||||
pipeline.AddPass<HloGetDimensionSizeRewriter>();
|
||||
|
||||
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
|
||||
|
@ -93,9 +93,13 @@ class GpuCompiler : public LLVMCompiler {
|
||||
|
||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
||||
// Capture just the pointer size, not the entire GpuCompiler object.
|
||||
int64 pointer_size = pointer_size_;
|
||||
return [pointer_size](const Shape& shape) {
|
||||
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
||||
return [pointer_size = pointer_size_](const Shape& shape) {
|
||||
if (shape.is_static() || shape.IsTuple()) {
|
||||
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
||||
}
|
||||
// Each dynamic dimension size is represented as a S32.
|
||||
int64 metadata_size = sizeof(int32) * shape.dimensions_size();
|
||||
return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size;
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -124,8 +124,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
||||
return bindings_.GetIrArray(inst, consumer, shape_index);
|
||||
}
|
||||
// A convenient helper for calling HloToIrBindings::GetBasePointer.
|
||||
llvm::Value* GetBasePointer(const HloInstruction& inst) const {
|
||||
return bindings_.GetBasePointer(inst);
|
||||
llvm::Value* GetBasePointer(const HloInstruction& inst,
|
||||
ShapeIndexView shape_index = {}) const {
|
||||
return bindings_.GetBasePointer(inst, shape_index);
|
||||
}
|
||||
|
||||
// Generates the IrArray for each output of an hlo instruction and returns
|
||||
|
@ -371,7 +371,233 @@ Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
|
||||
return IrEmitter::HandleConvolution(convolution);
|
||||
}
|
||||
|
||||
// Input = {dynamic array(with dynamic dimension meta data at the end)}
|
||||
// Output = {static array, dynamic_dim0, dynamic_dim1}
|
||||
Status IrEmitterUnnested::HandlePadToStatic(HloInstruction* pad_to_static) {
|
||||
int unroll_factor = 1;
|
||||
string ir_name = IrName(pad_to_static);
|
||||
auto kernel_thunk = BuildKernelThunk(pad_to_static,
|
||||
/*implements_whole_instruction=*/true,
|
||||
/*unroll_factor=*/unroll_factor);
|
||||
// pseudo code for padToStatic on a 2d array
|
||||
// int* source_array = input[0];
|
||||
// int* dest_array = output[0];
|
||||
std::vector<llvm::Value*> dynamic_dims;
|
||||
const Shape& data_shape = ShapeUtil::GetSubshape(pad_to_static->shape(), {0});
|
||||
const Shape& input_shape = pad_to_static->operand(0)->shape();
|
||||
llvm_ir::IrArray data_array = GetIrArray(*pad_to_static, *pad_to_static, {0});
|
||||
llvm::Value* source_buffer = GetBasePointer(*pad_to_static->operand(0));
|
||||
llvm::Value* raw_buffer =
|
||||
b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
|
||||
int64 raw_data_size =
|
||||
ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape));
|
||||
|
||||
// int* dyn_dim0_size = source_array + meta_data_offset;
|
||||
// int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
|
||||
for (int64 i = 1; i < pad_to_static->shape().tuple_shapes_size(); ++i) {
|
||||
// Dynamic size of each dimension is attached at the end of the source
|
||||
// array(operand(0)). We need to extract these value.
|
||||
const Shape& dim_shape =
|
||||
ShapeUtil::GetSubshape(pad_to_static->shape(), {i});
|
||||
TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
|
||||
|
||||
const int64 dim_index = i - 1;
|
||||
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
|
||||
b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
|
||||
llvm::Value* dyn_dim_size = b_.CreateLoad(
|
||||
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
|
||||
"dyn_dim_size");
|
||||
dynamic_dims.push_back(dyn_dim_size);
|
||||
}
|
||||
|
||||
// only one thread need to store the dynamic index
|
||||
// int thread_id = GetThreadId();
|
||||
// int block_id = GetBlockId();
|
||||
// if (thread_id == 0 && block_id == 0) {
|
||||
// *output[1] = *dyn_dim0_size;
|
||||
// *output[2] = *dyn_dim1_size;
|
||||
// }
|
||||
KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
|
||||
for (int64 i = 1; i < pad_to_static->shape().tuple_shapes_size(); ++i) {
|
||||
llvm::Value* dest_dim_size_address = GetBasePointer(*pad_to_static, {i});
|
||||
// output[i] stores dynamic_dim_(i-1)
|
||||
b_.CreateStore(dynamic_dims[i - 1],
|
||||
b_.CreateBitCast(dest_dim_size_address,
|
||||
b_.getInt32Ty()->getPointerTo()));
|
||||
}
|
||||
});
|
||||
|
||||
// int dyn_element_total = 1;
|
||||
// dyn_element_total *= *dyn_dim0_size;
|
||||
// dyn_element_total *= *dyn_dim1_size;
|
||||
llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
|
||||
for (llvm::Value* dynamic_dim : dynamic_dims) {
|
||||
dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
|
||||
/*Name=*/"dyn_element_total");
|
||||
}
|
||||
|
||||
// linear_index = block_id * thread_per_block + thread_id;
|
||||
// if (linear_index < max_num_element) {
|
||||
// Index static_index =
|
||||
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
|
||||
// if (linerized_index < dyn_element_total) {
|
||||
// Index dyn_index =
|
||||
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
|
||||
// dest_array[dyn_index.dim0][dyn_index.dim1] =
|
||||
// source_array[static_index.dim0][static_index.dim1];
|
||||
// }
|
||||
// }
|
||||
llvm_ir::LoopEmitter::BodyEmitter body_generator =
|
||||
[&](const llvm_ir::IrArray::Index& array_index) -> Status {
|
||||
llvm::Value* linearIndex =
|
||||
array_index.Linearize(input_shape.dimensions(), &b_);
|
||||
auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
|
||||
b_.CreateICmpULT(linearIndex, dyn_element_total),
|
||||
llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
|
||||
// Set IR builder insertion point to the body of the if structure.
|
||||
llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
|
||||
llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
|
||||
absl::MakeSpan(dynamic_dims), &b_);
|
||||
data_array.EmitWriteArrayElement(
|
||||
dyn_index,
|
||||
GetIrArray(*pad_to_static->operand(0), *pad_to_static)
|
||||
.EmitReadArrayElement(array_index, &b_, /*name=*/""),
|
||||
&b_, /*use_linear_index=*/false);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
input_shape, ir_emitter_context_->device_description(), unroll_factor);
|
||||
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
|
||||
ir_emitter_context_->llvm_module());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
|
||||
unroll_factor)
|
||||
.EmitLoop(ir_name,
|
||||
GetIndexTypeForKernel(
|
||||
pad_to_static, launch_dimensions.launch_bound(), &b_)));
|
||||
thunk_sequence_->emplace_back(std::move(kernel_thunk));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Input = {dynamic array(with dynamic dimension meta data at the end)}
|
||||
// Output = {static array, dynamic_dim0, dynamic_dim1}
|
||||
Status IrEmitterUnnested::HandleSliceToDynamic(
|
||||
HloInstruction* slice_to_dynamic) {
|
||||
int unroll_factor = 1;
|
||||
string ir_name = IrName(slice_to_dynamic);
|
||||
auto kernel_thunk = BuildKernelThunk(slice_to_dynamic,
|
||||
/*implements_whole_instruction=*/true,
|
||||
/*unroll_factor=*/unroll_factor);
|
||||
|
||||
std::vector<llvm::Value*> dynamic_dims;
|
||||
const Shape& input_shape = slice_to_dynamic->operand(0)->shape();
|
||||
const Shape& data_shape = slice_to_dynamic->shape();
|
||||
int32 raw_data_size = ShapeUtil::ByteSizeOf(
|
||||
ShapeUtil::MakeStaticShape(slice_to_dynamic->shape()));
|
||||
// pseudo code for sliceToDynamic on a 2d array
|
||||
// int* source_array = input[0];
|
||||
// int* dest_array = output[0];
|
||||
llvm::Value* dest_buffer = GetBasePointer(*slice_to_dynamic);
|
||||
llvm::Value* raw_buffer =
|
||||
b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
|
||||
llvm_ir::IrArray data_array =
|
||||
GetIrArray(*slice_to_dynamic, *slice_to_dynamic);
|
||||
|
||||
// calculate the location where metadata needs to be inserted
|
||||
// int* dyn_dim0_size = dest_array + meta_data_offset;
|
||||
// int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
|
||||
for (int64 i = 1; i < slice_to_dynamic->operand_count(); ++i) {
|
||||
// const int64 dim_index = i - 1;
|
||||
llvm::Value* source_buffer = GetBasePointer(*slice_to_dynamic->operand(i));
|
||||
llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
|
||||
dynamic_dims.push_back(dyn_dim_size);
|
||||
}
|
||||
|
||||
// only one thread need to store the dynamic index
|
||||
// int thread_id = GetThreadId();
|
||||
// int block_id = GetBlockId();
|
||||
// if (thread_id == 0 && block_id == 0) {
|
||||
// *dyn_dim0_size = *output[1];
|
||||
// *dyn_dim1_size = *output[2];
|
||||
// }
|
||||
KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
|
||||
for (int64 i = 1; i < slice_to_dynamic->operand_count(); ++i) {
|
||||
const int64 dim_index = i - 1;
|
||||
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
|
||||
b_.getInt8Ty(), raw_buffer,
|
||||
raw_data_size + dim_index * sizeof(int32));
|
||||
// output[i] stores dynamic_dim_(i-1)
|
||||
b_.CreateStore(
|
||||
dynamic_dims[dim_index],
|
||||
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
|
||||
}
|
||||
});
|
||||
|
||||
// int dyn_element_total = 1;
|
||||
// dyn_element_total *= dyn_dim0_size;
|
||||
// dyn_element_total *= dyn_dim1_size;
|
||||
llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
|
||||
for (llvm::Value* dynamic_dim : dynamic_dims) {
|
||||
dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
|
||||
/*Name=*/"dyn_element_total");
|
||||
}
|
||||
|
||||
// linear_index = block_id * thread_per_block + thread_id;
|
||||
// if (linear_index < max_num_element) {
|
||||
// Index static_index =
|
||||
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
|
||||
// if (linerized_index < dyn_element_total) {
|
||||
// Index dyn_index =
|
||||
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
|
||||
// dest_array[static_index.dim0][static_index.di] =
|
||||
// source_array[dyn_index.dim0][dyn_index.dim1];
|
||||
// }
|
||||
// }
|
||||
llvm_ir::LoopEmitter::BodyEmitter body_generator =
|
||||
[&](const llvm_ir::IrArray::Index& array_index) -> Status {
|
||||
llvm::Value* linearIndex =
|
||||
array_index.Linearize(input_shape.dimensions(), &b_);
|
||||
auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
|
||||
b_.CreateICmpULT(linearIndex, dyn_element_total),
|
||||
llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
|
||||
// Set IR builder insertion point to the body of the if structure.
|
||||
llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
|
||||
llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
|
||||
absl::MakeSpan(dynamic_dims), &b_);
|
||||
|
||||
data_array.EmitWriteArrayElement(
|
||||
array_index,
|
||||
GetIrArray(*slice_to_dynamic->operand(0), *slice_to_dynamic)
|
||||
.EmitReadArrayElement(dyn_index, &b_, /*name=*/"",
|
||||
/*use_linear_index=*/false),
|
||||
&b_);
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
input_shape, ir_emitter_context_->device_description(), unroll_factor);
|
||||
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
|
||||
ir_emitter_context_->llvm_module());
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
|
||||
unroll_factor)
|
||||
.EmitLoop(ir_name, GetIndexTypeForKernel(
|
||||
slice_to_dynamic,
|
||||
launch_dimensions.launch_bound(), &b_)));
|
||||
thunk_sequence_->emplace_back(std::move(kernel_thunk));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
if (custom_call->custom_call_target() == "PadToStatic") {
|
||||
return HandlePadToStatic(custom_call);
|
||||
}
|
||||
if (custom_call->custom_call_target() == "SliceToDynamic") {
|
||||
return HandleSliceToDynamic(custom_call);
|
||||
}
|
||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||
}
|
||||
|
||||
|
@ -146,6 +146,98 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
thunk_sequence_->emplace_back(std::move(thunk));
|
||||
}
|
||||
|
||||
// Input = {static array, dynamic_dim0, dynamic_dim1}
|
||||
// Output = {dynamic array(with dynamic dimension meta data at the end)}
|
||||
// For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
|
||||
// (`_` stands for padding)
|
||||
// Input = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
|
||||
// Output = {{1,2,3,4,5,6,_,_,_,_,2,3}}
|
||||
|
||||
// pseudo code for padToStatic on a 2d array
|
||||
// ```
|
||||
// void padToStatic(int** input, int** output, int thread_per_block,
|
||||
// int meta_data_offset, int max_num_element,
|
||||
// int static_dim0_size, int static_dim1_size) {
|
||||
// int* source_array = input[0];
|
||||
// int* dest_array = output[0];
|
||||
|
||||
// // extract the dynamic dimension from the source array's metadata
|
||||
// int* dyn_dim0_size = source_array + meta_data_offset;
|
||||
// int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
|
||||
|
||||
// // only one thread need to store the dynamic index
|
||||
// int thread_id = GetThreadId();
|
||||
// int block_id = GetBlockId();
|
||||
// if (thread_id == 0 && block_id == 0) {
|
||||
// *output[1] = *dyn_dim0_size;
|
||||
// *output[2] = *dyn_dim1_size;
|
||||
// }
|
||||
|
||||
// int dyn_element_total = 1;
|
||||
// dyn_element_total *= *dyn_dim0_size;
|
||||
// dyn_element_total *= *dyn_dim1_size;
|
||||
// linear_index = block_id * thread_per_block + thread_id;
|
||||
// if (linear_index < max_num_element) {
|
||||
// Index static_index =
|
||||
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
|
||||
// if (linerized_index < dyn_element_total) {
|
||||
// Index dyn_index =
|
||||
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
|
||||
// dest_array[dyn_index.dim0][dyn_index.dim1] =
|
||||
// source_array[static_index.dim0][static_index.dim1];
|
||||
// }
|
||||
// }
|
||||
// return;
|
||||
// }
|
||||
// ```
|
||||
Status HandlePadToStatic(HloInstruction* pad_to_static);
|
||||
|
||||
// Input = {dynamic array(with dynamic dimension meta data at the end)}
|
||||
// Output = {static array, dynamic_dim0, dynamic_dim1}
|
||||
// For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
|
||||
// (`_` stands for padding)
|
||||
// Input = {{1,2,3,4,5,6,_,_,_,_,2,3}}
|
||||
// Output = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
|
||||
|
||||
// pseudo code for sliceToDynamic on a 2d array
|
||||
// ```
|
||||
// void sliceToDynamic(int** input, int** output, int thread_per_block,
|
||||
// int meta_data_offset, int max_num_element,
|
||||
// int static_dim0_size, int static_dim1_size) {
|
||||
// int* source_array = input[0];
|
||||
// int* dest_array = output[0];
|
||||
|
||||
// // calculate the location where metadata needs to be inserted
|
||||
// int* dyn_dim0_size = dest_array + meta_data_offset;
|
||||
// int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
|
||||
|
||||
// // only one thread need to store the dynamic index
|
||||
// int thread_id = GetThreadId();
|
||||
// int block_id = GetBlockId();
|
||||
// if (thread_id == 0 && block_id == 0) {
|
||||
// *dyn_dim0_size = *output[1];
|
||||
// *dyn_dim1_size = *output[2];
|
||||
// }
|
||||
|
||||
// int dyn_element_total = 1;
|
||||
// dyn_element_total *= *dyn_dim0_size;
|
||||
// dyn_element_total *= *dyn_dim1_size;
|
||||
// linear_index = block_id * thread_per_block + thread_id;
|
||||
// if (linear_index < max_num_element) {
|
||||
// Index static_index =
|
||||
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
|
||||
// if (linerized_index < dyn_element_total) {
|
||||
// Index dyn_index =
|
||||
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
|
||||
// dest_array[static_index.dim0][static_index.dim1] =
|
||||
// source_array[dyn_index.dim0][dyn_index.dim1];
|
||||
// }
|
||||
// }
|
||||
// return;
|
||||
// }
|
||||
// ```
|
||||
Status HandleSliceToDynamic(HloInstruction* slice_to_dynamic);
|
||||
|
||||
// A convenient helper for calling BufferAssignment::GetUniqueSlice.
|
||||
StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
|
||||
const HloInstruction& hlo, const ShapeIndex& index) const override {
|
||||
|
@ -249,6 +249,21 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gpu_dyn_shape_test",
|
||||
srcs = ["gpu_dyn_shape_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gpu_ftz_test",
|
||||
srcs = ["gpu_ftz_test.cc"],
|
||||
|
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
class GpuDynamicShapeTest : public GpuCodegenTest {};
|
||||
|
||||
TEST_F(GpuDynamicShapeTest, DynamicShapeR2) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
|
||||
dyn_input_shape.set_dynamic_dimension(0, true);
|
||||
HloInstruction* param_x = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, dyn_input_shape, "x"));
|
||||
|
||||
builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
dyn_input_shape, HloOpcode::kNegate, param_x));
|
||||
auto hlo_module = CreateNewVerifiedModule();
|
||||
hlo_module->AddEntryComputation(builder.Build());
|
||||
|
||||
CompileAndVerifyIr(std::move(hlo_module),
|
||||
R"(
|
||||
; CHECK-LABEL: is_thred_0-true
|
||||
; CHECK_LABEL: custom-call.in_dyn_bounds-true
|
||||
; CHECK_LABEL: custom-call.in_bounds-true
|
||||
; CHECK: %[[dyn_dim_size:.*]] = load i32, i32*
|
||||
; CHECK: %[[dyn_element_total:.*]] = mul i32 1, %[[dyn_dim_size:.*]]
|
||||
; CHECK: %[[linear_index:.*]] = add nuw nsw i32
|
||||
; CHECK: %[[linear_index_in_range:.*]] = icmp ult i32 %[[linear_index:.*]],
|
||||
; CHECK: store i32 %[[dyn_dim_size:.*]], i32*
|
||||
)",
|
||||
/*match_optimized_ir=*/false);
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -71,6 +71,32 @@ void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
|
||||
}
|
||||
}
|
||||
|
||||
void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
|
||||
llvm::Value* linear, const Shape& shape,
|
||||
absl::Span<llvm::Value*> dynamic_dims,
|
||||
llvm::IRBuilder<>* b) const {
|
||||
CHECK_EQ(shape.dimensions_size(), dynamic_dims.size());
|
||||
CHECK_EQ(multidim_.size(), shape.rank());
|
||||
llvm::Value* divisor = GetConstantWithIndexType(1);
|
||||
const Layout& layout = shape.layout();
|
||||
for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
|
||||
int64 dimension = layout.minor_to_major(i);
|
||||
|
||||
// If i is not the last dimension, compute
|
||||
// (linear_index / divisor) % current_dimension.
|
||||
// If i is the last dimension, we can skip the mod, because we assume that
|
||||
// linear is in bounds.
|
||||
auto* quot = b->CreateUDiv(linear, divisor, "quot");
|
||||
if (i < layout.minor_to_major_size() - 1) {
|
||||
(*multidim)[dimension] =
|
||||
b->CreateURem(quot, dynamic_dims[dimension], "dim_value");
|
||||
divisor = b->CreateMul(divisor, dynamic_dims[dimension], "divisor");
|
||||
} else {
|
||||
(*multidim)[dimension] = quot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
|
||||
llvm::IRBuilder<>* b)
|
||||
: multidim_(shape.rank()),
|
||||
@ -85,6 +111,21 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
|
||||
Delinearize(&multidim_, linear, shape, b);
|
||||
}
|
||||
|
||||
IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
|
||||
absl::Span<llvm::Value*> dynamic_dims,
|
||||
llvm::IRBuilder<>* b)
|
||||
: multidim_(shape.rank()),
|
||||
linear_(linear),
|
||||
layout_(shape.layout()),
|
||||
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
|
||||
CHECK_NE(linear, nullptr);
|
||||
index_type_ = linear->getType();
|
||||
CHECK(LayoutUtil::HasLayout(shape))
|
||||
<< "Shape " << ShapeUtil::HumanStringWithLayout(shape)
|
||||
<< " should have a layout.";
|
||||
Delinearize(&multidim_, linear, shape, dynamic_dims, b);
|
||||
}
|
||||
|
||||
IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
|
||||
absl::Span<int64 const> dimensions,
|
||||
llvm::Type* index_type)
|
||||
|
@ -66,6 +66,11 @@ class IrArray {
|
||||
// Precondition: "shape" has a layout.
|
||||
Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
|
||||
|
||||
// Similar to the above constructor except using "dynamic_dims" instead of
|
||||
// shape's static dimension to constructs the index.
|
||||
Index(llvm::Value* linear, const Shape& shape,
|
||||
absl::Span<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b);
|
||||
|
||||
// Constructs an index from a multi-dimensional index. 'shape' is the shape
|
||||
// for which the multi-dimensional index is used. 'index_type' is the type
|
||||
// of the index.
|
||||
@ -180,6 +185,11 @@ class IrArray {
|
||||
void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
|
||||
const Shape& shape, llvm::IRBuilder<>* b) const;
|
||||
|
||||
// Delinearize the linear index with the dynamic dimensions.
|
||||
void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
|
||||
const Shape& shape, absl::Span<llvm::Value*> dynamic_dims,
|
||||
llvm::IRBuilder<>* b) const;
|
||||
|
||||
std::vector<llvm::Value*> multidim_;
|
||||
|
||||
// These values are purely for efficiency; `multidim_` is enough to find the
|
||||
|
@ -146,11 +146,6 @@ class XrtClientSession : public ClientSession {
|
||||
string* xla_test_device_ptr; // initial value set in main()
|
||||
string* xla_platform_ptr; // initial value set in main()
|
||||
|
||||
bool SupportDynamicShapes() {
|
||||
// TODO(jackcao): Support dynamic shapes on XLA GPU.
|
||||
return *xla_test_device_ptr != "XLA_GPU";
|
||||
}
|
||||
|
||||
string DeviceFromFlag() {
|
||||
string xla_test_device = *xla_test_device_ptr;
|
||||
return absl::StrCat("/device:", xla_test_device, ":0");
|
||||
@ -1126,10 +1121,6 @@ TEST(RawApiTest, CompileAndExecute) {
|
||||
}
|
||||
|
||||
TEST(RawApiTest, DynamicR1Test) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
GTEST_SKIP()
|
||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
|
||||
xrt::XLAAllocation p1;
|
||||
@ -1182,10 +1173,6 @@ TEST(RawApiTest, DynamicR1Test) {
|
||||
}
|
||||
|
||||
TEST(RawApiTest, DynamicR2Test) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
GTEST_SKIP()
|
||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f},
|
||||
{1.5f, 2.5f, 3.0f, -2.0f}})
|
||||
@ -1243,10 +1230,6 @@ TEST(RawApiTest, DynamicR2Test) {
|
||||
}
|
||||
|
||||
TEST(RawApiTest, DynamicR1TupleTest) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
GTEST_SKIP()
|
||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
|
||||
xrt::XLAAllocation p1;
|
||||
@ -1307,10 +1290,6 @@ TEST(RawApiTest, DynamicR1TupleTest) {
|
||||
}
|
||||
|
||||
TEST(RawApiTest, AcceptDynamicR1TupleTest) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
GTEST_SKIP()
|
||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
|
||||
xrt::XLAAllocation p1;
|
||||
@ -1373,10 +1352,6 @@ TEST(RawApiTest, AcceptDynamicR1TupleTest) {
|
||||
}
|
||||
|
||||
TEST(RawApiTest, AcceptDynamicR1Test) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
GTEST_SKIP()
|
||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
|
||||
xrt::XLAAllocation p1;
|
||||
@ -1424,13 +1399,9 @@ TEST(RawApiTest, AcceptDynamicR1Test) {
|
||||
}
|
||||
|
||||
TEST(RawApiTest, AcceptDynamicR2Test) {
|
||||
if (!SupportDynamicShapes()) {
|
||||
GTEST_SKIP()
|
||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
||||
}
|
||||
xrt::XLAAllocation p0;
|
||||
*p0.mutable_value() =
|
||||
xla::LiteralUtil::CreateR2({{-1.0f, 3.0f, 1.0f}, {-2.0f, -1.0f, 3.0f}})
|
||||
xla::LiteralUtil::CreateR2({{-1.0f, 2.0f, 3.0f}, {-4.0f, -5.0f, 6.0f}})
|
||||
.ToProto();
|
||||
|
||||
xrt::XLAComputation c;
|
||||
@ -1468,7 +1439,7 @@ TEST(RawApiTest, AcceptDynamicR2Test) {
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
|
||||
|
||||
auto expected = xla::LiteralUtil::CreateR2<float>(
|
||||
{{1.0f, -3.0f, -1.0f}, {2.0f, 1.0f, -3.0f}});
|
||||
{{1.0f, -2.0f, -3.0f}, {4.0f, 5.0f, -6.0f}});
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user