[XLA] Simplify tautological compares (and (< x A) (< x B)) to (< x A) when a <= B holds.

This is required for figuring out the trip count of loops whose condition
contains the conjunction.  Such conjunctions arise from TF when a for loop with
`tf.range` is lowered, or when using `tf.while_loop` with `maximum_iterations`
set.

PiperOrigin-RevId: 312138518
Change-Id: I12c5c7d0aeedbf0d375f3cff1d23b39aea89f64a
This commit is contained in:
George Karpenkov 2020-05-18 13:03:24 -07:00 committed by TensorFlower Gardener
parent 7023ce338b
commit 8e661af54d
2 changed files with 84 additions and 0 deletions

View File

@ -508,6 +508,13 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
// Tries to convert slice(reshape(X)) into reshape(slice(X)) // Tries to convert slice(reshape(X)) into reshape(slice(X))
StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice); StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
// Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into
// `(< a N)`. This is crucial for being able to figure out the loop trip
// count.
//
// Assumes that the input is conjunction.
StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction);
// Useful when we want to use the same visitor over multiple computations. // Useful when we want to use the same visitor over multiple computations.
void ResetState(HloComputation* computation); void ResetState(HloComputation* computation);
@ -856,6 +863,57 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
return Status::OK(); return Status::OK();
} }
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
HloInstruction* conjunction) {
HloInstruction *lhs, *rhs;
if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) {
return false;
}
struct LessThanCompareInfo { // (LT var constant)
HloInstruction* var;
int64 constant;
};
auto get_compare_info_helper =
[&](HloInstruction* lhs,
HloInstruction* rhs) -> absl::optional<LessThanCompareInfo> {
if (!Match(rhs, m::Constant().WithShape(
m::Shape().IsEffectiveScalar().WithElementType(
PrimitiveType::S32)))) {
return absl::nullopt;
}
return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}};
};
auto get_compare_info =
[&](HloInstruction* cmp) -> absl::optional<LessThanCompareInfo> {
HloInstruction *lhs, *rhs;
if (!Match(cmp, m::Compare(m::Op(&lhs), m::Op(&rhs))
.WithComparisonDirection(ComparisonDirection::kLt))) {
return absl::nullopt;
}
if (auto match1 = get_compare_info_helper(lhs, rhs)) {
return match1;
} else if (auto match2 = get_compare_info_helper(rhs, lhs)) {
return match2;
}
return absl::nullopt;
};
absl::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs);
absl::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs);
if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) {
int64 new_bound = std::min(lhs_info->constant, rhs_info->constant);
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
conjunction,
HloInstruction::CreateCompare(lhs->shape(), lhs_info->var,
MakeScalarLike(lhs_info->var, new_bound),
ComparisonDirection::kLt)));
return true;
}
return false;
}
Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) { Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
HloInstruction *lhs, *rhs; HloInstruction *lhs, *rhs;
CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs)))); CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
@ -890,6 +948,13 @@ Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
return Status::OK(); return Status::OK();
} }
// Simplify tautological conjunctions.
TF_ASSIGN_OR_RETURN(bool found_tautological_compare,
TrySimplifyTautologicalCompare(logical_and));
if (found_tautological_compare) {
return Status::OK();
}
return Status::OK(); return Status::OK();
} }

View File

@ -5761,6 +5761,25 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) {
GmockMatch(m::Broadcast(m::ConstantScalar(true)))); GmockMatch(m::Broadcast(m::ConstantScalar(true))));
} }
TEST_F(AlgebraicSimplifierTest, CompareSimplified) {
const char* kModuleStr = R"(
HloModule m
test {
param = s32[] parameter(0)
c1 = s32[] constant(10)
c2 = s32[] constant(100)
cmp1 = pred[] compare(param, c1), direction=LT
cmp2 = pred[] compare(param, c2), direction=LT
ROOT out = pred[] and(cmp1, cmp2)
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
.WithComparisonDirection(ComparisonDirection::kLt)));
}
TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) { TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
// Some backends may have better performance by treating an outer product as a // Some backends may have better performance by treating an outer product as a
// Dot, rather than a broadcast Multiply // Dot, rather than a broadcast Multiply