[XLA:GPU] Enhancement to the implementation of reduction to vector.
When the kept elements of a reduction are contiguous in the input tensor, we use the implementation of reduction to vector to emit high performance code. Previously, we require the reduction result have the same layout as the kept kept elements in the input tensor. This change removes such a constraint. Add test cases. PiperOrigin-RevId: 239204904
This commit is contained in:
parent
f9fbff63fb
commit
d5236e610b
@ -169,12 +169,7 @@ bool IsReductionToVector(const HloInstruction& reduce) {
|
||||
}
|
||||
}
|
||||
return LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
|
||||
dims_to_keep) &&
|
||||
ShapeUtil::Equal(
|
||||
reduce.shape(),
|
||||
ShapeUtil::FilterDimensions(
|
||||
[&](int64 dim) { return absl::c_count(dims_to_keep, dim); },
|
||||
input->shape()));
|
||||
dims_to_keep);
|
||||
}
|
||||
|
||||
// This emits a device-side call to
|
||||
|
@ -2608,6 +2608,9 @@ class ReductionCodegenInfo : public IrEmitterUnnested::KernelCodegenInfo {
|
||||
AddressVector reduction_input_addresses_;
|
||||
InlinedVector<HloComputation*, 1> reducers_;
|
||||
InlinedVector<ShapeIndex, 1> reduction_output_shape_indices_;
|
||||
// The address of the memory that stores the linear index of the current
|
||||
// output, assuming that the output doesn't change the layout of the kept
|
||||
// elements in the reduction input.
|
||||
llvm::AllocaInst* current_output_linear_index_address_;
|
||||
llvm::AllocaInst* current_output_inbound_address_;
|
||||
bool is_row_reduction_;
|
||||
@ -2806,21 +2809,49 @@ void IrEmitterUnnested::EmitEpilogueForReduction(
|
||||
llvm_ir::SetToFirstInsertPoint(if_output_inbound_data.true_block, &b_);
|
||||
}
|
||||
|
||||
HloInstruction* reduce_or_tuple = unnested_hlo->opcode() == HloOpcode::kFusion
|
||||
? unnested_hlo->fused_expression_root()
|
||||
: unnested_hlo;
|
||||
std::vector<const HloInstruction*> reduce_instructions;
|
||||
absl::c_for_each(GetOutputInstructions(&reduce_or_tuple),
|
||||
[&](const HloInstruction* instr) {
|
||||
if (instr->opcode() == HloOpcode::kReduce) {
|
||||
reduce_instructions.push_back(instr);
|
||||
}
|
||||
});
|
||||
int num_partial_results = reduction_info->GetNumberOfPartialResults();
|
||||
|
||||
// Emit an atomic operation that accumulates the partial reduction to the
|
||||
// output element. For row reduction, this is only for lane 0 due to the
|
||||
// if-statement emitted above.
|
||||
for (int i = 0; i != num_reduces; ++i) {
|
||||
const HloInstruction* reduce_hlo = reduce_instructions[i];
|
||||
Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions(
|
||||
[&](int64 dim) {
|
||||
return !absl::c_linear_search(reduce_hlo->dimensions(), dim);
|
||||
},
|
||||
reduce_hlo->operand(0)->shape());
|
||||
for (int j = 0; j < num_partial_results; ++j) {
|
||||
// A reduction is allowed to transpose its output. For example, suppose
|
||||
// we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are
|
||||
// allowed to produce as output either f32[10,30]{1,0} (no transpose) or
|
||||
// f32[10,30]{0,1} (transposing the two output dims).
|
||||
//
|
||||
// At this point in the function we have a "partial sum" of input elements
|
||||
// (stored in partial_result_addresses), and we need to accumulate it into
|
||||
// the correct output element.
|
||||
//
|
||||
// *reduction_info->GetCurrentOutputLinearIndexAddress() stores the linear
|
||||
// index in the output into which we would need to accumulate *if the
|
||||
// output layout matched the input layout*. This is why we use
|
||||
// `reduction_kept_element_shape` rather than `unnested_hlo->shape()` when
|
||||
// computing `element_index` below.
|
||||
IrArray::Index element_index(
|
||||
/*linear=*/Load(
|
||||
InBoundsGEP(reduction_info->GetCurrentOutputLinearIndexAddress(),
|
||||
{b_.getInt32(j)}),
|
||||
"output_linear_addr"),
|
||||
ShapeUtil::GetSubshape(unnested_hlo->shape(),
|
||||
reduction_output_shape_indices[i]),
|
||||
&b_);
|
||||
"untransposed_output_linear_addr"),
|
||||
reduction_kept_element_shape, &b_);
|
||||
llvm::Value* output_address =
|
||||
GetIrArray(*unnested_hlo, *unnested_hlo,
|
||||
reduction_output_shape_indices[i])
|
||||
|
@ -411,6 +411,69 @@ TEST_F(GpuKernelTilingTest, ColumnReductionMOFUnrolled) {
|
||||
// Check that the kernel runs correctly.
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1.0e-5, 1.0e-5}));
|
||||
}
|
||||
|
||||
TEST_F(GpuKernelTilingTest, ColumnReductionWithLayoutChangeTiled) {
|
||||
const char *const kHloString = R"(
|
||||
HloModule reduce_with_layout_change
|
||||
reduction0 {
|
||||
x0 = f32[] parameter(0)
|
||||
y0 = f32[] parameter(1)
|
||||
ROOT add0 = f32[] add(x0, y0)
|
||||
}
|
||||
|
||||
ENTRY kernel_entry {
|
||||
arg0 = f32[4,32,32,16,12,12,3,3]{2,3,5,4,0,7,6,1} parameter(0)
|
||||
constant0 = f32[] constant(0)
|
||||
ROOT reduce0 = f32[4,32,16,12,12]{4,3,2,1,0} reduce(arg0, constant0),
|
||||
dimensions={1,6,7}, to_apply=reduction0
|
||||
})";
|
||||
|
||||
// Check that the kernel is tiled by looking for llvm.nvvm.atomic.
|
||||
auto hlo_module =
|
||||
ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
|
||||
CompileAndVerifyIr(std::move(hlo_module),
|
||||
R"(
|
||||
; CHECK-LABEL: define void @reduce
|
||||
; CHECK: call float @llvm.nvvm.atomic.load.add.f32.p0f32
|
||||
; CHECK: }
|
||||
)",
|
||||
/*match_optimized_ir=*/true);
|
||||
|
||||
// Check that the kernel runs correctly.
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
||||
}
|
||||
|
||||
TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) {
|
||||
const char *const kHloString = R"(
|
||||
HloModule reduce_with_layout_change
|
||||
reduction0 {
|
||||
x0 = f32[] parameter(0)
|
||||
y0 = f32[] parameter(1)
|
||||
ROOT add0 = f32[] add(x0, y0)
|
||||
}
|
||||
|
||||
ENTRY kernel_entry {
|
||||
arg0 = f32[8,6,64]{2,1,0} parameter(0)
|
||||
constant0 = f32[] constant(0)
|
||||
ROOT reduce0 = f32[8,6]{0,1} reduce(arg0, constant0), dimensions={2},
|
||||
to_apply=reduction0
|
||||
})";
|
||||
|
||||
// Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down.
|
||||
auto hlo_module =
|
||||
ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
|
||||
CompileAndVerifyIr(std::move(hlo_module),
|
||||
R"(
|
||||
; CHECK-LABEL: define void @reduce
|
||||
; CHECK: call float @llvm.nvvm.shfl.sync.down.f32
|
||||
; CHECK: }
|
||||
)",
|
||||
/*match_optimized_ir=*/true);
|
||||
|
||||
// Check that the kernel runs correctly.
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user