From 8e661af54d9787b2a3a2371cc6efcfa1d8db6a34 Mon Sep 17 00:00:00 2001 From: George Karpenkov Date: Mon, 18 May 2020 13:03:24 -0700 Subject: [PATCH] [XLA] Simplify tautological compares (and (< x A) (< x B)) to (< x A) when `a <= B` holds. This is required for figuring out the trip count of loops whose condition contains the conjunction. Such conjunctions arise from TF when a for loop with `tf.range` is lowered, or when using `tf.while_loop` with `maximum_iterations` set. PiperOrigin-RevId: 312138518 Change-Id: I12c5c7d0aeedbf0d375f3cff1d23b39aea89f64a --- .../xla/service/algebraic_simplifier.cc | 65 +++++++++++++++++++ .../xla/service/algebraic_simplifier_test.cc | 19 ++++++ 2 files changed, 84 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 55af8726dc8..ecbf2075abe 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -508,6 +508,13 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Tries to convert slice(reshape(X)) into reshape(slice(X)) StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into + // `(< a N)`. This is crucial for being able to figure out the loop trip + // count. + // + // Assumes that the input is conjunction. + StatusOr TrySimplifyTautologicalCompare(HloInstruction* conjunction); + // Useful when we want to use the same visitor over multiple computations. void ResetState(HloComputation* computation); @@ -856,6 +863,57 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } +StatusOr AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare( + HloInstruction* conjunction) { + HloInstruction *lhs, *rhs; + if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) { + return false; + } + struct LessThanCompareInfo { // (LT var constant) + HloInstruction* var; + int64 constant; + }; + + auto get_compare_info_helper = + [&](HloInstruction* lhs, + HloInstruction* rhs) -> absl::optional { + if (!Match(rhs, m::Constant().WithShape( + m::Shape().IsEffectiveScalar().WithElementType( + PrimitiveType::S32)))) { + return absl::nullopt; + } + return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}}; + }; + + auto get_compare_info = + [&](HloInstruction* cmp) -> absl::optional { + HloInstruction *lhs, *rhs; + if (!Match(cmp, m::Compare(m::Op(&lhs), m::Op(&rhs)) + .WithComparisonDirection(ComparisonDirection::kLt))) { + return absl::nullopt; + } + if (auto match1 = get_compare_info_helper(lhs, rhs)) { + return match1; + } else if (auto match2 = get_compare_info_helper(rhs, lhs)) { + return match2; + } + return absl::nullopt; + }; + + absl::optional lhs_info = get_compare_info(lhs); + absl::optional rhs_info = get_compare_info(rhs); + if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) { + int64 new_bound = std::min(lhs_info->constant, rhs_info->constant); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + conjunction, + HloInstruction::CreateCompare(lhs->shape(), lhs_info->var, + MakeScalarLike(lhs_info->var, new_bound), + ComparisonDirection::kLt))); + return true; + } + return false; +} + Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { HloInstruction *lhs, *rhs; CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); @@ -890,6 +948,13 @@ Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { return Status::OK(); } + // Simplify tautological conjunctions. + TF_ASSIGN_OR_RETURN(bool found_tautological_compare, + TrySimplifyTautologicalCompare(logical_and)); + if (found_tautological_compare) { + return Status::OK(); + } + return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 6c8e80aa963..08a004e39fe 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5761,6 +5761,25 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) { GmockMatch(m::Broadcast(m::ConstantScalar(true)))); } +TEST_F(AlgebraicSimplifierTest, CompareSimplified) { + const char* kModuleStr = R"( + HloModule m + test { + param = s32[] parameter(0) + c1 = s32[] constant(10) + c2 = s32[] constant(100) + cmp1 = pred[] compare(param, c1), direction=LT + cmp2 = pred[] compare(param, c2), direction=LT + ROOT out = pred[] and(cmp1, cmp2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10)) + .WithComparisonDirection(ComparisonDirection::kLt))); +} + TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) { // Some backends may have better performance by treating an outer product as a // Dot, rather than a broadcast Multiply