[XLA/GPU] Use a proper indexing type while doing reduction
PiperOrigin-RevId: 291946087 Change-Id: Ie070fc1827713a9150ce1f5807574cb6b7929363
This commit is contained in:
parent
7e124e7b66
commit
9bc9f8fcc8
@ -2191,14 +2191,14 @@ static llvm::Value* GetUntransposedOutputLinearAddress(
|
||||
}
|
||||
|
||||
void IrEmitterUnnested::EmitEpilogueForReduction(
|
||||
HloInstruction* unnested_hlo, const ReductionCodegenInfo& reduction_info,
|
||||
llvm::Type* index_ty, HloInstruction* unnested_hlo,
|
||||
const ReductionCodegenInfo& reduction_info,
|
||||
absl::Span<const HloInstruction* const> reduce_instructions,
|
||||
absl::Span<const ShapeIndex> reduction_output_shape_indices,
|
||||
absl::Span<HloComputation* const> reducers,
|
||||
const IrArray::Index& starting_tile) {
|
||||
const KernelMappingScheme& mapping_scheme =
|
||||
reduction_info.GetKernelMappingScheme();
|
||||
llvm::Type* index_ty = b_.getInt32Ty();
|
||||
auto constant = [&](uint64 c) -> llvm::Constant* {
|
||||
return llvm::ConstantInt::get(index_ty, c);
|
||||
};
|
||||
@ -3078,9 +3078,9 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
|
||||
&b_, y, x, tile_height, tile_width, emit_reduction_tile);
|
||||
});
|
||||
EmitEpilogueForReduction(unnested_hlo, reduction_info, reduce_instructions,
|
||||
reduction_output_shape_indices, reducers,
|
||||
starting_tile);
|
||||
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
|
||||
reduce_instructions, reduction_output_shape_indices,
|
||||
reducers, starting_tile);
|
||||
|
||||
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
|
||||
ir_emitter_context_->llvm_module());
|
||||
|
@ -270,7 +270,8 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// Wraps up the code generation for a tile block of a reduction kernel: write
|
||||
// the calculated output into the output tensor.
|
||||
void EmitEpilogueForReduction(
|
||||
HloInstruction* unnested_hlo, const ReductionCodegenInfo& reduction_info,
|
||||
llvm::Type* index_ty, HloInstruction* unnested_hlo,
|
||||
const ReductionCodegenInfo& reduction_info,
|
||||
absl::Span<const HloInstruction* const> reduce_instructions,
|
||||
absl::Span<const ShapeIndex> reduction_output_shape_indices,
|
||||
absl::Span<HloComputation* const> reducers,
|
||||
|
@ -742,6 +742,30 @@ TEST_F(GpuKernelTilingTest, RowReductionWithSmallDimensionNotTiled) {
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
||||
}
|
||||
|
||||
TEST_F(GpuKernelTilingTest, RowReductionRequiring64BitIndex) {
|
||||
const char *const kHloString = R"(
|
||||
HloModule LargeReduction
|
||||
|
||||
Sum {
|
||||
x.1 = f32[] parameter(0)
|
||||
y.1 = f32[] parameter(1)
|
||||
ROOT add.1 = f32[] add(x.1, y.1)
|
||||
}
|
||||
|
||||
ENTRY reduce.1 {
|
||||
parameter = f32[3048576000] parameter(0)
|
||||
init_value = f32[] constant(0)
|
||||
ROOT out = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum
|
||||
}
|
||||
)";
|
||||
auto hlo_module = ParseAndReturnVerifiedModule(kHloString).ValueOrDie();
|
||||
auto expected_ir = R"(
|
||||
; CHECK: i64
|
||||
)";
|
||||
CompileAndVerifyIr(std::move(hlo_module), expected_ir,
|
||||
/*match_optimized_ir=*/true);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user