diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c02100debc3..8eeb45731d1 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2819,6 +2819,8 @@ bool HloInstruction::IsFusible() const { case HloOpcode::kReduce: case HloOpcode::kReduceWindow: return true; + case HloOpcode::kRng: + return user_count() <= 1; // Side effecting instructions cannot be fused. default: return !HasSideEffect(); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 9c5a66f0040..50e6b6bd884 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1425,9 +1425,25 @@ void HloFusionInstruction::MergeFusionInstruction( unfused_instructions.empty()); // Replace instruction_to_merge use of 'this' with unfused_root. TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root)); - // Fuse 'unfused_instructions' into 'this'. + + // Build a dummy root for the cloned fusion as we may remove the original root + // in the fusion process. + if (!unfused_instructions.empty()) { + HloComputation* computation = unfused_root->parent(); + auto* dummy_root = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::Zero(U32))); + computation->set_root_instruction(dummy_root, + /*accept_different_shape=*/true); + } + + // Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction + // we remove it from the closed fusion node. This is so that we don't add + // extra users to the producer of that instruction (we use user count to + // decide if a side-effectful instruction is fusible). for (auto& instruction : unfused_instructions) { - FuseInstruction(instruction); + auto* fused = FuseInstruction(instruction); + TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused)); + TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); } CHECK_EQ(0, cloned_fusion->user_count()); TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation( diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index f4309ea09ae..f3943c2c21e 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -184,6 +184,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) { abs1 = f32[] abs(add) rng = f32[] rng(p1, abs1), distribution=rng_uniform abs2 = f32[] abs(rng) + abs3 = f32[] abs(rng) ROOT root = f32[] subtract(abs2, add) })") .ValueOrDie(); @@ -375,6 +376,7 @@ TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) { abs1 = f32[] abs(add) rng = f32[] rng(p0, abs1), distribution=rng_uniform abs2 = f32[] abs(rng) + abs3 = f32[] abs(rng) ROOT root = f32[] subtract(abs2, add) })") .ValueOrDie();