[XLA:GPU] Support fusion of unfused consumers in FusionMerger pass.
PiperOrigin-RevId: 329683158 Change-Id: I0d93e0441e823f5a13e18268939568ed31d44fdf
This commit is contained in:
parent
da2c5d2647
commit
ff1ecf14de
@ -201,8 +201,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
||||
// Merging into all users enables the removal of 'fusion' from the
|
||||
// computation.
|
||||
if (!absl::c_all_of(fusion->users(), [&](const HloInstruction* user) {
|
||||
return user->opcode() == HloOpcode::kFusion &&
|
||||
IsProducerConsumerFusible(*fusion, *user);
|
||||
return IsProducerConsumerFusible(*fusion, *user);
|
||||
})) {
|
||||
VLOG(3) << "Not merging " << fusion->name()
|
||||
<< ": Some of its users are not loop/input fusion kernels.";
|
||||
@ -283,7 +282,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
||||
// Merge fused instructions from 'fusion' into each user.
|
||||
std::vector<HloInstruction*> users = fusion->users();
|
||||
for (HloInstruction* user : users) {
|
||||
user->MergeFusionInstruction(fusion);
|
||||
if (user->opcode() == HloOpcode::kFusion) {
|
||||
user->MergeFusionInstruction(fusion);
|
||||
} else {
|
||||
HloInstruction* fused_user =
|
||||
computation_->AddInstruction(HloInstruction::CreateFusion(
|
||||
user->shape(), ChooseFusionKind(*fusion, *user), user));
|
||||
TF_CHECK_OK(computation_->ReplaceInstruction(user, fused_user));
|
||||
fused_user->MergeFusionInstruction(fusion);
|
||||
}
|
||||
changed_ = true;
|
||||
}
|
||||
++total_merged_;
|
||||
@ -296,7 +303,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
||||
})
|
||||
<< " }";
|
||||
// Remove 'fusion' instruction.
|
||||
CHECK_EQ(0, fusion->user_count());
|
||||
CHECK_EQ(0, fusion->user_count()) << fusion->ToString();
|
||||
return computation_->RemoveInstruction(fusion);
|
||||
}
|
||||
|
||||
|
@ -234,6 +234,54 @@ TEST_F(FusionMergerTest, WillMergeIntoInputFusion) {
|
||||
op::Fusion(op::Parameter()));
|
||||
}
|
||||
|
||||
TEST_F(FusionMergerTest, WillMergeIntoUnfusedConsumer) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule jit_matmul.36
|
||||
|
||||
max (parameter.13: f32[], parameter.14: f32[]) -> f32[] {
|
||||
parameter.13 = f32[] parameter(0)
|
||||
parameter.14 = f32[] parameter(1)
|
||||
ROOT maximum.15 = f32[] maximum(f32[] parameter.13, f32[] parameter.14)
|
||||
}
|
||||
|
||||
add (parameter.29: f32[], parameter.30: f32[]) -> f32[] {
|
||||
parameter.29 = f32[] parameter(0)
|
||||
parameter.30 = f32[] parameter(1)
|
||||
ROOT add.31 = f32[] add(f32[] parameter.29, f32[] parameter.30)
|
||||
}
|
||||
|
||||
fused_computation.1 (param_1.4: f32[200,200,200], param_2.1: f32[200,200]) -> f32[200,200] {
|
||||
param_1.4 = f32[200,200,200]{2,1,0} parameter(0)
|
||||
param_2.1 = f32[200,200]{1,0} parameter(1)
|
||||
broadcast.3 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_2.1), dimensions={0,2}
|
||||
subtract.0 = f32[200,200,200]{2,1,0} subtract(f32[200,200,200]{2,1,0} param_1.4, f32[200,200,200]{2,1,0} broadcast.3)
|
||||
exponential.0 = f32[200,200,200]{2,1,0} exponential(f32[200,200,200]{2,1,0} subtract.0)
|
||||
constant.27 = f32[] constant(0)
|
||||
ROOT reduce.0 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} exponential.0, f32[] constant.27), dimensions={1}, to_apply=add
|
||||
}
|
||||
|
||||
fused_computation.3 (param_0.7: f32[200,200], param_1.9: f32[200,200]) -> f32[200,200,200] {
|
||||
param_1.9 = f32[200,200]{1,0} parameter(1)
|
||||
broadcast.10 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_1.9), dimensions={0,1}
|
||||
param_0.7 = f32[200,200]{1,0} parameter(0)
|
||||
broadcast.8 = f32[200,200,200]{2,1,0} broadcast(f32[200,200]{1,0} param_0.7), dimensions={1,2}
|
||||
ROOT add.1 = f32[200,200,200]{2,1,0} add(f32[200,200,200]{2,1,0} broadcast.10, f32[200,200,200]{2,1,0} broadcast.8)
|
||||
}
|
||||
|
||||
ENTRY entry (parameter.1: f32[200,200], parameter.2: f32[200,200]) -> f32[200,200] {
|
||||
parameter.2 = f32[200,200]{1,0} parameter(1)
|
||||
parameter.1 = f32[200,200]{1,0} parameter(0)
|
||||
fusion.3 = f32[200,200,200]{2,1,0} fusion(f32[200,200]{1,0} parameter.2, f32[200,200]{1,0} parameter.1), kind=kLoop, calls=fused_computation.3
|
||||
constant.11 = f32[] constant(-inf)
|
||||
reduce.16 = f32[200,200]{1,0} reduce(f32[200,200,200]{2,1,0} fusion.3, f32[] constant.11), dimensions={1}, to_apply=max
|
||||
ROOT fusion.1 = f32[200,200]{1,0} fusion(f32[200,200,200]{2,1,0} fusion.3, f32[200,200]{1,0} reduce.16), kind=kInput, calls=fused_computation.1
|
||||
})")
|
||||
.ValueOrDie();
|
||||
EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie());
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Fusion(op::Fusion(), op::Parameter(), op::Parameter()));
|
||||
}
|
||||
|
||||
TEST_F(FusionMergerTest, WillNotMergeReduceUnfriendlyLayouts) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule m
|
||||
|
Loading…
x
Reference in New Issue
Block a user