[XLA] Support fusion of trivial fusion nodes that only contain a parameter
instruction. PiperOrigin-RevId: 257245425
This commit is contained in:
parent
1059951b8e
commit
50c6158ee8
@ -1308,10 +1308,21 @@ void HloFusionInstruction::MergeFusionInstruction(
|
||||
unfused_instructions.push_back(fused_instruction);
|
||||
}
|
||||
}
|
||||
CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root());
|
||||
|
||||
// If there are no unfused instructions, the fused computation must consist
|
||||
// only of kParameter instructions. Make the operand of the corresponding
|
||||
// parameter number the new root.
|
||||
HloInstruction* unfused_root =
|
||||
unfused_instructions.empty()
|
||||
? instruction_to_merge->mutable_operand(
|
||||
instruction_to_merge->fused_instructions_computation()
|
||||
->root_instruction()
|
||||
->parameter_number())
|
||||
: unfused_instructions.front();
|
||||
CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
|
||||
unfused_instructions.empty());
|
||||
// Replace instruction_to_merge use of 'this' with unfused_root.
|
||||
TF_CHECK_OK(
|
||||
instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
|
||||
TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
|
||||
// Fuse 'unfused_instructions' into 'this'.
|
||||
for (auto& instruction : unfused_instructions) {
|
||||
FuseInstruction(instruction);
|
||||
@ -1359,7 +1370,16 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
|
||||
}
|
||||
}
|
||||
|
||||
HloInstruction* unfused_root = unfused_instructions.front();
|
||||
// If there are no unfused instructions, the fused computation must consist
|
||||
// only of kParameter instructions. Make the operand of the corresponding
|
||||
// parameter number the new root.
|
||||
HloInstruction* unfused_root =
|
||||
unfused_instructions.empty()
|
||||
? instruction_to_merge->mutable_operand(
|
||||
instruction_to_merge->fused_instructions_computation()
|
||||
->root_instruction()
|
||||
->parameter_number())
|
||||
: unfused_instructions.front();
|
||||
TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
|
||||
|
||||
TF_CHECK_OK(
|
||||
@ -1369,6 +1389,9 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
|
||||
}
|
||||
|
||||
// Fuse the root instruction and generate multiple outputs.
|
||||
if (unfused_instructions.empty()) {
|
||||
return;
|
||||
}
|
||||
FuseInstructionIntoMultiOutput(unfused_root);
|
||||
TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
|
||||
// The rest instructions are of normal fusing.
|
||||
|
Loading…
Reference in New Issue
Block a user