Add a method FuseInstruction to InstructionFusion.

This provides a way for a subclass to add a hook to gather the newly created
producer node which is part of the fusion computation. This will be needed for
a future change.

PiperOrigin-RevId: 292342465
Change-Id: Ie2be4e942b3bab72bc82ed895277a1361eac6c66
This commit is contained in:
Adrian Kuegel 2020-01-30 07:21:08 -08:00 committed by TensorFlower Gardener
parent 80cf1fd66f
commit 569e31aa6b
2 changed files with 17 additions and 2 deletions

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <list> #include <list>
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <utility>
#include <vector> #include <vector>
#include "absl/algorithm/container.h" #include "absl/algorithm/container.h"
@ -613,12 +614,17 @@ HloInstruction* InstructionFusion::AddFusionInstruction(
return fusion_instruction; return fusion_instruction;
} }
HloInstruction* InstructionFusion::FuseInstruction(
HloInstruction* fusion_instruction, HloInstruction* producer) {
return fusion_instruction->FuseInstruction(producer);
}
HloInstruction* InstructionFusion::Fuse(HloInstruction* producer, HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
HloInstruction* consumer) { HloInstruction* consumer) {
VLOG(2) << "Fusing " << producer->ToString() << " into " VLOG(2) << "Fusing " << producer->ToString() << " into "
<< consumer->ToString(); << consumer->ToString();
HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer); HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
fusion_instruction->FuseInstruction(producer); FuseInstruction(fusion_instruction, producer);
if (fusion_instruction != producer && fusion_instruction != consumer) { if (fusion_instruction != producer && fusion_instruction != consumer) {
VLOG(2) << " created new fusion: " << fusion_instruction->ToString(); VLOG(2) << " created new fusion: " << fusion_instruction->ToString();
} }

View File

@ -17,6 +17,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
#include <functional>
#include <utility>
#include "tensorflow/compiler/xla/service/fusion_queue.h" #include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -87,7 +90,13 @@ class InstructionFusion : public HloModulePass {
virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer, virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
const HloInstruction* consumer); const HloInstruction* consumer);
// Fuses producer into consumer. // Fuses 'producer' into 'fusion_instruction'. 'fusion_instruction' needs to
// be a fusion instruction. Returns the newly created clone of 'producer'
// which is part of the fusion computation.
virtual HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
HloInstruction* producer);
// Fuses producer into consumer. Returns the fusion instruction.
virtual HloInstruction* Fuse(HloInstruction* producer, virtual HloInstruction* Fuse(HloInstruction* producer,
HloInstruction* consumer); HloInstruction* consumer);