Add a method FuseInstruction to InstructionFusion.
This provides a way for a subclass to add a hook to gather the newly created producer node which is part of the fusion computation. This will be needed for a future change. PiperOrigin-RevId: 292342465 Change-Id: Ie2be4e942b3bab72bc82ed895277a1361eac6c66
This commit is contained in:
parent
80cf1fd66f
commit
569e31aa6b
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
@ -613,12 +614,17 @@ HloInstruction* InstructionFusion::AddFusionInstruction(
|
||||
return fusion_instruction;
|
||||
}
|
||||
|
||||
HloInstruction* InstructionFusion::FuseInstruction(
|
||||
HloInstruction* fusion_instruction, HloInstruction* producer) {
|
||||
return fusion_instruction->FuseInstruction(producer);
|
||||
}
|
||||
|
||||
HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
|
||||
HloInstruction* consumer) {
|
||||
VLOG(2) << "Fusing " << producer->ToString() << " into "
|
||||
<< consumer->ToString();
|
||||
HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
|
||||
fusion_instruction->FuseInstruction(producer);
|
||||
FuseInstruction(fusion_instruction, producer);
|
||||
if (fusion_instruction != producer && fusion_instruction != consumer) {
|
||||
VLOG(2) << " created new fusion: " << fusion_instruction->ToString();
|
||||
}
|
||||
|
||||
@ -17,6 +17,9 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
|
||||
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/fusion_queue.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -87,7 +90,13 @@ class InstructionFusion : public HloModulePass {
|
||||
virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
|
||||
const HloInstruction* consumer);
|
||||
|
||||
// Fuses producer into consumer.
|
||||
// Fuses 'producer' into 'fusion_instruction'. 'fusion_instruction' needs to
|
||||
// be a fusion instruction. Returns the newly created clone of 'producer'
|
||||
// which is part of the fusion computation.
|
||||
virtual HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
|
||||
HloInstruction* producer);
|
||||
|
||||
// Fuses producer into consumer. Returns the fusion instruction.
|
||||
virtual HloInstruction* Fuse(HloInstruction* producer,
|
||||
HloInstruction* consumer);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user