Add caching for reused parameters.
Since a fusion node is never splitted, we can assume that if a parameter was reused, it will stay that way even if additional instructions are fused in. PiperOrigin-RevId: 329497992 Change-Id: Ie0cdf7161d01704139783c329f88f7312bf7edf8
This commit is contained in:
parent
6293abccc7
commit
adcee3ccc4
@ -95,7 +95,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
// Cost condition: not fuse (simple, expensive producers) and (consumers who
|
||||
// reuse operand elements).
|
||||
if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) &&
|
||||
consumer->ReusesOperandElements(operand_index)) {
|
||||
ReusesOperandElements(consumer, operand_index)) {
|
||||
VLOG(2) << "Fusion is not profitable.";
|
||||
return false;
|
||||
}
|
||||
|
@ -65,9 +65,8 @@ bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer,
|
||||
}
|
||||
// Cost condition: not fuse (simple, expensive producers) and (consumers who
|
||||
// reuse operand elements).
|
||||
if (producer->opcode() != HloOpcode::kFusion &&
|
||||
consumer->ReusesOperandElements(operand_index) &&
|
||||
is_expensive(*producer)) {
|
||||
if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) &&
|
||||
ReusesOperandElements(consumer, operand_index)) {
|
||||
VLOG(4) << "Do not fuse simple, expensive producer " << producer->name()
|
||||
<< " and consumer which reuses operand elements.";
|
||||
return false;
|
||||
|
@ -711,4 +711,23 @@ HloInstruction::FusionKind InstructionFusion::ChooseKind(
|
||||
return HloInstruction::FusionKind::kLoop;
|
||||
}
|
||||
|
||||
bool InstructionFusion::ReusesOperandElements(const HloInstruction* consumer,
|
||||
int64 operand_index) {
|
||||
auto operand = consumer->operand(operand_index);
|
||||
auto it = reused_fusion_operands_.find(consumer);
|
||||
if (it != reused_fusion_operands_.end() && it->second.contains(operand)) {
|
||||
return true;
|
||||
}
|
||||
bool reuses = consumer->ReusesOperandElements(operand_index);
|
||||
// If a parameter was reused, we can cache this information. Fusion
|
||||
// computations only ever grow, so it becomes more likely that a parameter is
|
||||
// reused, but a reused parameter will never become *not* reused.
|
||||
if (reuses) {
|
||||
// We cache the operand corresponding to the fusion parameter, because the
|
||||
// parameter pointers would be invalidated after the next fusion.
|
||||
reused_fusion_operands_[consumer].insert(operand);
|
||||
}
|
||||
return reuses;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -1,4 +1,3 @@
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -20,6 +19,8 @@ limitations under the License.
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/fusion_queue.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -138,6 +139,11 @@ class InstructionFusion : public HloModulePass {
|
||||
return config_collection_mode_;
|
||||
}
|
||||
|
||||
// Returns whether 'consumer' may reuse elements of its `operand_index`th
|
||||
// operand.
|
||||
bool ReusesOperandElements(const HloInstruction* consumer,
|
||||
int64 operand_index);
|
||||
|
||||
private:
|
||||
// The set of producers whose consumers we cannot fuse into.
|
||||
using HloInstructionSet = std::unordered_set<HloInstruction*>;
|
||||
@ -172,6 +178,11 @@ class InstructionFusion : public HloModulePass {
|
||||
// Configuration mode.
|
||||
FusionConfigCollection config_collection_mode_;
|
||||
|
||||
// Caches which operands are reused inside fusion computations.
|
||||
absl::flat_hash_map<const HloInstruction*,
|
||||
absl::flat_hash_set<const HloInstruction*>>
|
||||
reused_fusion_operands_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(InstructionFusion);
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user