Add a method FuseInstruction to InstructionFusion.

This provides a way for a subclass to add a hook to gather the newly created
producer node which is part of the fusion computation. This will be needed for
a future change.

PiperOrigin-RevId: 292342465
Change-Id: Ie2be4e942b3bab72bc82ed895277a1361eac6c66
This commit is contained in:
Adrian Kuegel 2020-01-30 07:21:08 -08:00 committed by TensorFlower Gardener
parent 80cf1fd66f
commit 569e31aa6b
2 changed files with 17 additions and 2 deletions

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <list>
#include <memory>
#include <numeric>
#include <utility>
#include <vector>
#include "absl/algorithm/container.h"
@ -613,12 +614,17 @@ HloInstruction* InstructionFusion::AddFusionInstruction(
return fusion_instruction;
}
HloInstruction* InstructionFusion::FuseInstruction(
HloInstruction* fusion_instruction, HloInstruction* producer) {
return fusion_instruction->FuseInstruction(producer);
}
HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
HloInstruction* consumer) {
VLOG(2) << "Fusing " << producer->ToString() << " into "
<< consumer->ToString();
HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
fusion_instruction->FuseInstruction(producer);
FuseInstruction(fusion_instruction, producer);
if (fusion_instruction != producer && fusion_instruction != consumer) {
VLOG(2) << " created new fusion: " << fusion_instruction->ToString();
}

View File

@ -17,6 +17,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
#include <functional>
#include <utility>
#include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -87,7 +90,13 @@ class InstructionFusion : public HloModulePass {
virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
const HloInstruction* consumer);
// Fuses producer into consumer.
// Fuses 'producer' into 'fusion_instruction'. 'fusion_instruction' needs to
// be a fusion instruction. Returns the newly created clone of 'producer'
// which is part of the fusion computation.
virtual HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
HloInstruction* producer);
// Fuses producer into consumer. Returns the fusion instruction.
virtual HloInstruction* Fuse(HloInstruction* producer,
HloInstruction* consumer);