[XLA] Support fusion of trivial fusion nodes that only contain a parameter
instruction. PiperOrigin-RevId: 257245425
This commit is contained in:
parent
1059951b8e
commit
50c6158ee8
@ -1308,10 +1308,21 @@ void HloFusionInstruction::MergeFusionInstruction(
|
|||||||
unfused_instructions.push_back(fused_instruction);
|
unfused_instructions.push_back(fused_instruction);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root());
|
|
||||||
|
// If there are no unfused instructions, the fused computation must consist
|
||||||
|
// only of kParameter instructions. Make the operand of the corresponding
|
||||||
|
// parameter number the new root.
|
||||||
|
HloInstruction* unfused_root =
|
||||||
|
unfused_instructions.empty()
|
||||||
|
? instruction_to_merge->mutable_operand(
|
||||||
|
instruction_to_merge->fused_instructions_computation()
|
||||||
|
->root_instruction()
|
||||||
|
->parameter_number())
|
||||||
|
: unfused_instructions.front();
|
||||||
|
CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
|
||||||
|
unfused_instructions.empty());
|
||||||
// Replace instruction_to_merge use of 'this' with unfused_root.
|
// Replace instruction_to_merge use of 'this' with unfused_root.
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
|
||||||
instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
|
|
||||||
// Fuse 'unfused_instructions' into 'this'.
|
// Fuse 'unfused_instructions' into 'this'.
|
||||||
for (auto& instruction : unfused_instructions) {
|
for (auto& instruction : unfused_instructions) {
|
||||||
FuseInstruction(instruction);
|
FuseInstruction(instruction);
|
||||||
@ -1359,7 +1370,16 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
HloInstruction* unfused_root = unfused_instructions.front();
|
// If there are no unfused instructions, the fused computation must consist
|
||||||
|
// only of kParameter instructions. Make the operand of the corresponding
|
||||||
|
// parameter number the new root.
|
||||||
|
HloInstruction* unfused_root =
|
||||||
|
unfused_instructions.empty()
|
||||||
|
? instruction_to_merge->mutable_operand(
|
||||||
|
instruction_to_merge->fused_instructions_computation()
|
||||||
|
->root_instruction()
|
||||||
|
->parameter_number())
|
||||||
|
: unfused_instructions.front();
|
||||||
TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
|
TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
|
||||||
|
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
@ -1369,6 +1389,9 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Fuse the root instruction and generate multiple outputs.
|
// Fuse the root instruction and generate multiple outputs.
|
||||||
|
if (unfused_instructions.empty()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
FuseInstructionIntoMultiOutput(unfused_root);
|
FuseInstructionIntoMultiOutput(unfused_root);
|
||||||
TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
|
TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
|
||||||
// The rest instructions are of normal fusing.
|
// The rest instructions are of normal fusing.
|
||||||
|
Loading…
Reference in New Issue
Block a user