[XLA] Support fusion of trivial fusion nodes that only contain a parameter

instruction.

PiperOrigin-RevId: 257245425
This commit is contained in:
Blake Hechtman 2019-07-09 12:16:18 -07:00 committed by TensorFlower Gardener
parent 1059951b8e
commit 50c6158ee8

View File

@ -1308,10 +1308,21 @@ void HloFusionInstruction::MergeFusionInstruction(
unfused_instructions.push_back(fused_instruction);
}
}
CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root());
// If there are no unfused instructions, the fused computation must consist
// only of kParameter instructions. Make the operand of the corresponding
// parameter number the new root.
HloInstruction* unfused_root =
unfused_instructions.empty()
? instruction_to_merge->mutable_operand(
instruction_to_merge->fused_instructions_computation()
->root_instruction()
->parameter_number())
: unfused_instructions.front();
CHECK(unfused_root == cloned_fusion->fused_expression_root() ||
unfused_instructions.empty());
// Replace instruction_to_merge use of 'this' with unfused_root.
TF_CHECK_OK(
instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front()));
TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root));
// Fuse 'unfused_instructions' into 'this'.
for (auto& instruction : unfused_instructions) {
FuseInstruction(instruction);
@ -1359,7 +1370,16 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
}
}
HloInstruction* unfused_root = unfused_instructions.front();
// If there are no unfused instructions, the fused computation must consist
// only of kParameter instructions. Make the operand of the corresponding
// parameter number the new root.
HloInstruction* unfused_root =
unfused_instructions.empty()
? instruction_to_merge->mutable_operand(
instruction_to_merge->fused_instructions_computation()
->root_instruction()
->parameter_number())
: unfused_instructions.front();
TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root));
TF_CHECK_OK(
@ -1369,6 +1389,9 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
}
// Fuse the root instruction and generate multiple outputs.
if (unfused_instructions.empty()) {
return;
}
FuseInstructionIntoMultiOutput(unfused_root);
TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root));
// The rest instructions are of normal fusing.