rng loop fusion
PiperOrigin-RevId: 314951531 Change-Id: If2842a456471eaabd1fbab8c792fd84d95b4f792
This commit is contained in:
parent
78eca642cf
commit
2178657934
@ -2819,6 +2819,8 @@ bool HloInstruction::IsFusible() const {
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReduceWindow:
|
||||
return true;
|
||||
case HloOpcode::kRng:
|
||||
return user_count() <= 1;
|
||||
// Side effecting instructions cannot be fused.
|
||||
default:
|
||||
return !HasSideEffect();
|
||||
|
||||
@ -1425,9 +1425,25 @@ void HloFusionInstruction::MergeFusionInstruction(
|
||||
unfused_instructions.empty());
|
||||
// Replace instruction_to_merge use of 'this' with unfused_root.
|
||||
TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
|
||||
// Fuse 'unfused_instructions' into 'this'.
|
||||
|
||||
// Build a dummy root for the cloned fusion as we may remove the original root
|
||||
// in the fusion process.
|
||||
if (!unfused_instructions.empty()) {
|
||||
HloComputation* computation = unfused_root->parent();
|
||||
auto* dummy_root = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::Zero(U32)));
|
||||
computation->set_root_instruction(dummy_root,
|
||||
/*accept_different_shape=*/true);
|
||||
}
|
||||
|
||||
// Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction
|
||||
// we remove it from the closed fusion node. This is so that we don't add
|
||||
// extra users to the producer of that instruction (we use user count to
|
||||
// decide if a side-effectful instruction is fusible).
|
||||
for (auto& instruction : unfused_instructions) {
|
||||
FuseInstruction(instruction);
|
||||
auto* fused = FuseInstruction(instruction);
|
||||
TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused));
|
||||
TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
|
||||
}
|
||||
CHECK_EQ(0, cloned_fusion->user_count());
|
||||
TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
|
||||
|
||||
@ -184,6 +184,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
||||
abs1 = f32[] abs(add)
|
||||
rng = f32[] rng(p1, abs1), distribution=rng_uniform
|
||||
abs2 = f32[] abs(rng)
|
||||
abs3 = f32[] abs(rng)
|
||||
ROOT root = f32[] subtract(abs2, add)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
@ -375,6 +376,7 @@ TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) {
|
||||
abs1 = f32[] abs(add)
|
||||
rng = f32[] rng(p0, abs1), distribution=rng_uniform
|
||||
abs2 = f32[] abs(rng)
|
||||
abs3 = f32[] abs(rng)
|
||||
ROOT root = f32[] subtract(abs2, add)
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user