Evaluate the maximum code duplication in fusion nodes instead of average.

This makes it easier to explain why a certain op was not fused. Also,
evaluating the max instead of the average is easier, so we can simplify the
code a bit.

PiperOrigin-RevId: 329709163
Change-Id: I62fd7fba8c8f9db9124692c83981d5b9281f7f0d
This commit is contained in:
Adrian Kuegel 2020-09-02 07:31:31 -07:00 committed by TensorFlower Gardener
parent 5363cf3d1d
commit ea462c5eac
6 changed files with 35 additions and 65 deletions

View File

@ -132,7 +132,7 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
fusion_node_evaluations_.emplace(consumer,
FusionNodeIndexingEvaluation(consumer));
}
if (fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh(
if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(
producer)) {
return false;
}

View File

@ -27,33 +27,25 @@ namespace xla {
FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
const HloInstruction* fusion)
: fusion_(fusion) {
total_emitted_instructions_ = 0;
HloInstruction* root = fusion->fused_expression_root();
indexing_users_[root].insert(fusion);
index_usage_count_[fusion] = 1;
RecomputeCache();
}
bool FusionNodeIndexingEvaluation::AverageCodeDuplicationTooHigh(
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
const HloInstruction* producer) const {
// This constant is arbitrarily chosen. Essentially we don't want to have too
// much code duplication, because it slows down the compilation time. There is
// a tradeoff between compilation time and runtime here.
const int64 kAllowedCodeDuplication = 15;
// index_usage_count_ contains an entry for each instruction in the fusion
// computation (except parameter instructions), plus an entry for the 'fusion'
// instruction. So the size of this map is already one bigger than the number
// of instructions in the fusion node that are emitted, thus accounting for
// the number of instructions after 'producer' is fused.
return EvaluateTotalEmittedInstructions(producer) /
index_usage_count_.size() >
kAllowedCodeDuplication;
return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
}
int64 FusionNodeIndexingEvaluation::EvaluateTotalEmittedInstructions(
int64 FusionNodeIndexingEvaluation::EvaluateEmittedInstructions(
const HloInstruction* producer) const {
int64 total = total_emitted_instructions_;
int64 total = 0;
for (const auto* user : indexing_users_.at(producer)) {
total += index_usage_count_.at(user);
}
@ -99,7 +91,6 @@ void FusionNodeIndexingEvaluation::UpdateIndexUsageCount(
total += index_usage_count_.at(user);
}
CHECK(index_usage_count_.emplace(instruction, total).second);
total_emitted_instructions_ += total;
}
void FusionNodeIndexingEvaluation::UpdateIndexingUsersOfOperands(

View File

@ -26,17 +26,14 @@ class FusionNodeIndexingEvaluation {
public:
explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion);
// Evaluate the average number of times an instruction is emitted inside the
// fusion node, if 'producer' is fused into 'fusion_'. If this average
// duplication is "too high" (some arbitrary chosen constant), returns
// true.
bool AverageCodeDuplicationTooHigh(const HloInstruction* producer) const;
// Evaluate the number of times 'producer' would be emitted if it is fused
// into 'fusion_'. If the duplication is "too high" (some arbitrary chosen
// constant), returns true.
bool CodeDuplicationTooHigh(const HloInstruction* producer) const;
// Evaluate the total number of times an instruction is emitted inside the
// fusion node, if 'producer' is fused into 'fusion_'. An instruction may be
// emitted several times, once for each different index value with which it is
// indexed.
int64 EvaluateTotalEmittedInstructions(const HloInstruction* producer) const;
// Evaluate the number of times 'producer' would be emitted if it is fused
// into 'fusion_'.
int64 EvaluateEmittedInstructions(const HloInstruction* producer) const;
// Update the evaluation cache after having fused 'producer' into 'fusion_'.
// 'producer' is the cloned instruction which is now part of the fusion
@ -84,9 +81,6 @@ class FusionNodeIndexingEvaluation {
// The fusion instruction.
const HloInstruction* fusion_;
// The total number of emitted instructions.
int64 total_emitted_instructions_;
};
} // namespace xla

View File

@ -29,7 +29,7 @@ using FusionNodeIndexingEvaluationTest = HloTestBase;
// Subclass of InstructionFusion exposing the protected methods Fuse and
// FuseInstruction for testing. Also adds the FusionNodeIndexingEvaluation to
// track the average code duplication due to indexing HloInstructions with
// track the code duplication due to indexing HloInstructions with
// different index values.
class InstructionFusionForTesting : public InstructionFusion {
public:
@ -61,8 +61,8 @@ class InstructionFusionForTesting : public InstructionFusion {
return InstructionFusion::Fuse(producer, consumer);
}
int64 EvaluateTotalEmittedInstructions(const HloInstruction* producer,
const HloInstruction* consumer) {
int64 EvaluateEmittedInstructions(const HloInstruction* producer,
const HloInstruction* consumer) {
if (consumer->opcode() != HloOpcode::kFusion) {
return 0;
}
@ -71,8 +71,8 @@ class InstructionFusionForTesting : public InstructionFusion {
fusion_node_evaluations_.emplace(consumer,
FusionNodeIndexingEvaluation(consumer));
}
return fusion_node_evaluations_.at(consumer)
.EvaluateTotalEmittedInstructions(producer);
return fusion_node_evaluations_.at(consumer).EvaluateEmittedInstructions(
producer);
}
private:
@ -109,8 +109,7 @@ TEST_F(FusionNodeIndexingEvaluationTest, FuseThreeInstructions) {
HloInstruction* slice1 = sub->mutable_operand(0);
HloInstruction* slice2 = sub->mutable_operand(1);
auto fusion = instruction_fusion.Fuse(slice1, sub);
EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(slice2, fusion),
3);
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice2, fusion), 1);
instruction_fusion.Fuse(slice2, fusion);
}
@ -151,37 +150,31 @@ TEST_F(FusionNodeIndexingEvaluationTest, ExponentialDuplicationPattern) {
HloInstruction* slice2_1 = add2->mutable_operand(1);
auto fusion = instruction_fusion.Fuse(slice2_0, add2);
// So far we have fused add2 and slice2.0. So when we also fuse slice2.1, we
// expect to emit 3 instructions.
EXPECT_EQ(
instruction_fusion.EvaluateTotalEmittedInstructions(slice2_1, fusion), 3);
// expect to emit it 1 time.
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice2_1, fusion),
1);
instruction_fusion.Fuse(slice2_1, fusion);
HloInstruction* add1 = fusion->mutable_operand(0);
EXPECT_EQ(add1->opcode(), HloOpcode::kAdd);
// If we fuse add1 into 'fusion', it needs to be emitted twice, adding 2 to
// the sum.
EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add1, fusion),
5);
// If we fuse add1 into 'fusion', it needs to be emitted twice.
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add1, fusion), 2);
instruction_fusion.Fuse(add1, fusion);
HloInstruction* slice1_0 = fusion->mutable_operand(0);
EXPECT_EQ(slice1_0->opcode(), HloOpcode::kSlice);
// If we fuse slice1.0 into 'fusion', it needs to be emitted twice, adding 2
// to the sum.
EXPECT_EQ(
instruction_fusion.EvaluateTotalEmittedInstructions(slice1_0, fusion), 7);
// If we fuse slice1.0 into 'fusion', it needs to be emitted twice.
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice1_0, fusion),
2);
instruction_fusion.Fuse(slice1_0, fusion);
HloInstruction* slice1_1 = fusion->mutable_operand(0);
EXPECT_EQ(slice1_1->opcode(), HloOpcode::kSlice);
// If we fuse slice1.1 into 'fusion', it needs to be emitted twice, adding 2
// to the sum.
EXPECT_EQ(
instruction_fusion.EvaluateTotalEmittedInstructions(slice1_1, fusion), 9);
// If we fuse slice1.1 into 'fusion', it needs to be emitted twice.
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(slice1_1, fusion),
2);
instruction_fusion.Fuse(slice1_1, fusion);
HloInstruction* add0 = fusion->mutable_operand(0);
EXPECT_EQ(add0->opcode(), HloOpcode::kAdd);
// If we fuse add0 into 'fusion', it needs to be emitted twice, adding 4 to
// the sum.
EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add0, fusion),
13);
// If we fuse add0 into 'fusion', it needs to be emitted four times.
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add0, fusion), 4);
instruction_fusion.Fuse(add0, fusion);
}
@ -212,10 +205,9 @@ ENTRY entry_computation {
HloInstruction* add0 = fusion->mutable_operand(0);
EXPECT_EQ(add0->opcode(), HloOpcode::kAdd);
// Here, the cache for the fusion node needs to be recomputed. Make sure we
// still get the same evaluation as before when we incrementally built the
// still get the same evaluation as before when we incrementally build the
// cache.
EXPECT_EQ(instruction_fusion.EvaluateTotalEmittedInstructions(add0, fusion),
13);
EXPECT_EQ(instruction_fusion.EvaluateEmittedInstructions(add0, fusion), 4);
}
} // namespace xla

View File

@ -112,8 +112,7 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
fusion_node_evaluations_.emplace(consumer,
FusionNodeIndexingEvaluation(consumer));
}
if (fusion_node_evaluations_.at(consumer).AverageCodeDuplicationTooHigh(
producer)) {
if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) {
VLOG(5) << "Fusion of " << producer->name() << " into " << consumer->name()
<< " would result in overly large code duplication.";
return false;

View File

@ -289,15 +289,9 @@ bool FusedIrEmitter::IsFusedIrEmitterInefficient(
evaluate_fusion_computation(producer);
}
// Sum up the total number of emitted ops.
int64 total = 0;
for (const auto& entry : index_usage_count) {
total += entry.second;
}
// Check that the code duplication has at most a factor of 15 (where 15 is an
// arbitrary constant that seems to work).
return total > 15 * index_usage_count.size();
return index_usage_count[producer] > 15;
}
} // namespace xla