rng loop fusion
PiperOrigin-RevId: 314951531 Change-Id: If2842a456471eaabd1fbab8c792fd84d95b4f792
This commit is contained in:
parent
78eca642cf
commit
2178657934
@ -2819,6 +2819,8 @@ bool HloInstruction::IsFusible() const {
|
|||||||
case HloOpcode::kReduce:
|
case HloOpcode::kReduce:
|
||||||
case HloOpcode::kReduceWindow:
|
case HloOpcode::kReduceWindow:
|
||||||
return true;
|
return true;
|
||||||
|
case HloOpcode::kRng:
|
||||||
|
return user_count() <= 1;
|
||||||
// Side effecting instructions cannot be fused.
|
// Side effecting instructions cannot be fused.
|
||||||
default:
|
default:
|
||||||
return !HasSideEffect();
|
return !HasSideEffect();
|
||||||
|
|||||||
@ -1425,9 +1425,25 @@ void HloFusionInstruction::MergeFusionInstruction(
|
|||||||
unfused_instructions.empty());
|
unfused_instructions.empty());
|
||||||
// Replace instruction_to_merge use of 'this' with unfused_root.
|
// Replace instruction_to_merge use of 'this' with unfused_root.
|
||||||
TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
|
TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
|
||||||
// Fuse 'unfused_instructions' into 'this'.
|
|
||||||
|
// Build a dummy root for the cloned fusion as we may remove the original root
|
||||||
|
// in the fusion process.
|
||||||
|
if (!unfused_instructions.empty()) {
|
||||||
|
HloComputation* computation = unfused_root->parent();
|
||||||
|
auto* dummy_root = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::Zero(U32)));
|
||||||
|
computation->set_root_instruction(dummy_root,
|
||||||
|
/*accept_different_shape=*/true);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fuse 'unfused_instructions' into 'this'. Everytime we fuse an instruction
|
||||||
|
// we remove it from the closed fusion node. This is so that we don't add
|
||||||
|
// extra users to the producer of that instruction (we use user count to
|
||||||
|
// decide if a side-effectful instruction is fusible).
|
||||||
for (auto& instruction : unfused_instructions) {
|
for (auto& instruction : unfused_instructions) {
|
||||||
FuseInstruction(instruction);
|
auto* fused = FuseInstruction(instruction);
|
||||||
|
TF_CHECK_OK(instruction->ReplaceAllUsesWith(fused));
|
||||||
|
TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction));
|
||||||
}
|
}
|
||||||
CHECK_EQ(0, cloned_fusion->user_count());
|
CHECK_EQ(0, cloned_fusion->user_count());
|
||||||
TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
|
TF_CHECK_OK(parent()->parent()->RemoveEmbeddedComputation(
|
||||||
|
|||||||
@ -184,6 +184,7 @@ TEST_F(InstructionFusionTest, AvoidDuplicationIfNotAllFusibleRecursively) {
|
|||||||
abs1 = f32[] abs(add)
|
abs1 = f32[] abs(add)
|
||||||
rng = f32[] rng(p1, abs1), distribution=rng_uniform
|
rng = f32[] rng(p1, abs1), distribution=rng_uniform
|
||||||
abs2 = f32[] abs(rng)
|
abs2 = f32[] abs(rng)
|
||||||
|
abs3 = f32[] abs(rng)
|
||||||
ROOT root = f32[] subtract(abs2, add)
|
ROOT root = f32[] subtract(abs2, add)
|
||||||
})")
|
})")
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
@ -375,6 +376,7 @@ TEST_F(InstructionFusionTest, AllowBinarySameValueOperandsDuplication) {
|
|||||||
abs1 = f32[] abs(add)
|
abs1 = f32[] abs(add)
|
||||||
rng = f32[] rng(p0, abs1), distribution=rng_uniform
|
rng = f32[] rng(p0, abs1), distribution=rng_uniform
|
||||||
abs2 = f32[] abs(rng)
|
abs2 = f32[] abs(rng)
|
||||||
|
abs3 = f32[] abs(rng)
|
||||||
ROOT root = f32[] subtract(abs2, add)
|
ROOT root = f32[] subtract(abs2, add)
|
||||||
})")
|
})")
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user