diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 55af8726dc8..ecbf2075abe 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -508,6 +508,13 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Tries to convert slice(reshape(X)) into reshape(slice(X)) StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into + // `(< a N)`. This is crucial for being able to figure out the loop trip + // count. + // + // Assumes that the input is conjunction. + StatusOr TrySimplifyTautologicalCompare(HloInstruction* conjunction); + // Useful when we want to use the same visitor over multiple computations. void ResetState(HloComputation* computation); @@ -856,6 +863,57 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } +StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( + HloInstruction* conjunction) { + HloInstruction *lhs, *rhs; + if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) { + return false; + } + struct LessThanCompareInfo { // (LT var constant) + HloInstruction* var; + int64 constant; + }; + + auto get_compare_info_helper = + [&](HloInstruction* lhs, + HloInstruction* rhs) -> absl::optional { + if (!Match(rhs, m::Constant().WithShape( + m::Shape().IsEffectiveScalar().WithElementType( + PrimitiveType::S32)))) { + return absl::nullopt; + } + return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}}; + }; + + auto get_compare_info = + [&](HloInstruction* cmp) -> absl::optional { + HloInstruction *lhs, *rhs; + if (!Match(cmp, m::Compare(m::Op(&lhs), m::Op(&rhs)) + .WithComparisonDirection(ComparisonDirection::kLt))) { + return absl::nullopt; + } + if (auto match1 = get_compare_info_helper(lhs, rhs)) { + return match1; + } else if (auto match2 = get_compare_info_helper(rhs, lhs)) { + return match2; + } + return absl::nullopt; + }; + + absl::optional lhs_info = get_compare_info(lhs); + absl::optional rhs_info = get_compare_info(rhs); + if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) { + int64 new_bound = std::min(lhs_info->constant, rhs_info->constant); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + conjunction, + HloInstruction::CreateCompare(lhs->shape(), lhs_info->var, + MakeScalarLike(lhs_info->var, new_bound), + ComparisonDirection::kLt))); + return true; + } + return false; +} + Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { HloInstruction *lhs, *rhs; CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); @@ -890,6 +948,13 @@ Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { return Status::OK(); } + // Simplify tautological conjunctions. + TF_ASSIGN_OR_RETURN(bool found_tautological_compare, + TrySimplifyTautologicalCompare(logical_and)); + if (found_tautological_compare) { + return Status::OK(); + } + return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 6c8e80aa963..08a004e39fe 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5761,6 +5761,25 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) { GmockMatch(m::Broadcast(m::ConstantScalar(true)))); } +TEST_F(AlgebraicSimplifierTest, CompareSimplified) { + const char* kModuleStr = R"( + HloModule m + test { + param = s32[] parameter(0) + c1 = s32[] constant(10) + c2 = s32[] constant(100) + cmp1 = pred[] compare(param, c1), direction=LT + cmp2 = pred[] compare(param, c2), direction=LT + ROOT out = pred[] and(cmp1, cmp2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10)) + .WithComparisonDirection(ComparisonDirection::kLt))); +} + TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) { // Some backends may have better performance by treating an outer product as a // Dot, rather than a broadcast Multiply