From 50c6158ee8de0bad7d818a0cdcbf0c55aa67cf4c Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Tue, 9 Jul 2019 12:16:18 -0700 Subject: [PATCH] [XLA] Support fusion of trivial fusion nodes that only contain a parameter instruction. PiperOrigin-RevId: 257245425 --- .../compiler/xla/service/hlo_instructions.cc | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e2f4c30610a..754ccc1ff9f 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1308,10 +1308,21 @@ void HloFusionInstruction::MergeFusionInstruction( unfused_instructions.push_back(fused_instruction); } } - CHECK(unfused_instructions.front() == cloned_fusion->fused_expression_root()); + + // If there are no unfused instructions, the fused computation must consist + // only of kParameter instructions. Make the operand of the corresponding + // parameter number the new root. + HloInstruction* unfused_root = + unfused_instructions.empty() + ? instruction_to_merge->mutable_operand( + instruction_to_merge->fused_instructions_computation() + ->root_instruction() + ->parameter_number()) + : unfused_instructions.front(); + CHECK(unfused_root == cloned_fusion->fused_expression_root() || + unfused_instructions.empty()); // Replace instruction_to_merge use of 'this' with unfused_root. - TF_CHECK_OK( - instruction_to_merge->ReplaceUseWith(this, unfused_instructions.front())); + TF_CHECK_OK(instruction_to_merge->ReplaceUseWith(this, unfused_root)); // Fuse 'unfused_instructions' into 'this'. for (auto& instruction : unfused_instructions) { FuseInstruction(instruction); @@ -1359,7 +1370,16 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( } } - HloInstruction* unfused_root = unfused_instructions.front(); + // If there are no unfused instructions, the fused computation must consist + // only of kParameter instructions. Make the operand of the corresponding + // parameter number the new root. + HloInstruction* unfused_root = + unfused_instructions.empty() + ? instruction_to_merge->mutable_operand( + instruction_to_merge->fused_instructions_computation() + ->root_instruction() + ->parameter_number()) + : unfused_instructions.front(); TF_CHECK_OK(instruction_to_merge->ReplaceAllUsesWith(unfused_root)); TF_CHECK_OK( @@ -1369,6 +1389,9 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( } // Fuse the root instruction and generate multiple outputs. + if (unfused_instructions.empty()) { + return; + } FuseInstructionIntoMultiOutput(unfused_root); TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); // The rest instructions are of normal fusing.