[XLA] Simplify tautological compares (and (< x A) (< x B)) to (< x A) when a <= B
holds.
This is required for figuring out the trip count of loops whose condition contains the conjunction. Such conjunctions arise from TF when a for loop with `tf.range` is lowered, or when using `tf.while_loop` with `maximum_iterations` set. PiperOrigin-RevId: 312138518 Change-Id: I12c5c7d0aeedbf0d375f3cff1d23b39aea89f64a
This commit is contained in:
parent
7023ce338b
commit
8e661af54d
@ -508,6 +508,13 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
||||
// Tries to convert slice(reshape(X)) into reshape(slice(X))
|
||||
StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
|
||||
|
||||
// Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into
|
||||
// `(< a N)`. This is crucial for being able to figure out the loop trip
|
||||
// count.
|
||||
//
|
||||
// Assumes that the input is conjunction.
|
||||
StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction);
|
||||
|
||||
// Useful when we want to use the same visitor over multiple computations.
|
||||
void ResetState(HloComputation* computation);
|
||||
|
||||
@ -856,6 +863,57 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
|
||||
HloInstruction* conjunction) {
|
||||
HloInstruction *lhs, *rhs;
|
||||
if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) {
|
||||
return false;
|
||||
}
|
||||
struct LessThanCompareInfo { // (LT var constant)
|
||||
HloInstruction* var;
|
||||
int64 constant;
|
||||
};
|
||||
|
||||
auto get_compare_info_helper =
|
||||
[&](HloInstruction* lhs,
|
||||
HloInstruction* rhs) -> absl::optional<LessThanCompareInfo> {
|
||||
if (!Match(rhs, m::Constant().WithShape(
|
||||
m::Shape().IsEffectiveScalar().WithElementType(
|
||||
PrimitiveType::S32)))) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}};
|
||||
};
|
||||
|
||||
auto get_compare_info =
|
||||
[&](HloInstruction* cmp) -> absl::optional<LessThanCompareInfo> {
|
||||
HloInstruction *lhs, *rhs;
|
||||
if (!Match(cmp, m::Compare(m::Op(&lhs), m::Op(&rhs))
|
||||
.WithComparisonDirection(ComparisonDirection::kLt))) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (auto match1 = get_compare_info_helper(lhs, rhs)) {
|
||||
return match1;
|
||||
} else if (auto match2 = get_compare_info_helper(rhs, lhs)) {
|
||||
return match2;
|
||||
}
|
||||
return absl::nullopt;
|
||||
};
|
||||
|
||||
absl::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs);
|
||||
absl::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs);
|
||||
if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) {
|
||||
int64 new_bound = std::min(lhs_info->constant, rhs_info->constant);
|
||||
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
|
||||
conjunction,
|
||||
HloInstruction::CreateCompare(lhs->shape(), lhs_info->var,
|
||||
MakeScalarLike(lhs_info->var, new_bound),
|
||||
ComparisonDirection::kLt)));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
|
||||
HloInstruction *lhs, *rhs;
|
||||
CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
|
||||
@ -890,6 +948,13 @@ Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Simplify tautological conjunctions.
|
||||
TF_ASSIGN_OR_RETURN(bool found_tautological_compare,
|
||||
TrySimplifyTautologicalCompare(logical_and));
|
||||
if (found_tautological_compare) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -5761,6 +5761,25 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) {
|
||||
GmockMatch(m::Broadcast(m::ConstantScalar(true))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareSimplified) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
param = s32[] parameter(0)
|
||||
c1 = s32[] constant(10)
|
||||
c2 = s32[] constant(100)
|
||||
cmp1 = pred[] compare(param, c1), direction=LT
|
||||
cmp2 = pred[] compare(param, c2), direction=LT
|
||||
ROOT out = pred[] and(cmp1, cmp2)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
|
||||
.WithComparisonDirection(ComparisonDirection::kLt)));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
|
||||
// Some backends may have better performance by treating an outer product as a
|
||||
// Dot, rather than a broadcast Multiply
|
||||
|
Loading…
Reference in New Issue
Block a user