Avoid fusing expensive ops with reusing ops.

We already avoid this for regular fusion, so we should check it in the
FusionMerger pass as well. E.g. we don't want to fuse a reduce with an
(indirect) broadcast user.

PiperOrigin-RevId: 353008990
Change-Id: Id95aed8b82059a2fc24dbff52ce350c9bce6c5a2
This commit is contained in:
Adrian Kuegel 2021-01-21 06:53:29 -08:00 committed by TensorFlower Gardener
parent 227c8cfe5e
commit ef4847e727
2 changed files with 32 additions and 2 deletions

View File

@ -232,6 +232,9 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// This is done to avoid the duplication of expensive instructions, which
// would occur if 'fusion' were merged into multiple users.
//
// Also, we don't want to fuse expensive instructions with instructions which
// reuse its operand values (e.g. Broadcast instructions).
//
// However, if we are going to save a "lot" in memory bandwidth then we
// ignore how expensive the fusion instructions are. The heuristic used to
// determine "a lot" is the following: merging must reduce memory traffic by a
@ -239,8 +242,12 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// trivial (above 1K). This likely has room for improvement in the future.
bool allow_expensive_ops =
fusion->user_count() == 1 ||
(merged_to_current_bytes_ratio < 0.3 && current_bytes_transferred > 1024);
(fusion->user_count() == 1 || (merged_to_current_bytes_ratio < 0.3 &&
current_bytes_transferred > 1024)) &&
!absl::c_any_of(fusion->users(), [fusion](const HloInstruction* user) {
int64 operand_index = user->operand_index(fusion);
return user->ReusesOperandElements(operand_index);
});
if (!allow_expensive_ops &&
absl::c_any_of(fusion->fused_instructions(),

View File

@ -469,6 +469,29 @@ TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) {
EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
}
TEST_F(FusionMergerTest, WillNotMergeExpensiveFusionsWithReusingConsumer) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule m
%f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] {
%p = f32[1024,1024,1024] parameter(0)
ROOT %t = f32[1024,1024,1024] tanh(%p)
}
%f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024,2] {
%p = f32[1024,1024,1024] parameter(0)
ROOT %t = f32[1024,1024,1024,2] broadcast(%p), dimensions={0,1,2}
}
ENTRY entry {
p0 = f32[1024,1024,1024] parameter(0)
f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b
ROOT f2 = f32[1024,1024,1024,2] fusion(f1), kind=kLoop, calls=%f_c
})")
.ValueOrDie();
EXPECT_FALSE(FusionMerger().Run(module.get()).ValueOrDie());
}
TEST_F(FusionMergerTest, NoMergeBecauseCodeDuplication) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule module