Evaluate the maximum code duplication in fusion nodes instead of average.
This makes it easier to explain why a certain op was not fused. Also, evaluating the max instead of the average is easier, so we can simplify the code a bit. PiperOrigin-RevId: 329709163 Change-Id: I62fd7fba8c8f9db9124692c83981d5b9281f7f0d
This commit is contained in:
parent
5363cf3d1d
commit
ea462c5eac
@ -132,7 +132,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
fusion_node_evaluations_.emplace(consumer,
|
||||
FusionNodeIndexingEvaluation(consumer));
|
||||
}
|
||||
if (fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh(
|
||||
if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(
|
||||
producer)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -27,33 +27,25 @@ namespace xla {
|
||||
FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
|
||||
const HloInstruction* fusion)
|
||||
: fusion_(fusion) {
|
||||
total_emitted_instructions_ = 0;
|
||||
HloInstruction* root = fusion->fused_expression_root();
|
||||
indexing_users_[root].insert(fusion);
|
||||
index_usage_count_[fusion] = 1;
|
||||
RecomputeCache();
|
||||
}
|
||||
|
||||
bool FusionNodeIndexingEvaluation::AverageCodeDuplicationTooHigh(
|
||||
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
|
||||
const HloInstruction* producer) const {
|
||||
// This constant is arbitrarily chosen. Essentially we don't want to have too
|
||||
// much code duplication, because it slows down the compilation time. There is
|
||||
// a tradeoff between compilation time and runtime here.
|
||||
const int64 kAllowedCodeDuplication = 15;
|
||||
|
||||
// index_usage_count_ contains an entry for each instruction in the fusion
|
||||
// computation (except parameter instructions), plus an entry for the 'fusion'
|
||||
// instruction. So the size of this map is already one bigger than the number
|
||||
// of instructions in the fusion node that are emitted, thus accounting for
|
||||
// the number of instructions after 'producer' is fused.
|
||||
return EvaluateTotalEmittedInstructions(producer) /
|
||||
index_usage_count_.size() >
|
||||
kAllowedCodeDuplication;
|
||||
return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
|
||||
}
|
||||
|
||||
int64 FusionNodeIndexingEvaluation::EvaluateTotalEmittedInstructions(
|
||||
int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions(
|
||||
const HloInstruction* producer) const {
|
||||
int64 total = total_emitted_instructions_;
|
||||
int64 total = 0;
|
||||
for (const auto* user : indexing_users_.at(producer)) {
|
||||
total += index_usage_count_.at(user);
|
||||
}
|
||||
@ -99,7 +91,6 @@ void FusionNodeIndexingEvaluation::UpdateIndexUsageCount(
|
||||
total += index_usage_count_.at(user);
|
||||
}
|
||||
CHECK(index_usage_count_.emplace(instruction, total).second);
|
||||
total_emitted_instructions_ += total;
|
||||
}
|
||||
|
||||
void FusionNodeIndexingEvaluation::UpdateIndexingUsersOfOperands(
|
||||
|
@ -26,17 +26,14 @@ class FusionNodeIndexingEvaluation {
|
||||
public:
|
||||
explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion);
|
||||
|
||||
// Evaluate the average number of times an instruction is emitted inside the
|
||||
// fusion node, if 'producer' is fused into 'fusion_'. If this average
|
||||
// duplication is "too high" (some arbitrary chosen constant), returns
|
||||
// true.
|
||||
bool AverageCodeDuplicationTooHigh(const HloInstruction* producer) const;
|
||||
// Evaluate the number of times 'producer' would be emitted if it is fused
|
||||
// into 'fusion_'. If the duplication is "too high" (some arbitrary chosen
|
||||
// constant), returns true.
|
||||
bool CodeDuplicationTooHigh(const HloInstruction* producer) const;
|
||||
|
||||
// Evaluate the total number of times an instruction is emitted inside the
|
||||
// fusion node, if 'producer' is fused into 'fusion_'. An instruction may be
|
||||
// emitted several times, once for each different index value with which it is
|
||||
// indexed.
|
||||
int64 EvaluateTotalEmittedInstructions(const HloInstruction* producer) const;
|
||||
// Evaluate the number of times 'producer' would be emitted if it is fused
|
||||
// into 'fusion_'.
|
||||
int64 EvaluateEmittedInstructions(const HloInstruction* producer) const;
|
||||
|
||||
// Update the evaluation cache after having fused 'producer' into 'fusion_'.
|
||||
// 'producer' is the cloned instruction which is now part of the fusion
|
||||
@ -84,9 +81,6 @@ class FusionNodeIndexingEvaluation {
|
||||
|
||||
// The fusion instruction.
|
||||
const HloInstruction* fusion_;
|
||||
|
||||
// The total number of emitted instructions.
|
||||
int64 total_emitted_instructions_;
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
|
@ -29,7 +29,7 @@ using FusionNodeIndexingEvaluationTest = HloTestBase;
|
||||
|
||||
// Subclass of InstructionFusion exposing the protected methods Fuse and
|
||||
// FuseInstruction for testing. Also adds the FusionNodeIndexingEvaluation to
|
||||
// track the average code duplication due to indexing HloInstructions with
|
||||
// track the code duplication due to indexing HloInstructions with
|
||||
// different index values.
|
||||
class InstructionFusionForTesting : public InstructionFusion {
|
||||
public:
|
||||
@ -61,8 +61,8 @@ class InstructionFusionForTesting : public InstructionFusion {
|
||||
return InstructionFusion::Fuse(producer, consumer);
|
||||
}
|
||||
|
||||
int64 EvaluateTotalEmittedInstructions(const HloInstruction* producer,
|
||||
const HloInstruction* consumer) {
|
||||
int64 EvaluateEmittedInstructions(const HloInstruction* producer,
|
||||
const HloInstruction* consumer) {
|
||||
if (consumer->opcode() != HloOpcode::kFusion) {
|
||||
return 0;
|
||||
}
|
||||
@ -71,8 +71,8 @@ class InstructionFusionForTesting : public InstructionFusion {
|
||||
fusion_node_evaluations_.emplace(consumer,
|
||||
FusionNodeIndexingEvaluation(consumer));
|
||||
}
|
||||
return fusion_node_evaluations_.at(consumer)
|
||||
.EvaluateTotalEmittedInstructions(producer);
|
||||
return fusion_node_evaluations_.at(consumer).EvaluateEmittedInstructions(
|
||||
producer);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -109,8 +109,7 @@ TEST_F(FusionNodeIndexingEvaluationTest, FuseThreeInstructions) {
|
||||
HloInstruction* slice1 = sub->mutable_operand(0);
|
||||
HloInstruction* slice2 = sub->mutable_operand(1);
|
||||
auto fusion = instruction_fusion.Fuse(slice1, sub);
|
||||
EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(slice2, fusion),
|
||||
3);
|
||||
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice2, fusion), 1);
|
||||
instruction_fusion.Fuse(slice2, fusion);
|
||||
}
|
||||
|
||||
@ -151,37 +150,31 @@ TEST_F(FusionNodeIndexingEvaluationTest, ExponentialDuplicationPattern) {
|
||||
HloInstruction* slice2_1 = add2->mutable_operand(1);
|
||||
auto fusion = instruction_fusion.Fuse(slice2_0, add2);
|
||||
// So far we have fused add2 and slice2.0. So when we also fuse slice2.1, we
|
||||
// expect to emit 3 instructions.
|
||||
EXPECT_EQ(
|
||||
instruction_fusion.EvaluateTotalEmittedInstructions(slice2_1, fusion), 3);
|
||||
// expect to emit it 1 time.
|
||||
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice2_1, fusion),
|
||||
1);
|
||||
instruction_fusion.Fuse(slice2_1, fusion);
|
||||
HloInstruction* add1 = fusion->mutable_operand(0);
|
||||
EXPECT_EQ(add1->opcode(), HloOpcode::kAdd);
|
||||
// If we fuse add1 into 'fusion', it needs to be emitted twice, adding 2 to
|
||||
// the sum.
|
||||
EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add1, fusion),
|
||||
5);
|
||||
// If we fuse add1 into 'fusion', it needs to be emitted twice.
|
||||
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add1, fusion), 2);
|
||||
instruction_fusion.Fuse(add1, fusion);
|
||||
HloInstruction* slice1_0 = fusion->mutable_operand(0);
|
||||
EXPECT_EQ(slice1_0->opcode(), HloOpcode::kSlice);
|
||||
// If we fuse slice1.0 into 'fusion', it needs to be emitted twice, adding 2
|
||||
// to the sum.
|
||||
EXPECT_EQ(
|
||||
instruction_fusion.EvaluateTotalEmittedInstructions(slice1_0, fusion), 7);
|
||||
// If we fuse slice1.0 into 'fusion', it needs to be emitted twice.
|
||||
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice1_0, fusion),
|
||||
2);
|
||||
instruction_fusion.Fuse(slice1_0, fusion);
|
||||
HloInstruction* slice1_1 = fusion->mutable_operand(0);
|
||||
EXPECT_EQ(slice1_1->opcode(), HloOpcode::kSlice);
|
||||
// If we fuse slice1.1 into 'fusion', it needs to be emitted twice, adding 2
|
||||
// to the sum.
|
||||
EXPECT_EQ(
|
||||
instruction_fusion.EvaluateTotalEmittedInstructions(slice1_1, fusion), 9);
|
||||
// If we fuse slice1.1 into 'fusion', it needs to be emitted twice.
|
||||
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice1_1, fusion),
|
||||
2);
|
||||
instruction_fusion.Fuse(slice1_1, fusion);
|
||||
HloInstruction* add0 = fusion->mutable_operand(0);
|
||||
EXPECT_EQ(add0->opcode(), HloOpcode::kAdd);
|
||||
// If we fuse add0 into 'fusion', it needs to be emitted twice, adding 4 to
|
||||
// the sum.
|
||||
EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add0, fusion),
|
||||
13);
|
||||
// If we fuse add0 into 'fusion', it needs to be emitted four times.
|
||||
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add0, fusion), 4);
|
||||
instruction_fusion.Fuse(add0, fusion);
|
||||
}
|
||||
|
||||
@ -212,10 +205,9 @@ ENTRY entry_computation {
|
||||
HloInstruction* add0 = fusion->mutable_operand(0);
|
||||
EXPECT_EQ(add0->opcode(), HloOpcode::kAdd);
|
||||
// Here, the cache for the fusion node needs to be recomputed. Make sure we
|
||||
// still get the same evaluation as before when we incrementally built the
|
||||
// still get the same evaluation as before when we incrementally build the
|
||||
// cache.
|
||||
EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add0, fusion),
|
||||
13);
|
||||
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add0, fusion), 4);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -112,8 +112,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
fusion_node_evaluations_.emplace(consumer,
|
||||
FusionNodeIndexingEvaluation(consumer));
|
||||
}
|
||||
if (fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh(
|
||||
producer)) {
|
||||
if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) {
|
||||
VLOG(5) << "Fusion of " << producer->name() << " into " << consumer->name()
|
||||
<< " would result in overly large code duplication.";
|
||||
return false;
|
||||
|
@ -289,15 +289,9 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient(
|
||||
evaluate_fusion_computation(producer);
|
||||
}
|
||||
|
||||
// Sum up the total number of emitted ops.
|
||||
int64 total = 0;
|
||||
for (const auto& entry : index_usage_count) {
|
||||
total += entry.second;
|
||||
}
|
||||
|
||||
// Check that the code duplication has at most a factor of 15 (where 15 is an
|
||||
// arbitrary constant that seems to work).
|
||||
return total > 15 * index_usage_count.size();
|
||||
return index_usage_count[producer] > 15;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user