[XLA:GPU] Fuse reverse HLOs

Turns out reverse is not elementwise, but we can still fuse it.

PiperOrigin-RevId: 220524481
This commit is contained in:
Benjamin Kramer 2018-11-07 14:00:26 -08:00 committed by TensorFlower Gardener
parent 4da7c3bc5a
commit 100c6184f4
2 changed files with 22 additions and 0 deletions

View File

@ -47,6 +47,7 @@ bool IsFusible(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kReduceWindow ||
hlo.opcode() == HloOpcode::kReshape ||
hlo.opcode() == HloOpcode::kReverse ||
hlo.opcode() == HloOpcode::kScatter ||
hlo.opcode() == HloOpcode::kSlice ||
hlo.opcode() == HloOpcode::kTranspose;

View File

@ -805,5 +805,26 @@ TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) {
op::Reduce(op::Broadcast(op::Parameter()), op::Constant()));
}
TEST_F(InstructionFusionTest, FuseReverse) {
auto module = ParseHloString(R"(
HloModule test_module
ENTRY Reverse {
p0 = f32[50,96,1024]{2,1,0} parameter(0)
add = f32[50,96,1024]{2,1,0} add(p0, p0)
ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0}
})")
.ValueOrDie();
EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
EXPECT_THAT(root->fused_expression_root(),
op::Reverse(op::Add(op::Parameter(), op::Parameter())));
}
} // namespace gpu
} // namespace xla