[XLA:GPU] Support fusion of unfused consumers in FusionMerger pass.
PiperOrigin-RevId: 329683158 Change-Id: I0d93e0441e823f5a13e18268939568ed31d44fdf
This commit is contained in:
parent
da2c5d2647
commit
ff1ecf14de
@ -201,8 +201,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
|||||||
// Merging into all users enables the removal of 'fusion' from the
|
// Merging into all users enables the removal of 'fusion' from the
|
||||||
// computation.
|
// computation.
|
||||||
if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) {
|
if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) {
|
||||||
return user->opcode() == HloOpcode::kFusion &&
|
return IsProducerConsumerFusible(*fusion, *user);
|
||||||
IsProducerConsumerFusible(*fusion, *user);
|
|
||||||
})) {
|
})) {
|
||||||
VLOG(3) << "Not merging " << fusion->name()
|
VLOG(3) << "Not merging " << fusion->name()
|
||||||
<< ": Some of its users are not loop/input fusion kernels.";
|
<< ": Some of its users are not loop/input fusion kernels.";
|
||||||
@ -283,7 +282,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
|||||||
// Merge fused instructions from 'fusion' into each user.
|
// Merge fused instructions from 'fusion' into each user.
|
||||||
std::vector<HloInstruction*> users = fusion->users();
|
std::vector<HloInstruction*> users = fusion->users();
|
||||||
for (HloInstruction* user : users) {
|
for (HloInstruction* user : users) {
|
||||||
user->MergeFusionInstruction(fusion);
|
if (user->opcode() == HloOpcode::kFusion) {
|
||||||
|
user->MergeFusionInstruction(fusion);
|
||||||
|
} else {
|
||||||
|
HloInstruction* fused_user =
|
||||||
|
computation_->AddInstruction(HloInstruction::CreateFusion(
|
||||||
|
user->shape(), ChooseFusionKind(*fusion, *user), user));
|
||||||
|
TF_CHECK_OK(computation_->ReplaceInstruction(user, fused_user));
|
||||||
|
fused_user->MergeFusionInstruction(fusion);
|
||||||
|
}
|
||||||
changed_ = true;
|
changed_ = true;
|
||||||
}
|
}
|
||||||
++total_merged_;
|
++total_merged_;
|
||||||
@ -296,7 +303,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
|||||||
})
|
})
|
||||||
<< " }";
|
<< " }";
|
||||||
// Remove 'fusion' instruction.
|
// Remove 'fusion' instruction.
|
||||||
CHECK_EQ(0, fusion->user_count());
|
CHECK_EQ(0, fusion->user_count()) << fusion->ToString();
|
||||||
return computation_->RemoveInstruction(fusion);
|
return computation_->RemoveInstruction(fusion);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,6 +234,54 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
|
|||||||
op::Fusion(op::Parameter()));
|
op::Fusion(op::Parameter()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(FusionMergerTest, WillMergeIntoUnfusedConsumer) {
|
||||||
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
|
HloModule jit_matmul.36
|
||||||
|
|
||||||
|
max (parameter.13: f32[], parameter.14: f32[]) -> f32[] {
|
||||||
|
parameter.13 = f32[] parameter(0)
|
||||||
|
parameter.14 = f32[] parameter(1)
|
||||||
|
ROOT maximum.15 = f32[] maximum(f32[] parameter.13, f32[] parameter.14)
|
||||||
|
}
|
||||||
|
|
||||||
|
add (parameter.29: f32[], parameter.30: f32[]) -> f32[] {
|
||||||
|
parameter.29 = f32[] parameter(0)
|
||||||
|
parameter.30 = f32[] parameter(1)
|
||||||
|
ROOT add.31 = f32[] add(f32[] parameter.29, f32[] parameter.30)
|
||||||
|
}
|
||||||
|
|
||||||
|
fused_computation.1 (param_1.4: f32[200,200,200], param_2.1: f32[200,200]) -> f32[200,200] {
|
||||||
|
param_1.4 = f32[200,200,200]{2,1,0} parameter(0)
|
||||||
|
param_2.1 = f32[200,200]{1,0} parameter(1)
|
||||||
|
broadcast.3 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_2.1), dimensions={0,2}
|
||||||
|
subtract.0 = f32[200,200,200]{2,1,0} subtract(f32[200,200,200]{2,1,0} param_1.4, f32[200,200,200]{2,1,0} broadcast.3)
|
||||||
|
exponential.0 = f32[200,200,200]{2,1,0} exponential(f32[200,200,200]{2,1,0} subtract.0)
|
||||||
|
constant.27 = f32[] constant(0)
|
||||||
|
ROOT reduce.0 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} exponential.0, f32[] constant.27), dimensions={1}, to_apply=add
|
||||||
|
}
|
||||||
|
|
||||||
|
fused_computation.3 (param_0.7: f32[200,200], param_1.9: f32[200,200]) -> f32[200,200,200] {
|
||||||
|
param_1.9 = f32[200,200]{1,0} parameter(1)
|
||||||
|
broadcast.10 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_1.9), dimensions={0,1}
|
||||||
|
param_0.7 = f32[200,200]{1,0} parameter(0)
|
||||||
|
broadcast.8 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_0.7), dimensions={1,2}
|
||||||
|
ROOT add.1 = f32[200,200,200]{2,1,0} add(f32[200,200,200]{2,1,0} broadcast.10, f32[200,200,200]{2,1,0} broadcast.8)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry (parameter.1: f32[200,200], parameter.2: f32[200,200]) -> f32[200,200] {
|
||||||
|
parameter.2 = f32[200,200]{1,0} parameter(1)
|
||||||
|
parameter.1 = f32[200,200]{1,0} parameter(0)
|
||||||
|
fusion.3 = f32[200,200,200]{2,1,0} fusion(f32[200,200]{1,0} parameter.2, f32[200,200]{1,0} parameter.1), kind=kLoop, calls=fused_computation.3
|
||||||
|
constant.11 = f32[] constant(-inf)
|
||||||
|
reduce.16 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} fusion.3, f32[] constant.11), dimensions={1}, to_apply=max
|
||||||
|
ROOT fusion.1 = f32[200,200]{1,0} fusion(f32[200,200,200]{2,1,0} fusion.3, f32[200,200]{1,0} reduce.16), kind=kInput, calls=fused_computation.1
|
||||||
|
})")
|
||||||
|
.ValueOrDie();
|
||||||
|
EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Fusion(op::Fusion(), op::Parameter(), op::Parameter()));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
|
TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
|
||||||
auto module = ParseAndReturnVerifiedModule(R"(
|
auto module = ParseAndReturnVerifiedModule(R"(
|
||||||
HloModule m
|
HloModule m
|
||||||
|
Loading…
x
Reference in New Issue
Block a user