[XLA/GPU] Use a proper indexing type while doing reduction
PiperOrigin-RevId: 291946087 Change-Id: Ie070fc1827713a9150ce1f5807574cb6b7929363
This commit is contained in:
parent
7e124e7b66
commit
9bc9f8fcc8
@ -2191,14 +2191,14 @@ static llvm::Value* GetUntransposedOutputLinearAddress(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void IrEmitterUnnested::EmitEpilogueForReduction(
|
void IrEmitterUnnested::EmitEpilogueForReduction(
|
||||||
HloInstruction* unnested_hlo, const ReductionCodegenInfo& reduction_info,
|
llvm::Type* index_ty, HloInstruction* unnested_hlo,
|
||||||
|
const ReductionCodegenInfo& reduction_info,
|
||||||
absl::Span<const HloInstruction* const> reduce_instructions,
|
absl::Span<const HloInstruction* const> reduce_instructions,
|
||||||
absl::Span<const ShapeIndex> reduction_output_shape_indices,
|
absl::Span<const ShapeIndex> reduction_output_shape_indices,
|
||||||
absl::Span<HloComputation* const> reducers,
|
absl::Span<HloComputation* const> reducers,
|
||||||
const IrArray::Index& starting_tile) {
|
const IrArray::Index& starting_tile) {
|
||||||
const KernelMappingScheme& mapping_scheme =
|
const KernelMappingScheme& mapping_scheme =
|
||||||
reduction_info.GetKernelMappingScheme();
|
reduction_info.GetKernelMappingScheme();
|
||||||
llvm::Type* index_ty = b_.getInt32Ty();
|
|
||||||
auto constant = [&](uint64 c) -> llvm::Constant* {
|
auto constant = [&](uint64 c) -> llvm::Constant* {
|
||||||
return llvm::ConstantInt::get(index_ty, c);
|
return llvm::ConstantInt::get(index_ty, c);
|
||||||
};
|
};
|
||||||
@ -3078,9 +3078,9 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
|||||||
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
|
EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
|
||||||
&b_, y, x, tile_height, tile_width, emit_reduction_tile);
|
&b_, y, x, tile_height, tile_width, emit_reduction_tile);
|
||||||
});
|
});
|
||||||
EmitEpilogueForReduction(unnested_hlo, reduction_info, reduce_instructions,
|
EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
|
||||||
reduction_output_shape_indices, reducers,
|
reduce_instructions, reduction_output_shape_indices,
|
||||||
starting_tile);
|
reducers, starting_tile);
|
||||||
|
|
||||||
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
|
UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
|
||||||
ir_emitter_context_->llvm_module());
|
ir_emitter_context_->llvm_module());
|
||||||
|
@ -270,7 +270,8 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
// Wraps up the code generation for a tile block of a reduction kernel: write
|
// Wraps up the code generation for a tile block of a reduction kernel: write
|
||||||
// the calculated output into the output tensor.
|
// the calculated output into the output tensor.
|
||||||
void EmitEpilogueForReduction(
|
void EmitEpilogueForReduction(
|
||||||
HloInstruction* unnested_hlo, const ReductionCodegenInfo& reduction_info,
|
llvm::Type* index_ty, HloInstruction* unnested_hlo,
|
||||||
|
const ReductionCodegenInfo& reduction_info,
|
||||||
absl::Span<const HloInstruction* const> reduce_instructions,
|
absl::Span<const HloInstruction* const> reduce_instructions,
|
||||||
absl::Span<const ShapeIndex> reduction_output_shape_indices,
|
absl::Span<const ShapeIndex> reduction_output_shape_indices,
|
||||||
absl::Span<HloComputation* const> reducers,
|
absl::Span<HloComputation* const> reducers,
|
||||||
|
@ -742,6 +742,30 @@ TEST_F(GpuKernelTilingTest, RowReductionWithSmallDimensionNotTiled) {
|
|||||||
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GpuKernelTilingTest, RowReductionRequiring64BitIndex) {
|
||||||
|
const char *const kHloString = R"(
|
||||||
|
HloModule LargeReduction
|
||||||
|
|
||||||
|
Sum {
|
||||||
|
x.1 = f32[] parameter(0)
|
||||||
|
y.1 = f32[] parameter(1)
|
||||||
|
ROOT add.1 = f32[] add(x.1, y.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY reduce.1 {
|
||||||
|
parameter = f32[3048576000] parameter(0)
|
||||||
|
init_value = f32[] constant(0)
|
||||||
|
ROOT out = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
auto hlo_module = ParseAndReturnVerifiedModule(kHloString).ValueOrDie();
|
||||||
|
auto expected_ir = R"(
|
||||||
|
; CHECK: i64
|
||||||
|
)";
|
||||||
|
CompileAndVerifyIr(std::move(hlo_module), expected_ir,
|
||||||
|
/*match_optimized_ir=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user