[XLA] Simplify tautological compares (and (< x A) (< x B)) to (< x A) when a <= B
holds.
This is required for figuring out the trip count of loops whose condition contains the conjunction. Such conjunctions arise from TF when a for loop with `tf.range` is lowered, or when using `tf.while_loop` with `maximum_iterations` set. PiperOrigin-RevId: 312138518 Change-Id: I12c5c7d0aeedbf0d375f3cff1d23b39aea89f64a
This commit is contained in:
parent
7023ce338b
commit
8e661af54d
@ -508,6 +508,13 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
|
|||||||
// Tries to convert slice(reshape(X)) into reshape(slice(X))
|
// Tries to convert slice(reshape(X)) into reshape(slice(X))
|
||||||
StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
|
StatusOr<bool> TryToReorderSliceAndReshape(HloInstruction* slice);
|
||||||
|
|
||||||
|
// Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into
|
||||||
|
// `(< a N)`. This is crucial for being able to figure out the loop trip
|
||||||
|
// count.
|
||||||
|
//
|
||||||
|
// Assumes that the input is conjunction.
|
||||||
|
StatusOr<bool> TrySimplifyTautologicalCompare(HloInstruction* conjunction);
|
||||||
|
|
||||||
// Useful when we want to use the same visitor over multiple computations.
|
// Useful when we want to use the same visitor over multiple computations.
|
||||||
void ResetState(HloComputation* computation);
|
void ResetState(HloComputation* computation);
|
||||||
|
|
||||||
@ -856,6 +863,57 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
StatusOr<bool> AlgebraicSimplifierVisitor::TrySimplifyTautologicalCompare(
|
||||||
|
HloInstruction* conjunction) {
|
||||||
|
HloInstruction *lhs, *rhs;
|
||||||
|
if (!Match(conjunction, m::And(m::Op(&lhs), m::Op(&rhs)))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
struct LessThanCompareInfo { // (LT var constant)
|
||||||
|
HloInstruction* var;
|
||||||
|
int64 constant;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_compare_info_helper =
|
||||||
|
[&](HloInstruction* lhs,
|
||||||
|
HloInstruction* rhs) -> absl::optional<LessThanCompareInfo> {
|
||||||
|
if (!Match(rhs, m::Constant().WithShape(
|
||||||
|
m::Shape().IsEffectiveScalar().WithElementType(
|
||||||
|
PrimitiveType::S32)))) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
return {LessThanCompareInfo{lhs, *rhs->literal().GetFirstInteger()}};
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_compare_info =
|
||||||
|
[&](HloInstruction* cmp) -> absl::optional<LessThanCompareInfo> {
|
||||||
|
HloInstruction *lhs, *rhs;
|
||||||
|
if (!Match(cmp, m::Compare(m::Op(&lhs), m::Op(&rhs))
|
||||||
|
.WithComparisonDirection(ComparisonDirection::kLt))) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
if (auto match1 = get_compare_info_helper(lhs, rhs)) {
|
||||||
|
return match1;
|
||||||
|
} else if (auto match2 = get_compare_info_helper(rhs, lhs)) {
|
||||||
|
return match2;
|
||||||
|
}
|
||||||
|
return absl::nullopt;
|
||||||
|
};
|
||||||
|
|
||||||
|
absl::optional<LessThanCompareInfo> lhs_info = get_compare_info(lhs);
|
||||||
|
absl::optional<LessThanCompareInfo> rhs_info = get_compare_info(rhs);
|
||||||
|
if (lhs_info && rhs_info && lhs_info->var == rhs_info->var) {
|
||||||
|
int64 new_bound = std::min(lhs_info->constant, rhs_info->constant);
|
||||||
|
TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(
|
||||||
|
conjunction,
|
||||||
|
HloInstruction::CreateCompare(lhs->shape(), lhs_info->var,
|
||||||
|
MakeScalarLike(lhs_info->var, new_bound),
|
||||||
|
ComparisonDirection::kLt)));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
|
Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
|
||||||
HloInstruction *lhs, *rhs;
|
HloInstruction *lhs, *rhs;
|
||||||
CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
|
CHECK(Match(logical_and, m::And(m::Op(&lhs), m::Op(&rhs))));
|
||||||
@ -890,6 +948,13 @@ Status AlgebraicSimplifierVisitor::HandleAnd(HloInstruction* logical_and) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Simplify tautological conjunctions.
|
||||||
|
TF_ASSIGN_OR_RETURN(bool found_tautological_compare,
|
||||||
|
TrySimplifyTautologicalCompare(logical_and));
|
||||||
|
if (found_tautological_compare) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5761,6 +5761,25 @@ TEST_F(AlgebraicSimplifierTest, CompareSame) {
|
|||||||
GmockMatch(m::Broadcast(m::ConstantScalar(true))));
|
GmockMatch(m::Broadcast(m::ConstantScalar(true))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, CompareSimplified) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
param = s32[] parameter(0)
|
||||||
|
c1 = s32[] constant(10)
|
||||||
|
c2 = s32[] constant(100)
|
||||||
|
cmp1 = pred[] compare(param, c1), direction=LT
|
||||||
|
cmp2 = pred[] compare(param, c2), direction=LT
|
||||||
|
ROOT out = pred[] and(cmp1, cmp2)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(
|
||||||
|
m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Compare(m::Op(), m::Op().IsConstantScalar(10))
|
||||||
|
.WithComparisonDirection(ComparisonDirection::kLt)));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
|
TEST_F(AlgebraicSimplifierTest, CanDisableDotToMultiplyRewrite) {
|
||||||
// Some backends may have better performance by treating an outer product as a
|
// Some backends may have better performance by treating an outer product as a
|
||||||
// Dot, rather than a broadcast Multiply
|
// Dot, rather than a broadcast Multiply
|
||||||
|
Loading…
Reference in New Issue
Block a user