From 569e31aa6bd895a12caeb95ecafa7bc93544d0f1 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 30 Jan 2020 07:21:08 -0800 Subject: [PATCH] Add a method FuseInstruction to InstructionFusion. This provides a way for a subclass to add a hook to gather the newly created producer node which is part of the fusion computation. This will be needed for a future change. PiperOrigin-RevId: 292342465 Change-Id: Ie2be4e942b3bab72bc82ed895277a1361eac6c66 --- tensorflow/compiler/xla/service/instruction_fusion.cc | 8 +++++++- tensorflow/compiler/xla/service/instruction_fusion.h | 11 ++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index da25d5d928b..daf84dc39fc 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -613,12 +614,17 @@ HloInstruction* InstructionFusion::AddFusionInstruction( return fusion_instruction; } +HloInstruction* InstructionFusion::FuseInstruction( + HloInstruction* fusion_instruction, HloInstruction* producer) { + return fusion_instruction->FuseInstruction(producer); +} + HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction* consumer) { VLOG(2) << "Fusing " << producer->ToString() << " into " << consumer->ToString(); HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); - fusion_instruction->FuseInstruction(producer); + FuseInstruction(fusion_instruction, producer); if (fusion_instruction != producer && fusion_instruction != consumer) { VLOG(2) << " created new fusion: " << fusion_instruction->ToString(); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h index 3c39284a80a..90d9da48e33 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.h +++ b/tensorflow/compiler/xla/service/instruction_fusion.h @@ -17,6 +17,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ +#include +#include + #include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -87,7 +90,13 @@ class InstructionFusion : public HloModulePass { virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, const HloInstruction* consumer); - // Fuses producer into consumer. + // Fuses 'producer' into 'fusion_instruction'. 'fusion_instruction' needs to + // be a fusion instruction. Returns the newly created clone of 'producer' + // which is part of the fusion computation. + virtual HloInstruction* FuseInstruction(HloInstruction* fusion_instruction, + HloInstruction* producer); + + // Fuses producer into consumer. Returns the fusion instruction. virtual HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer);