[XLA:GPU] Fuse reverse HLOs
Turns out reverse is not elementwise, but we can still fuse it. PiperOrigin-RevId: 220524481
This commit is contained in:
parent
4da7c3bc5a
commit
100c6184f4
@ -47,6 +47,7 @@ bool IsFusible(const HloInstruction& hlo) {
|
||||
hlo.opcode() == HloOpcode::kReduce ||
|
||||
hlo.opcode() == HloOpcode::kReduceWindow ||
|
||||
hlo.opcode() == HloOpcode::kReshape ||
|
||||
hlo.opcode() == HloOpcode::kReverse ||
|
||||
hlo.opcode() == HloOpcode::kScatter ||
|
||||
hlo.opcode() == HloOpcode::kSlice ||
|
||||
hlo.opcode() == HloOpcode::kTranspose;
|
||||
|
@ -805,5 +805,26 @@ TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) {
|
||||
op::Reduce(op::Broadcast(op::Parameter()), op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, FuseReverse) {
|
||||
auto module = ParseHloString(R"(
|
||||
HloModule test_module
|
||||
|
||||
ENTRY Reverse {
|
||||
p0 = f32[50,96,1024]{2,1,0} parameter(0)
|
||||
add = f32[50,96,1024]{2,1,0} add(p0, p0)
|
||||
ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0}
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
|
||||
HloInstruction* root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root, op::Fusion());
|
||||
EXPECT_THAT(root->fused_expression_root(),
|
||||
op::Reverse(op::Add(op::Parameter(), op::Parameter())));
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user