From 100c6184f4f1df82f0c465bcb07be34dc52da614 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 7 Nov 2018 14:00:26 -0800 Subject: [PATCH] [XLA:GPU] Fuse reverse HLOs Turns out reverse is not elementwise, but we can still fuse it. PiperOrigin-RevId: 220524481 --- .../xla/service/gpu/instruction_fusion.cc | 1 + .../service/gpu/instruction_fusion_test.cc | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 7f2b59810f0..43f43b50e4a 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -47,6 +47,7 @@ bool IsFusible(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kReduceWindow || hlo.opcode() == HloOpcode::kReshape || + hlo.opcode() == HloOpcode::kReverse || hlo.opcode() == HloOpcode::kScatter || hlo.opcode() == HloOpcode::kSlice || hlo.opcode() == HloOpcode::kTranspose; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 57e66f5a12c..3cd49131349 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -805,5 +805,26 @@ TEST_F(InstructionFusionTest, NonscalarConstantsNotFused) { op::Reduce(op::Broadcast(op::Parameter()), op::Constant())); } +TEST_F(InstructionFusionTest, FuseReverse) { + auto module = ParseHloString(R"( + HloModule test_module + + ENTRY Reverse { + p0 = f32[50,96,1024]{2,1,0} parameter(0) + add = f32[50,96,1024]{2,1,0} add(p0, p0) + ROOT reverse = f32[50,96,1024] reverse(add), dimensions={0} + })") + .ValueOrDie(); + + EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true) + .Run(module.get()) + .ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::Fusion()); + EXPECT_THAT(root->fused_expression_root(), + op::Reverse(op::Add(op::Parameter(), op::Parameter()))); +} + } // namespace gpu } // namespace xla