Prevent creating fusions which could cause problems.
In fusion nodes we need to make sure that we can cache the emitted ops. However there are ops which create new BasicBlocks and set the insertion point to the newly created basic block. This invalidates all cache entries, because we can only reuse them if we are still inside the same basic block. This becomes an issue if such an op is emitted more than once. PiperOrigin-RevId: 339679716 Change-Id: I0087235dd0aef4f15b0ace51f6f650a5c6a027ed
This commit is contained in:
parent
79cdd9533a
commit
5fb34be1be
@ -38,14 +38,60 @@ FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation(
|
||||
// a tradeoff between compilation time and runtime here.
|
||||
const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15;
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns which ops invalidate the cache of emitted instructions by creating a
|
||||
// new BasicBlock and setting the insertion point to the newly created
|
||||
// BasicBlock. We can only reuse cached values if they were emitted in the same
|
||||
// BasicBlock as the current BasicBlock.
|
||||
bool OpInvalidatesCache(const HloInstruction* hlo) {
|
||||
switch (hlo->opcode()) {
|
||||
// This list of ops was created by inspecting the code. There is no
|
||||
// guarantee that it is complete.
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kDot:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReduceWindow:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Counts the number of "real" users of 'hlo'. When 'hlo' has a fusion node as
|
||||
// user, we consider the users of the fusion parameter corresponding to 'hlo' as
|
||||
// the real users.
|
||||
int64 UserCount(const HloInstruction* hlo) {
|
||||
int64 cnt = 0;
|
||||
for (HloInstruction* user : hlo->users()) {
|
||||
if (user->opcode() == HloOpcode::kFusion) {
|
||||
// Count the number of users of the parameter corresponding to the fusion
|
||||
// operand.
|
||||
int64 operand_index = user->operand_index(hlo);
|
||||
cnt += user->fused_parameter(operand_index)->user_count();
|
||||
} else {
|
||||
++cnt;
|
||||
}
|
||||
}
|
||||
return cnt;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh(
|
||||
const HloInstruction* producer) const {
|
||||
return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication;
|
||||
int64 emitted_instructions = EvaluateEmittedInstructions(producer);
|
||||
return emitted_instructions > kAllowedCodeDuplication ||
|
||||
(OpInvalidatesCache(producer) &&
|
||||
(emitted_instructions > 1 || UserCount(producer) > 1));
|
||||
}
|
||||
|
||||
bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const {
|
||||
for (const auto& entry : index_usage_count_) {
|
||||
if (entry.second > kAllowedCodeDuplication) {
|
||||
if (entry.second > kAllowedCodeDuplication ||
|
||||
(OpInvalidatesCache(entry.first) &&
|
||||
(entry.second > 1 || UserCount(entry.first) > 1))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -521,8 +521,7 @@ XLA_TEST_F(ConcatTest, ConcatDeeplyNested) {
|
||||
ComputeAndCompareR1<float>(&builder, expected, {a_data.get()});
|
||||
}
|
||||
|
||||
// TODO(b/169314478): Enable the test when the slow compilation is fixed.
|
||||
XLA_TEST_F(ConcatTestHlo, DISABLED_ConcatWithBitcast) {
|
||||
XLA_TEST_F(ConcatTestHlo, ConcatWithBitcast) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule jit_broken.874
|
||||
|
||||
@ -762,7 +761,7 @@ ENTRY jit_broken.874 {
|
||||
auto input_array = absl::make_unique<Array2D<float>>(4, 2);
|
||||
input_array->FillUnique(1.0f);
|
||||
auto input = LiteralUtil::CreateR2FromArray2D<float>(*input_array);
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, absl::nullopt));
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, error_spec_));
|
||||
}
|
||||
|
||||
// Describes a binary rank-2 concatenation test.
|
||||
|
Loading…
x
Reference in New Issue
Block a user