diff --git a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc index 25b9658ba98..17d3fb2b3d6 100644 --- a/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc +++ b/tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.cc @@ -38,60 +38,14 @@ FusionNodeIndexingEvaluation::FusionNodeIndexingEvaluation( // a tradeoff between compilation time and runtime here. const int64 FusionNodeIndexingEvaluation::kAllowedCodeDuplication = 15; -namespace { - -// Returns which ops invalidate the cache of emitted instructions by creating a -// new BasicBlock and setting the insertion point to the newly created -// BasicBlock. We can only reuse cached values if they were emitted in the same -// BasicBlock as the current BasicBlock. -bool OpInvalidatesCache(const HloInstruction* hlo) { - switch (hlo->opcode()) { - // This list of ops was created by inspecting the code. There is no - // guarantee that it is complete. - case HloOpcode::kConcatenate: - case HloOpcode::kDot: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kPad: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - return true; - default: - return false; - } -} - -// Counts the number of "real" users of 'hlo'. When 'hlo' has a fusion node as -// user, we consider the users of the fusion parameter corresponding to 'hlo' as -// the real users. -int64 UserCount(const HloInstruction* hlo) { - int64 cnt = 0; - for (HloInstruction* user : hlo->users()) { - if (user->opcode() == HloOpcode::kFusion) { - // Count the number of users of the parameter corresponding to the fusion - // operand. - int64 operand_index = user->operand_index(hlo); - cnt += user->fused_parameter(operand_index)->user_count(); - } else { - ++cnt; - } - } - return cnt; -} -} // namespace - bool FusionNodeIndexingEvaluation::CodeDuplicationTooHigh( const HloInstruction* producer) const { - int64 emitted_instructions = EvaluateEmittedInstructions(producer); - return emitted_instructions > kAllowedCodeDuplication || - (OpInvalidatesCache(producer) && - (emitted_instructions > 1 || UserCount(producer) > 1)); + return EvaluateEmittedInstructions(producer) > kAllowedCodeDuplication; } bool FusionNodeIndexingEvaluation::MaxCodeDuplicationTooHigh() const { for (const auto& entry : index_usage_count_) { - if (entry.second > kAllowedCodeDuplication || - (OpInvalidatesCache(entry.first) && - (entry.second > 1 || UserCount(entry.first) > 1))) { + if (entry.second > kAllowedCodeDuplication) { return true; } } diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index fe27a8c6963..9df83e30ad4 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -521,7 +521,8 @@ XLA_TEST_F(ConcatTest, ConcatDeeplyNested) { ComputeAndCompareR1(&builder, expected, {a_data.get()}); } -XLA_TEST_F(ConcatTestHlo, ConcatWithBitcast) { +// TODO(b/169314478): Enable the test when the slow compilation is fixed. +XLA_TEST_F(ConcatTestHlo, DISABLED_ConcatWithBitcast) { auto module = ParseAndReturnVerifiedModule(R"( HloModule jit_broken.874 @@ -761,7 +762,7 @@ ENTRY jit_broken.874 { auto input_array = absl::make_unique>(4, 2); input_array->FillUnique(1.0f); auto input = LiteralUtil::CreateR2FromArray2D(*input_array); - EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, error_spec_)); + EXPECT_TRUE(RunAndCompare(std::move(module), {&input}, absl::nullopt)); } // Describes a binary rank-2 concatenation test.