Implement the llvm lowering for the customcall sliceToDynamic and padToStatic for on XLA:GPU.
PiperOrigin-RevId: 315784895 Change-Id: Ibfda342ea7a0b616cb34c11198cfa38ce1cef6a9
This commit is contained in:
parent
83af443dc7
commit
64a5248407
@ -1194,6 +1194,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||||
"//tensorflow/compiler/xla/service:dump",
|
"//tensorflow/compiler/xla/service:dump",
|
||||||
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
|
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
|
||||||
|
"//tensorflow/compiler/xla/service:dynamic_padder",
|
||||||
"//tensorflow/compiler/xla/service:executable",
|
"//tensorflow/compiler/xla/service:executable",
|
||||||
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
"//tensorflow/compiler/xla/service:flatten_call_graph",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||||
#include "tensorflow/compiler/xla/service/dump.h"
|
#include "tensorflow/compiler/xla/service/dump.h"
|
||||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/dynamic_padder.h"
|
||||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
|
#include "tensorflow/compiler/xla/service/gpu/alias_passthrough_params.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
|
||||||
@ -157,6 +158,25 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
// most ops.
|
// most ops.
|
||||||
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
|
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
|
||||||
|
|
||||||
|
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
|
||||||
|
// where possible. Not every batchnorm op can be implemented as a call to
|
||||||
|
// cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
|
||||||
|
if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
|
||||||
|
// Since BatchNorm inference is essentially pointwise operations, it is
|
||||||
|
// always advantageous to use kernel fusion rather than cudnn.
|
||||||
|
pipeline.AddPass<BatchNormExpander>(
|
||||||
|
/*rewrite_training_op=*/false,
|
||||||
|
/*rewrite_inference_op=*/true,
|
||||||
|
/*rewrite_grad_op=*/false);
|
||||||
|
pipeline.AddPass<CudnnBatchNormRewriter>();
|
||||||
|
}
|
||||||
|
pipeline.AddPass<BatchNormExpander>(
|
||||||
|
/*rewrite_training_op=*/true,
|
||||||
|
/*rewrite_inference_op=*/true,
|
||||||
|
/*rewrite_grad_op=*/true);
|
||||||
|
|
||||||
|
pipeline.AddPass<DynamicPadder>();
|
||||||
|
|
||||||
{
|
{
|
||||||
auto& pass =
|
auto& pass =
|
||||||
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
|
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
|
||||||
@ -164,23 +184,6 @@ Status GpuCompiler::OptimizeHloModule(
|
|||||||
/*layout_sensitive=*/false,
|
/*layout_sensitive=*/false,
|
||||||
/*allow_mixed_precision=*/false);
|
/*allow_mixed_precision=*/false);
|
||||||
|
|
||||||
// If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls
|
|
||||||
// where possible. Not every batchnorm op can be implemented as a call to
|
|
||||||
// cudnn, so decompose any remaining batchnorm ops into a soup of HLOs.
|
|
||||||
if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) {
|
|
||||||
// Since BatchNorm inference is essentially pointwise operations, it is
|
|
||||||
// always advantageous to use kernel fusion rather than cudnn.
|
|
||||||
pass.AddPass<BatchNormExpander>(
|
|
||||||
/*rewrite_training_op=*/false,
|
|
||||||
/*rewrite_inference_op=*/true,
|
|
||||||
/*rewrite_grad_op=*/false);
|
|
||||||
pass.AddPass<CudnnBatchNormRewriter>();
|
|
||||||
}
|
|
||||||
pass.AddPass<BatchNormExpander>(
|
|
||||||
/*rewrite_training_op=*/true,
|
|
||||||
/*rewrite_inference_op=*/true,
|
|
||||||
/*rewrite_grad_op=*/true);
|
|
||||||
|
|
||||||
pipeline.AddPass<HloGetDimensionSizeRewriter>();
|
pipeline.AddPass<HloGetDimensionSizeRewriter>();
|
||||||
|
|
||||||
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
|
// BatchNormExpander can create zero-sized ops, so zero-sized HLO
|
||||||
|
@ -93,9 +93,13 @@ class GpuCompiler : public LLVMCompiler {
|
|||||||
|
|
||||||
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override {
|
||||||
// Capture just the pointer size, not the entire GpuCompiler object.
|
// Capture just the pointer size, not the entire GpuCompiler object.
|
||||||
int64 pointer_size = pointer_size_;
|
return [pointer_size = pointer_size_](const Shape& shape) {
|
||||||
return [pointer_size](const Shape& shape) {
|
if (shape.is_static() || shape.IsTuple()) {
|
||||||
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
return ShapeUtil::ByteSizeOf(shape, pointer_size);
|
||||||
|
}
|
||||||
|
// Each dynamic dimension size is represented as a S32.
|
||||||
|
int64 metadata_size = sizeof(int32) * shape.dimensions_size();
|
||||||
|
return ShapeUtil::ByteSizeOf(shape, pointer_size) + metadata_size;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -124,8 +124,9 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||||||
return bindings_.GetIrArray(inst, consumer, shape_index);
|
return bindings_.GetIrArray(inst, consumer, shape_index);
|
||||||
}
|
}
|
||||||
// A convenient helper for calling HloToIrBindings::GetBasePointer.
|
// A convenient helper for calling HloToIrBindings::GetBasePointer.
|
||||||
llvm::Value* GetBasePointer(const HloInstruction& inst) const {
|
llvm::Value* GetBasePointer(const HloInstruction& inst,
|
||||||
return bindings_.GetBasePointer(inst);
|
ShapeIndexView shape_index = {}) const {
|
||||||
|
return bindings_.GetBasePointer(inst, shape_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates the IrArray for each output of an hlo instruction and returns
|
// Generates the IrArray for each output of an hlo instruction and returns
|
||||||
|
@ -371,7 +371,233 @@ Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
|
|||||||
return IrEmitter::HandleConvolution(convolution);
|
return IrEmitter::HandleConvolution(convolution);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Input = {dynamic array(with dynamic dimension meta data at the end)}
|
||||||
|
// Output = {static array, dynamic_dim0, dynamic_dim1}
|
||||||
|
Status IrEmitterUnnested::HandlePadToStatic(HloInstruction* pad_to_static) {
|
||||||
|
int unroll_factor = 1;
|
||||||
|
string ir_name = IrName(pad_to_static);
|
||||||
|
auto kernel_thunk = BuildKernelThunk(pad_to_static,
|
||||||
|
/*implements_whole_instruction=*/true,
|
||||||
|
/*unroll_factor=*/unroll_factor);
|
||||||
|
// pseudo code for padToStatic on a 2d array
|
||||||
|
// int* source_array = input[0];
|
||||||
|
// int* dest_array = output[0];
|
||||||
|
std::vector<llvm::Value*> dynamic_dims;
|
||||||
|
const Shape& data_shape = ShapeUtil::GetSubshape(pad_to_static->shape(), {0});
|
||||||
|
const Shape& input_shape = pad_to_static->operand(0)->shape();
|
||||||
|
llvm_ir::IrArray data_array = GetIrArray(*pad_to_static, *pad_to_static, {0});
|
||||||
|
llvm::Value* source_buffer = GetBasePointer(*pad_to_static->operand(0));
|
||||||
|
llvm::Value* raw_buffer =
|
||||||
|
b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
|
||||||
|
int64 raw_data_size =
|
||||||
|
ShapeUtil::ByteSizeOf(ShapeUtil::MakeStaticShape(input_shape));
|
||||||
|
|
||||||
|
// int* dyn_dim0_size = source_array + meta_data_offset;
|
||||||
|
// int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
|
||||||
|
for (int64 i = 1; i < pad_to_static->shape().tuple_shapes_size(); ++i) {
|
||||||
|
// Dynamic size of each dimension is attached at the end of the source
|
||||||
|
// array(operand(0)). We need to extract these value.
|
||||||
|
const Shape& dim_shape =
|
||||||
|
ShapeUtil::GetSubshape(pad_to_static->shape(), {i});
|
||||||
|
TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
|
||||||
|
|
||||||
|
const int64 dim_index = i - 1;
|
||||||
|
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
|
||||||
|
b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
|
||||||
|
llvm::Value* dyn_dim_size = b_.CreateLoad(
|
||||||
|
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
|
||||||
|
"dyn_dim_size");
|
||||||
|
dynamic_dims.push_back(dyn_dim_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// only one thread need to store the dynamic index
|
||||||
|
// int thread_id = GetThreadId();
|
||||||
|
// int block_id = GetBlockId();
|
||||||
|
// if (thread_id == 0 && block_id == 0) {
|
||||||
|
// *output[1] = *dyn_dim0_size;
|
||||||
|
// *output[2] = *dyn_dim1_size;
|
||||||
|
// }
|
||||||
|
KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
|
||||||
|
for (int64 i = 1; i < pad_to_static->shape().tuple_shapes_size(); ++i) {
|
||||||
|
llvm::Value* dest_dim_size_address = GetBasePointer(*pad_to_static, {i});
|
||||||
|
// output[i] stores dynamic_dim_(i-1)
|
||||||
|
b_.CreateStore(dynamic_dims[i - 1],
|
||||||
|
b_.CreateBitCast(dest_dim_size_address,
|
||||||
|
b_.getInt32Ty()->getPointerTo()));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// int dyn_element_total = 1;
|
||||||
|
// dyn_element_total *= *dyn_dim0_size;
|
||||||
|
// dyn_element_total *= *dyn_dim1_size;
|
||||||
|
llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
|
||||||
|
for (llvm::Value* dynamic_dim : dynamic_dims) {
|
||||||
|
dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
|
||||||
|
/*Name=*/"dyn_element_total");
|
||||||
|
}
|
||||||
|
|
||||||
|
// linear_index = block_id * thread_per_block + thread_id;
|
||||||
|
// if (linear_index < max_num_element) {
|
||||||
|
// Index static_index =
|
||||||
|
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
|
||||||
|
// if (linerized_index < dyn_element_total) {
|
||||||
|
// Index dyn_index =
|
||||||
|
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
|
||||||
|
// dest_array[dyn_index.dim0][dyn_index.dim1] =
|
||||||
|
// source_array[static_index.dim0][static_index.dim1];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
llvm_ir::LoopEmitter::BodyEmitter body_generator =
|
||||||
|
[&](const llvm_ir::IrArray::Index& array_index) -> Status {
|
||||||
|
llvm::Value* linearIndex =
|
||||||
|
array_index.Linearize(input_shape.dimensions(), &b_);
|
||||||
|
auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
|
||||||
|
b_.CreateICmpULT(linearIndex, dyn_element_total),
|
||||||
|
llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
|
||||||
|
// Set IR builder insertion point to the body of the if structure.
|
||||||
|
llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
|
||||||
|
llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
|
||||||
|
absl::MakeSpan(dynamic_dims), &b_);
|
||||||
|
data_array.EmitWriteArrayElement(
|
||||||
|
dyn_index,
|
||||||
|
GetIrArray(*pad_to_static->operand(0), *pad_to_static)
|
||||||
|
.EmitReadArrayElement(array_index, &b_, /*name=*/""),
|
||||||
|
&b_, /*use_linear_index=*/false);
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||||
|
input_shape, ir_emitter_context_->device_description(), unroll_factor);
|
||||||
|
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
|
||||||
|
ir_emitter_context_->llvm_module());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
|
||||||
|
unroll_factor)
|
||||||
|
.EmitLoop(ir_name,
|
||||||
|
GetIndexTypeForKernel(
|
||||||
|
pad_to_static, launch_dimensions.launch_bound(), &b_)));
|
||||||
|
thunk_sequence_->emplace_back(std::move(kernel_thunk));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Input = {dynamic array(with dynamic dimension meta data at the end)}
|
||||||
|
// Output = {static array, dynamic_dim0, dynamic_dim1}
|
||||||
|
Status IrEmitterUnnested::HandleSliceToDynamic(
|
||||||
|
HloInstruction* slice_to_dynamic) {
|
||||||
|
int unroll_factor = 1;
|
||||||
|
string ir_name = IrName(slice_to_dynamic);
|
||||||
|
auto kernel_thunk = BuildKernelThunk(slice_to_dynamic,
|
||||||
|
/*implements_whole_instruction=*/true,
|
||||||
|
/*unroll_factor=*/unroll_factor);
|
||||||
|
|
||||||
|
std::vector<llvm::Value*> dynamic_dims;
|
||||||
|
const Shape& input_shape = slice_to_dynamic->operand(0)->shape();
|
||||||
|
const Shape& data_shape = slice_to_dynamic->shape();
|
||||||
|
int32 raw_data_size = ShapeUtil::ByteSizeOf(
|
||||||
|
ShapeUtil::MakeStaticShape(slice_to_dynamic->shape()));
|
||||||
|
// pseudo code for sliceToDynamic on a 2d array
|
||||||
|
// int* source_array = input[0];
|
||||||
|
// int* dest_array = output[0];
|
||||||
|
llvm::Value* dest_buffer = GetBasePointer(*slice_to_dynamic);
|
||||||
|
llvm::Value* raw_buffer =
|
||||||
|
b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
|
||||||
|
llvm_ir::IrArray data_array =
|
||||||
|
GetIrArray(*slice_to_dynamic, *slice_to_dynamic);
|
||||||
|
|
||||||
|
// calculate the location where metadata needs to be inserted
|
||||||
|
// int* dyn_dim0_size = dest_array + meta_data_offset;
|
||||||
|
// int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
|
||||||
|
for (int64 i = 1; i < slice_to_dynamic->operand_count(); ++i) {
|
||||||
|
// const int64 dim_index = i - 1;
|
||||||
|
llvm::Value* source_buffer = GetBasePointer(*slice_to_dynamic->operand(i));
|
||||||
|
llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
|
||||||
|
dynamic_dims.push_back(dyn_dim_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
// only one thread need to store the dynamic index
|
||||||
|
// int thread_id = GetThreadId();
|
||||||
|
// int block_id = GetBlockId();
|
||||||
|
// if (thread_id == 0 && block_id == 0) {
|
||||||
|
// *dyn_dim0_size = *output[1];
|
||||||
|
// *dyn_dim1_size = *output[2];
|
||||||
|
// }
|
||||||
|
KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
|
||||||
|
for (int64 i = 1; i < slice_to_dynamic->operand_count(); ++i) {
|
||||||
|
const int64 dim_index = i - 1;
|
||||||
|
llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
|
||||||
|
b_.getInt8Ty(), raw_buffer,
|
||||||
|
raw_data_size + dim_index * sizeof(int32));
|
||||||
|
// output[i] stores dynamic_dim_(i-1)
|
||||||
|
b_.CreateStore(
|
||||||
|
dynamic_dims[dim_index],
|
||||||
|
b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// int dyn_element_total = 1;
|
||||||
|
// dyn_element_total *= dyn_dim0_size;
|
||||||
|
// dyn_element_total *= dyn_dim1_size;
|
||||||
|
llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
|
||||||
|
for (llvm::Value* dynamic_dim : dynamic_dims) {
|
||||||
|
dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
|
||||||
|
/*Name=*/"dyn_element_total");
|
||||||
|
}
|
||||||
|
|
||||||
|
// linear_index = block_id * thread_per_block + thread_id;
|
||||||
|
// if (linear_index < max_num_element) {
|
||||||
|
// Index static_index =
|
||||||
|
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
|
||||||
|
// if (linerized_index < dyn_element_total) {
|
||||||
|
// Index dyn_index =
|
||||||
|
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
|
||||||
|
// dest_array[static_index.dim0][static_index.di] =
|
||||||
|
// source_array[dyn_index.dim0][dyn_index.dim1];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
llvm_ir::LoopEmitter::BodyEmitter body_generator =
|
||||||
|
[&](const llvm_ir::IrArray::Index& array_index) -> Status {
|
||||||
|
llvm::Value* linearIndex =
|
||||||
|
array_index.Linearize(input_shape.dimensions(), &b_);
|
||||||
|
auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
|
||||||
|
b_.CreateICmpULT(linearIndex, dyn_element_total),
|
||||||
|
llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
|
||||||
|
// Set IR builder insertion point to the body of the if structure.
|
||||||
|
llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
|
||||||
|
llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
|
||||||
|
absl::MakeSpan(dynamic_dims), &b_);
|
||||||
|
|
||||||
|
data_array.EmitWriteArrayElement(
|
||||||
|
array_index,
|
||||||
|
GetIrArray(*slice_to_dynamic->operand(0), *slice_to_dynamic)
|
||||||
|
.EmitReadArrayElement(dyn_index, &b_, /*name=*/"",
|
||||||
|
/*use_linear_index=*/false),
|
||||||
|
&b_);
|
||||||
|
return Status::OK();
|
||||||
|
};
|
||||||
|
|
||||||
|
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||||
|
input_shape, ir_emitter_context_->device_description(), unroll_factor);
|
||||||
|
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
|
||||||
|
ir_emitter_context_->llvm_module());
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
|
||||||
|
unroll_factor)
|
||||||
|
.EmitLoop(ir_name, GetIndexTypeForKernel(
|
||||||
|
slice_to_dynamic,
|
||||||
|
launch_dimensions.launch_bound(), &b_)));
|
||||||
|
thunk_sequence_->emplace_back(std::move(kernel_thunk));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||||
|
if (custom_call->custom_call_target() == "PadToStatic") {
|
||||||
|
return HandlePadToStatic(custom_call);
|
||||||
|
}
|
||||||
|
if (custom_call->custom_call_target() == "SliceToDynamic") {
|
||||||
|
return HandleSliceToDynamic(custom_call);
|
||||||
|
}
|
||||||
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
return ThunkEmitter(this).HandleCustomCall(custom_call);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,6 +146,98 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
thunk_sequence_->emplace_back(std::move(thunk));
|
thunk_sequence_->emplace_back(std::move(thunk));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Input = {static array, dynamic_dim0, dynamic_dim1}
|
||||||
|
// Output = {dynamic array(with dynamic dimension meta data at the end)}
|
||||||
|
// For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
|
||||||
|
// (`_` stands for padding)
|
||||||
|
// Input = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
|
||||||
|
// Output = {{1,2,3,4,5,6,_,_,_,_,2,3}}
|
||||||
|
|
||||||
|
// pseudo code for padToStatic on a 2d array
|
||||||
|
// ```
|
||||||
|
// void padToStatic(int** input, int** output, int thread_per_block,
|
||||||
|
// int meta_data_offset, int max_num_element,
|
||||||
|
// int static_dim0_size, int static_dim1_size) {
|
||||||
|
// int* source_array = input[0];
|
||||||
|
// int* dest_array = output[0];
|
||||||
|
|
||||||
|
// // extract the dynamic dimension from the source array's metadata
|
||||||
|
// int* dyn_dim0_size = source_array + meta_data_offset;
|
||||||
|
// int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
|
||||||
|
|
||||||
|
// // only one thread need to store the dynamic index
|
||||||
|
// int thread_id = GetThreadId();
|
||||||
|
// int block_id = GetBlockId();
|
||||||
|
// if (thread_id == 0 && block_id == 0) {
|
||||||
|
// *output[1] = *dyn_dim0_size;
|
||||||
|
// *output[2] = *dyn_dim1_size;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// int dyn_element_total = 1;
|
||||||
|
// dyn_element_total *= *dyn_dim0_size;
|
||||||
|
// dyn_element_total *= *dyn_dim1_size;
|
||||||
|
// linear_index = block_id * thread_per_block + thread_id;
|
||||||
|
// if (linear_index < max_num_element) {
|
||||||
|
// Index static_index =
|
||||||
|
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
|
||||||
|
// if (linerized_index < dyn_element_total) {
|
||||||
|
// Index dyn_index =
|
||||||
|
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
|
||||||
|
// dest_array[dyn_index.dim0][dyn_index.dim1] =
|
||||||
|
// source_array[static_index.dim0][static_index.dim1];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
// ```
|
||||||
|
Status HandlePadToStatic(HloInstruction* pad_to_static);
|
||||||
|
|
||||||
|
// Input = {dynamic array(with dynamic dimension meta data at the end)}
|
||||||
|
// Output = {static array, dynamic_dim0, dynamic_dim1}
|
||||||
|
// For a tensor with static dimension [2][<=5] and dynamic dimension [2][3]
|
||||||
|
// (`_` stands for padding)
|
||||||
|
// Input = {{1,2,3,4,5,6,_,_,_,_,2,3}}
|
||||||
|
// Output = {{1,2,3,_,_,4,5,6_,_}, 2, 3}
|
||||||
|
|
||||||
|
// pseudo code for sliceToDynamic on a 2d array
|
||||||
|
// ```
|
||||||
|
// void sliceToDynamic(int** input, int** output, int thread_per_block,
|
||||||
|
// int meta_data_offset, int max_num_element,
|
||||||
|
// int static_dim0_size, int static_dim1_size) {
|
||||||
|
// int* source_array = input[0];
|
||||||
|
// int* dest_array = output[0];
|
||||||
|
|
||||||
|
// // calculate the location where metadata needs to be inserted
|
||||||
|
// int* dyn_dim0_size = dest_array + meta_data_offset;
|
||||||
|
// int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
|
||||||
|
|
||||||
|
// // only one thread need to store the dynamic index
|
||||||
|
// int thread_id = GetThreadId();
|
||||||
|
// int block_id = GetBlockId();
|
||||||
|
// if (thread_id == 0 && block_id == 0) {
|
||||||
|
// *dyn_dim0_size = *output[1];
|
||||||
|
// *dyn_dim1_size = *output[2];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// int dyn_element_total = 1;
|
||||||
|
// dyn_element_total *= *dyn_dim0_size;
|
||||||
|
// dyn_element_total *= *dyn_dim1_size;
|
||||||
|
// linear_index = block_id * thread_per_block + thread_id;
|
||||||
|
// if (linear_index < max_num_element) {
|
||||||
|
// Index static_index =
|
||||||
|
// delinerized(linerized_index, static_dim0_size, static_dim1_size);
|
||||||
|
// if (linerized_index < dyn_element_total) {
|
||||||
|
// Index dyn_index =
|
||||||
|
// delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
|
||||||
|
// dest_array[static_index.dim0][static_index.dim1] =
|
||||||
|
// source_array[dyn_index.dim0][dyn_index.dim1];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
// ```
|
||||||
|
Status HandleSliceToDynamic(HloInstruction* slice_to_dynamic);
|
||||||
|
|
||||||
// A convenient helper for calling BufferAssignment::GetUniqueSlice.
|
// A convenient helper for calling BufferAssignment::GetUniqueSlice.
|
||||||
StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
|
StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
|
||||||
const HloInstruction& hlo, const ShapeIndex& index) const override {
|
const HloInstruction& hlo, const ShapeIndex& index) const override {
|
||||||
|
@ -249,6 +249,21 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "gpu_dyn_shape_test",
|
||||||
|
srcs = ["gpu_dyn_shape_test.cc"],
|
||||||
|
tags = tf_cuda_tests_tags(),
|
||||||
|
deps = [
|
||||||
|
":gpu_codegen_test",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "gpu_ftz_test",
|
name = "gpu_ftz_test",
|
||||||
srcs = ["gpu_ftz_test.cc"],
|
srcs = ["gpu_ftz_test.cc"],
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace gpu {
|
||||||
|
class GpuDynamicShapeTest : public GpuCodegenTest {};
|
||||||
|
|
||||||
|
TEST_F(GpuDynamicShapeTest, DynamicShapeR2) {
|
||||||
|
HloComputation::Builder builder(TestName());
|
||||||
|
|
||||||
|
xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
|
||||||
|
dyn_input_shape.set_dynamic_dimension(0, true);
|
||||||
|
HloInstruction* param_x = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, dyn_input_shape, "x"));
|
||||||
|
|
||||||
|
builder.AddInstruction(HloInstruction::CreateUnary(
|
||||||
|
dyn_input_shape, HloOpcode::kNegate, param_x));
|
||||||
|
auto hlo_module = CreateNewVerifiedModule();
|
||||||
|
hlo_module->AddEntryComputation(builder.Build());
|
||||||
|
|
||||||
|
CompileAndVerifyIr(std::move(hlo_module),
|
||||||
|
R"(
|
||||||
|
; CHECK-LABEL: is_thred_0-true
|
||||||
|
; CHECK_LABEL: custom-call.in_dyn_bounds-true
|
||||||
|
; CHECK_LABEL: custom-call.in_bounds-true
|
||||||
|
; CHECK: %[[dyn_dim_size:.*]] = load i32, i32*
|
||||||
|
; CHECK: %[[dyn_element_total:.*]] = mul i32 1, %[[dyn_dim_size:.*]]
|
||||||
|
; CHECK: %[[linear_index:.*]] = add nuw nsw i32
|
||||||
|
; CHECK: %[[linear_index_in_range:.*]] = icmp ult i32 %[[linear_index:.*]],
|
||||||
|
; CHECK: store i32 %[[dyn_dim_size:.*]], i32*
|
||||||
|
)",
|
||||||
|
/*match_optimized_ir=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gpu
|
||||||
|
} // namespace xla
|
@ -71,6 +71,32 @@ void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
|
||||||
|
llvm::Value* linear, const Shape& shape,
|
||||||
|
absl::Span<llvm::Value*> dynamic_dims,
|
||||||
|
llvm::IRBuilder<>* b) const {
|
||||||
|
CHECK_EQ(shape.dimensions_size(), dynamic_dims.size());
|
||||||
|
CHECK_EQ(multidim_.size(), shape.rank());
|
||||||
|
llvm::Value* divisor = GetConstantWithIndexType(1);
|
||||||
|
const Layout& layout = shape.layout();
|
||||||
|
for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
|
||||||
|
int64 dimension = layout.minor_to_major(i);
|
||||||
|
|
||||||
|
// If i is not the last dimension, compute
|
||||||
|
// (linear_index / divisor) % current_dimension.
|
||||||
|
// If i is the last dimension, we can skip the mod, because we assume that
|
||||||
|
// linear is in bounds.
|
||||||
|
auto* quot = b->CreateUDiv(linear, divisor, "quot");
|
||||||
|
if (i < layout.minor_to_major_size() - 1) {
|
||||||
|
(*multidim)[dimension] =
|
||||||
|
b->CreateURem(quot, dynamic_dims[dimension], "dim_value");
|
||||||
|
divisor = b->CreateMul(divisor, dynamic_dims[dimension], "divisor");
|
||||||
|
} else {
|
||||||
|
(*multidim)[dimension] = quot;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
|
IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
|
||||||
llvm::IRBuilder<>* b)
|
llvm::IRBuilder<>* b)
|
||||||
: multidim_(shape.rank()),
|
: multidim_(shape.rank()),
|
||||||
@ -85,6 +111,21 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
|
|||||||
Delinearize(&multidim_, linear, shape, b);
|
Delinearize(&multidim_, linear, shape, b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
|
||||||
|
absl::Span<llvm::Value*> dynamic_dims,
|
||||||
|
llvm::IRBuilder<>* b)
|
||||||
|
: multidim_(shape.rank()),
|
||||||
|
linear_(linear),
|
||||||
|
layout_(shape.layout()),
|
||||||
|
dims_(shape.dimensions().begin(), shape.dimensions().end()) {
|
||||||
|
CHECK_NE(linear, nullptr);
|
||||||
|
index_type_ = linear->getType();
|
||||||
|
CHECK(LayoutUtil::HasLayout(shape))
|
||||||
|
<< "Shape " << ShapeUtil::HumanStringWithLayout(shape)
|
||||||
|
<< " should have a layout.";
|
||||||
|
Delinearize(&multidim_, linear, shape, dynamic_dims, b);
|
||||||
|
}
|
||||||
|
|
||||||
IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
|
IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
|
||||||
absl::Span<int64 const> dimensions,
|
absl::Span<int64 const> dimensions,
|
||||||
llvm::Type* index_type)
|
llvm::Type* index_type)
|
||||||
|
@ -66,6 +66,11 @@ class IrArray {
|
|||||||
// Precondition: "shape" has a layout.
|
// Precondition: "shape" has a layout.
|
||||||
Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
|
Index(llvm::Value* linear, const Shape& shape, llvm::IRBuilder<>* b);
|
||||||
|
|
||||||
|
// Similar to the above constructor except using "dynamic_dims" instead of
|
||||||
|
// shape's static dimension to constructs the index.
|
||||||
|
Index(llvm::Value* linear, const Shape& shape,
|
||||||
|
absl::Span<llvm::Value*> dynamic_dims, llvm::IRBuilder<>* b);
|
||||||
|
|
||||||
// Constructs an index from a multi-dimensional index. 'shape' is the shape
|
// Constructs an index from a multi-dimensional index. 'shape' is the shape
|
||||||
// for which the multi-dimensional index is used. 'index_type' is the type
|
// for which the multi-dimensional index is used. 'index_type' is the type
|
||||||
// of the index.
|
// of the index.
|
||||||
@ -180,6 +185,11 @@ class IrArray {
|
|||||||
void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
|
void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
|
||||||
const Shape& shape, llvm::IRBuilder<>* b) const;
|
const Shape& shape, llvm::IRBuilder<>* b) const;
|
||||||
|
|
||||||
|
// Delinearize the linear index with the dynamic dimensions.
|
||||||
|
void Delinearize(std::vector<llvm::Value*>* multidim, llvm::Value* linear,
|
||||||
|
const Shape& shape, absl::Span<llvm::Value*> dynamic_dims,
|
||||||
|
llvm::IRBuilder<>* b) const;
|
||||||
|
|
||||||
std::vector<llvm::Value*> multidim_;
|
std::vector<llvm::Value*> multidim_;
|
||||||
|
|
||||||
// These values are purely for efficiency; `multidim_` is enough to find the
|
// These values are purely for efficiency; `multidim_` is enough to find the
|
||||||
|
@ -146,11 +146,6 @@ class XrtClientSession : public ClientSession {
|
|||||||
string* xla_test_device_ptr; // initial value set in main()
|
string* xla_test_device_ptr; // initial value set in main()
|
||||||
string* xla_platform_ptr; // initial value set in main()
|
string* xla_platform_ptr; // initial value set in main()
|
||||||
|
|
||||||
bool SupportDynamicShapes() {
|
|
||||||
// TODO(jackcao): Support dynamic shapes on XLA GPU.
|
|
||||||
return *xla_test_device_ptr != "XLA_GPU";
|
|
||||||
}
|
|
||||||
|
|
||||||
string DeviceFromFlag() {
|
string DeviceFromFlag() {
|
||||||
string xla_test_device = *xla_test_device_ptr;
|
string xla_test_device = *xla_test_device_ptr;
|
||||||
return absl::StrCat("/device:", xla_test_device, ":0");
|
return absl::StrCat("/device:", xla_test_device, ":0");
|
||||||
@ -1126,10 +1121,6 @@ TEST(RawApiTest, CompileAndExecute) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(RawApiTest, DynamicR1Test) {
|
TEST(RawApiTest, DynamicR1Test) {
|
||||||
if (!SupportDynamicShapes()) {
|
|
||||||
GTEST_SKIP()
|
|
||||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
|
||||||
}
|
|
||||||
xrt::XLAAllocation p0;
|
xrt::XLAAllocation p0;
|
||||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
|
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
|
||||||
xrt::XLAAllocation p1;
|
xrt::XLAAllocation p1;
|
||||||
@ -1182,10 +1173,6 @@ TEST(RawApiTest, DynamicR1Test) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(RawApiTest, DynamicR2Test) {
|
TEST(RawApiTest, DynamicR2Test) {
|
||||||
if (!SupportDynamicShapes()) {
|
|
||||||
GTEST_SKIP()
|
|
||||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
|
||||||
}
|
|
||||||
xrt::XLAAllocation p0;
|
xrt::XLAAllocation p0;
|
||||||
*p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f},
|
*p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f},
|
||||||
{1.5f, 2.5f, 3.0f, -2.0f}})
|
{1.5f, 2.5f, 3.0f, -2.0f}})
|
||||||
@ -1243,10 +1230,6 @@ TEST(RawApiTest, DynamicR2Test) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(RawApiTest, DynamicR1TupleTest) {
|
TEST(RawApiTest, DynamicR1TupleTest) {
|
||||||
if (!SupportDynamicShapes()) {
|
|
||||||
GTEST_SKIP()
|
|
||||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
|
||||||
}
|
|
||||||
xrt::XLAAllocation p0;
|
xrt::XLAAllocation p0;
|
||||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
|
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
|
||||||
xrt::XLAAllocation p1;
|
xrt::XLAAllocation p1;
|
||||||
@ -1307,10 +1290,6 @@ TEST(RawApiTest, DynamicR1TupleTest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(RawApiTest, AcceptDynamicR1TupleTest) {
|
TEST(RawApiTest, AcceptDynamicR1TupleTest) {
|
||||||
if (!SupportDynamicShapes()) {
|
|
||||||
GTEST_SKIP()
|
|
||||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
|
||||||
}
|
|
||||||
xrt::XLAAllocation p0;
|
xrt::XLAAllocation p0;
|
||||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
|
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
|
||||||
xrt::XLAAllocation p1;
|
xrt::XLAAllocation p1;
|
||||||
@ -1373,10 +1352,6 @@ TEST(RawApiTest, AcceptDynamicR1TupleTest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(RawApiTest, AcceptDynamicR1Test) {
|
TEST(RawApiTest, AcceptDynamicR1Test) {
|
||||||
if (!SupportDynamicShapes()) {
|
|
||||||
GTEST_SKIP()
|
|
||||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
|
||||||
}
|
|
||||||
xrt::XLAAllocation p0;
|
xrt::XLAAllocation p0;
|
||||||
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
|
*p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
|
||||||
xrt::XLAAllocation p1;
|
xrt::XLAAllocation p1;
|
||||||
@ -1424,13 +1399,9 @@ TEST(RawApiTest, AcceptDynamicR1Test) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(RawApiTest, AcceptDynamicR2Test) {
|
TEST(RawApiTest, AcceptDynamicR2Test) {
|
||||||
if (!SupportDynamicShapes()) {
|
|
||||||
GTEST_SKIP()
|
|
||||||
<< "Skipping the test if backend doesn't support dynamic shapes";
|
|
||||||
}
|
|
||||||
xrt::XLAAllocation p0;
|
xrt::XLAAllocation p0;
|
||||||
*p0.mutable_value() =
|
*p0.mutable_value() =
|
||||||
xla::LiteralUtil::CreateR2({{-1.0f, 3.0f, 1.0f}, {-2.0f, -1.0f, 3.0f}})
|
xla::LiteralUtil::CreateR2({{-1.0f, 2.0f, 3.0f}, {-4.0f, -5.0f, 6.0f}})
|
||||||
.ToProto();
|
.ToProto();
|
||||||
|
|
||||||
xrt::XLAComputation c;
|
xrt::XLAComputation c;
|
||||||
@ -1468,7 +1439,7 @@ TEST(RawApiTest, AcceptDynamicR2Test) {
|
|||||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
|
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
|
||||||
|
|
||||||
auto expected = xla::LiteralUtil::CreateR2<float>(
|
auto expected = xla::LiteralUtil::CreateR2<float>(
|
||||||
{{1.0f, -3.0f, -1.0f}, {2.0f, 1.0f, -3.0f}});
|
{{1.0f, -2.0f, -3.0f}, {4.0f, 5.0f, -6.0f}});
|
||||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user