rng loop fusion

PiperOrigin-RevId: 314951531
Change-Id: If2842a456471eaabd1fbab8c792fd84d95b4f792
This commit is contained in:
Yunxing Dai 2020-06-05 10:25:45 -07:00 committed by TensorFlower Gardener
parent 78eca642cf
commit 2178657934
3 changed files with 22 additions and 2 deletions

View File

@ -2819,6 +2819,8 @@ bool HloInstruction::IsFusible() const {
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
return true;
case HloOpcode::kRng:
return user_count() <= 1;
// Side effecting instructions cannot be fused.
default:
return !HasSideEffect();

View File

@ -1425,9 +1425,25 @@ void HloFusionInstruction::MergeFusionInstruction(
unfused_instructions.empty());
// Replace instruction_to_merge use of 'this' with unfused_root.
TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
// Fuse 'unfused_instructions' into 'this'.
// Build a dummy root for the cloned fusion as we may remove the original root
// in the fusion process.
if (!unfused_instructions.empty()) {
HloComputation* computation = unfused_root->parent();
auto* dummy_root = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::Zero(U32)));
computation->set_root_instruction(dummy_root,
/*accept_different_shape=*/true);
}
// Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction
// we remove it from the closed fusion node. This is so that we don't add
// extra users to the producer of that instruction (we use user count to
// decide if a side-effectful instruction is fusible).
for (auto& instruction : unfused_instructions) {
FuseInstruction(instruction);
auto* fused = FuseInstruction(instruction);
TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused));
TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
}
CHECK_EQ(0, cloned_fusion->user_count());
TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(

View File

@ -184,6 +184,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
abs1 = f32[] abs(add)
rng = f32[] rng(p1, abs1), distribution=rng_uniform
abs2 = f32[] abs(rng)
abs3 = f32[] abs(rng)
ROOT root = f32[] subtract(abs2, add)
})")
.ValueOrDie();
@ -375,6 +376,7 @@ TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) {
abs1 = f32[] abs(add)
rng = f32[] rng(p0, abs1), distribution=rng_uniform
abs2 = f32[] abs(rng)
abs3 = f32[] abs(rng)
ROOT root = f32[] subtract(abs2, add)
})")
.ValueOrDie();