[XLA:GPU] Reduce block sizes of multi-output row-reductions to reduce register pressure and avoid spilling.
PiperOrigin-RevId: 315749669 Change-Id: Iaae83aa020b0c2d8c341f00faaee5b117ea79ca4
This commit is contained in:
parent
3099f3a664
commit
c03e49d2c2
@ -3150,6 +3150,15 @@ bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo,
|
||||
return can_be_vectorized >= cannot_be_vectorized;
|
||||
}
|
||||
|
||||
int64 NearestPowerOfTwo(int64 v) {
|
||||
if (v < 0) {
|
||||
return 0;
|
||||
}
|
||||
int64 upper = tensorflow::NextPowerOfTwo64(v);
|
||||
int64 lower = upper >> 1;
|
||||
return upper - v < v - lower ? upper : lower;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
|
||||
@ -3179,8 +3188,16 @@ ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
|
||||
int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
|
||||
int64 num_threads_x = [&] {
|
||||
if (reduction_dimensions.is_row_reduction) {
|
||||
// Use 512 as default block size (threads per block) for row reductions.
|
||||
// For multi-output fusions, reduce the block size further to decrease
|
||||
// register pressure when multiple outputs are computed by each thread.
|
||||
int64 fan_out =
|
||||
unnested_hlo->IsMultiOutputFusion()
|
||||
? unnested_hlo->fused_expression_root()->operand_count()
|
||||
: 1;
|
||||
int64 max_block_size = std::max(16LL, 512LL / NearestPowerOfTwo(fan_out));
|
||||
return std::min(
|
||||
kWarpSize * kWarpSize,
|
||||
max_block_size,
|
||||
RoundUpToNearest(CeilOfRatio(reduction_dimensions.dimensions[2],
|
||||
reduction_tiling[2]),
|
||||
kWarpSize));
|
||||
@ -3292,6 +3309,9 @@ Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
|
||||
reduction_info.GetKernelMappingScheme();
|
||||
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
|
||||
mapping_scheme.GetThreadsPerBlock());
|
||||
VLOG(3) << "Launch dimensions of " << unnested_hlo->name()
|
||||
<< ": number of blocks: " << mapping_scheme.GetNumberOfBlocks()
|
||||
<< " - threads per block: " << mapping_scheme.GetThreadsPerBlock();
|
||||
llvm::Type* index_ty = GetIndexTypeForKernel(
|
||||
unnested_hlo, launch_dimensions.launch_bound(), &b_);
|
||||
EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions,
|
||||
|
Loading…
Reference in New Issue
Block a user