diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index da25d5d928b..daf84dc39fc 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -613,12 +614,17 @@ HloInstruction* InstructionFusion::AddFusionInstruction( return fusion_instruction; } +HloInstruction* InstructionFusion::FuseInstruction( + HloInstruction* fusion_instruction, HloInstruction* producer) { + return fusion_instruction->FuseInstruction(producer); +} + HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction* consumer) { VLOG(2) << "Fusing " << producer->ToString() << " into " << consumer->ToString(); HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); - fusion_instruction->FuseInstruction(producer); + FuseInstruction(fusion_instruction, producer); if (fusion_instruction != producer && fusion_instruction != consumer) { VLOG(2) << " created new fusion: " << fusion_instruction->ToString(); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 3c39284a80a..90d9da48e33 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -17,6 +17,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ +#include +#include + #include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -87,7 +90,13 @@ class InstructionFusion : public HloModulePass { virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, const HloInstruction* consumer); - // Fuses producer into consumer. + // Fuses 'producer' into 'fusion_instruction'. 'fusion_instruction' needs to + // be a fusion instruction. Returns the newly created clone of 'producer' + // which is part of the fusion computation. + virtual HloInstruction* FuseInstruction(HloInstruction* fusion_instruction, + HloInstruction* producer); + + // Fuses producer into consumer. Returns the fusion instruction. virtual HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer);