[XLA:GPU] Limit the use of the 0-2-1 shared memory transpose implementation.
For a parameter that is a 0-2-1 transpose of a result tensor, the shared memory transpose implementation uses the indices of the result tensor elements in a tile to pre-load the parameter elements to a shared memory buffer and replaces later accesses to the parameter with accesses to the shared memory buffer. This relies on an assumption that the computation of the result tensor elements in the tile only needs the elements in the pre-loaded tile for the parameter. Any instruction inside a fused computation that may break this assumption can cause incorrect kernel code generation. This change avoids using the 0-2-1 shared memory transpose implementation for a fusion instruction that contains kGather or kReverse instructions. Add a test case. PiperOrigin-RevId: 221097814
This commit is contained in:
parent
7e511f2db8
commit
df41194786
@ -3528,6 +3528,29 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
return launch_dimensions;
|
return launch_dimensions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Returns true to indicate it is safe to use the tile based shared memory
|
||||||
|
// transpose implementation to implement the kernel for the instruction.
|
||||||
|
//
|
||||||
|
// An instruction is not safe for such an implementation if it can change the
|
||||||
|
// element order of a tensor without changing the dimension of the tensor, and
|
||||||
|
// the instruction has a corresponding elemental_ir_emitter.
|
||||||
|
bool IsInstructionSafeForTileBasedTranspose(const HloInstruction* hlo) {
|
||||||
|
auto is_safe_for_tile_based_transpose = [&](const HloInstruction* instr) {
|
||||||
|
HloOpcode opcode = instr->opcode();
|
||||||
|
CHECK_NE(opcode, HloOpcode::kFusion);
|
||||||
|
return (opcode != HloOpcode::kReverse && opcode != HloOpcode::kGather);
|
||||||
|
};
|
||||||
|
|
||||||
|
if (hlo->opcode() == HloOpcode::kFusion) {
|
||||||
|
return absl::c_all_of(hlo->fused_instructions_computation()->instructions(),
|
||||||
|
is_safe_for_tile_based_transpose);
|
||||||
|
}
|
||||||
|
|
||||||
|
return is_safe_for_tile_based_transpose(hlo);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
|
bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
|
||||||
HloOpcode opcode = hlo->opcode();
|
HloOpcode opcode = hlo->opcode();
|
||||||
CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy);
|
CHECK(opcode == HloOpcode::kFusion || opcode == HloOpcode::kCopy);
|
||||||
@ -3572,6 +3595,10 @@ bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!IsInstructionSafeForTileBasedTranspose(hlo)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
|
// Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
|
||||||
// elements are of size 4 bytes), and CUDA has an architectural limit of 48kb
|
// elements are of size 4 bytes), and CUDA has an architectural limit of 48kb
|
||||||
// shared memory per SM. (This is increased to 96kb in Volta, but we don't
|
// shared memory per SM. (This is increased to 96kb in Volta, but we don't
|
||||||
|
@ -193,6 +193,33 @@ TEST_F(GpuKernelTilingTest,
|
|||||||
/*match_optimized_ir=*/true);
|
/*match_optimized_ir=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GpuKernelTilingTest, FusionTransposeWithReverseNotTiled) {
|
||||||
|
const char *const kHloString = R"(
|
||||||
|
HloModule FusionTransposeWithReverseNotTiled
|
||||||
|
fused_computation.1 {
|
||||||
|
arg0 = f32[128,64]{1,0} parameter(0)
|
||||||
|
copy0 = f32[128,64]{0,1} copy(arg0)
|
||||||
|
ROOT reverse0 = f32[128,64]{0,1} reverse(copy0), dimensions={0}
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY reverse_break_assumption {
|
||||||
|
param0 = f32[128,64]{1,0} parameter(0)
|
||||||
|
ROOT fusion0 = f32[128,64]{0,1} fusion(param0), kind=kLoop,
|
||||||
|
calls=fused_computation.1
|
||||||
|
})";
|
||||||
|
|
||||||
|
// Check that a call to llvm.nvvm.barrier0 is not generated.
|
||||||
|
auto hlo_module =
|
||||||
|
ParseHloString(kHloString, ConfigWithoutLayoutAssignment()).ValueOrDie();
|
||||||
|
CompileAndVerifyIr(std::move(hlo_module),
|
||||||
|
R"(
|
||||||
|
; CHECK-LABEL: define void @fusion
|
||||||
|
; CHECK-NOT: tail call void @llvm.nvvm.barrier0()
|
||||||
|
; CHECK: }
|
||||||
|
)",
|
||||||
|
/*match_optimized_ir=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user