diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index f9cd61bea3d..6f79ad7c146 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -48,10 +48,15 @@ bool IsMatrixVectorDot(const HloInstruction* hlo) { (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); } +bool HasExactlyOneUse(const HloInstruction& hlo_instr) { + return hlo_instr.user_count() == 1 && + absl::c_count(hlo_instr.users().front()->operands(), &hlo_instr) == 1; +} + bool CanBeOutputFused(const HloInstruction* producer, const HloInstruction* consumer) { return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && - producer->user_count() == 1; + HasExactlyOneUse(*producer) == 1; } bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index c77d5988ba3..527df0bd1c2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -763,6 +763,28 @@ TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { Not(op::Fusion())); } +TEST_F(InstructionFusionTest, + DotOperationFusion_DontOutputFuseDuplicateOperands) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[60,1]{1,0} parameter(1) + c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT d = f32[50,1]{1,0} add(c, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + struct GatherLoopFusionTestSpec { string test_name; string hlo_computation_text; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc index 985c0d411ee..6151dd8ff4c 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc @@ -181,7 +181,8 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } } else if (consumer->operand_count() == 2 && - consumer->opcode() == HloOpcode::kAdd) { + consumer->opcode() == HloOpcode::kAdd && + consumer->operand(other_operand_index) != producer) { // Fuse a bias add into the output of the dot. return true; } diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 2b060b03cea..688604cd36e 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -358,6 +358,29 @@ TEST_F(InstructionFusionTest, DotOutputFusionBiasAdd) { op::Parameter())); } +TEST_F(InstructionFusionTest, + DotOperationFusion_DontOutputFuseDuplicateOperands) { + absl::string_view module_string = R"( +HloModule module + +ENTRY main { + a = f32[50,60]{1,0} parameter(0) + b = f32[60,1]{1,0} parameter(1) + c = f32[50,1]{1,0} dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT d = f32[50,1]{1,0} add(c, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool fused_something, + GpuInstructionFusion(/*may_duplicate=*/false).Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + // Compute sum(1/p0), where p0 has type f32, twice. Check that the division is // duplicated and fused into both reduces. TEST_F(InstructionFusionTest, FloatingPointDivIsCheap) {