[XLA:GPU] Support fusion of unfused consumers in FusionMerger pass.

PiperOrigin-RevId: 329683158
Change-Id: I0d93e0441e823f5a13e18268939568ed31d44fdf
This commit is contained in:
Thomas Joerg 2020-09-02 03:54:09 -07:00 committed by TensorFlower Gardener
parent da2c5d2647
commit ff1ecf14de
2 changed files with 59 additions and 4 deletions

View File

@ -201,8 +201,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// Merging into all users enables the removal of 'fusion' from the
// computation.
if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) {
return user->opcode() == HloOpcode::kFusion &&
IsProducerConsumerFusible(*fusion, *user);
return IsProducerConsumerFusible(*fusion, *user);
})) {
VLOG(3) << "Not merging " << fusion->name()
<< ": Some of its users are not loop/input fusion kernels.";
@ -283,7 +282,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
// Merge fused instructions from 'fusion' into each user.
std::vector<HloInstruction*> users = fusion->users();
for (HloInstruction* user : users) {
user->MergeFusionInstruction(fusion);
if (user->opcode() == HloOpcode::kFusion) {
user->MergeFusionInstruction(fusion);
} else {
HloInstruction* fused_user =
computation_->AddInstruction(HloInstruction::CreateFusion(
user->shape(), ChooseFusionKind(*fusion, *user), user));
TF_CHECK_OK(computation_->ReplaceInstruction(user, fused_user));
fused_user->MergeFusionInstruction(fusion);
}
changed_ = true;
}
++total_merged_;
@ -296,7 +303,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
})
<< " }";
// Remove 'fusion' instruction.
CHECK_EQ(0, fusion->user_count());
CHECK_EQ(0, fusion->user_count()) << fusion->ToString();
return computation_->RemoveInstruction(fusion);
}

View File

@ -234,6 +234,54 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
op::Fusion(op::Parameter()));
}
TEST_F(FusionMergerTest, WillMergeIntoUnfusedConsumer) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule jit_matmul.36
max (parameter.13: f32[], parameter.14: f32[]) -> f32[] {
parameter.13 = f32[] parameter(0)
parameter.14 = f32[] parameter(1)
ROOT maximum.15 = f32[] maximum(f32[] parameter.13, f32[] parameter.14)
}
add (parameter.29: f32[], parameter.30: f32[]) -> f32[] {
parameter.29 = f32[] parameter(0)
parameter.30 = f32[] parameter(1)
ROOT add.31 = f32[] add(f32[] parameter.29, f32[] parameter.30)
}
fused_computation.1 (param_1.4: f32[200,200,200], param_2.1: f32[200,200]) -> f32[200,200] {
param_1.4 = f32[200,200,200]{2,1,0} parameter(0)
param_2.1 = f32[200,200]{1,0} parameter(1)
broadcast.3 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_2.1), dimensions={0,2}
subtract.0 = f32[200,200,200]{2,1,0} subtract(f32[200,200,200]{2,1,0} param_1.4, f32[200,200,200]{2,1,0} broadcast.3)
exponential.0 = f32[200,200,200]{2,1,0} exponential(f32[200,200,200]{2,1,0} subtract.0)
constant.27 = f32[] constant(0)
ROOT reduce.0 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} exponential.0, f32[] constant.27), dimensions={1}, to_apply=add
}
fused_computation.3 (param_0.7: f32[200,200], param_1.9: f32[200,200]) -> f32[200,200,200] {
param_1.9 = f32[200,200]{1,0} parameter(1)
broadcast.10 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_1.9), dimensions={0,1}
param_0.7 = f32[200,200]{1,0} parameter(0)
broadcast.8 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_0.7), dimensions={1,2}
ROOT add.1 = f32[200,200,200]{2,1,0} add(f32[200,200,200]{2,1,0} broadcast.10, f32[200,200,200]{2,1,0} broadcast.8)
}
ENTRY entry (parameter.1: f32[200,200], parameter.2: f32[200,200]) -> f32[200,200] {
parameter.2 = f32[200,200]{1,0} parameter(1)
parameter.1 = f32[200,200]{1,0} parameter(0)
fusion.3 = f32[200,200,200]{2,1,0} fusion(f32[200,200]{1,0} parameter.2, f32[200,200]{1,0} parameter.1), kind=kLoop, calls=fused_computation.3
constant.11 = f32[] constant(-inf)
reduce.16 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} fusion.3, f32[] constant.11), dimensions={1}, to_apply=max
ROOT fusion.1 = f32[200,200]{1,0} fusion(f32[200,200,200]{2,1,0} fusion.3, f32[200,200]{1,0} reduce.16), kind=kInput, calls=fused_computation.1
})")
.ValueOrDie();
EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Fusion(op::Fusion(), op::Parameter(), op::Parameter()));
}
TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule m