[XLA] Simplify unsigned comparisons against 0

PiperOrigin-RevId: 357326640
Change-Id: Ic5d018758772ccb6cba6387e962d28d1217db7a7
This commit is contained in:
David Majnemer 2021-02-12 23:04:14 -08:00 committed by TensorFlower Gardener
parent 5bfc37ef25
commit 6d597cd0ae
2 changed files with 139 additions and 0 deletions

View File

@ -3195,6 +3195,30 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
}
}
if (Cast<HloCompareInstruction>(compare)->type() ==
Comparison::Type::kUnsigned) {
// X u< 0 -> false
if (compare->comparison_direction() == ComparisonDirection::kLt &&
IsAll(rhs, 0)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
}
// X u>= 0 -> true
if (compare->comparison_direction() == ComparisonDirection::kGe &&
IsAll(rhs, 0)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
}
// 0 u> X -> false
if (compare->comparison_direction() == ComparisonDirection::kGt &&
IsAll(lhs, 0)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
}
// 0 u<= X -> true
if (compare->comparison_direction() == ComparisonDirection::kLe &&
IsAll(lhs, 0)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
}
}
if (compare->comparison_direction() == ComparisonDirection::kLt &&
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
return ReplaceInstruction(compare, MakeScalarLike(compare, false));

View File

@ -6259,6 +6259,121 @@ TEST_F(AlgebraicSimplifierTest, CompareIota) {
GmockMatch(m::Broadcast(m::ConstantScalar(false))));
}
TEST_F(AlgebraicSimplifierTest, CompareLtZero) {
const char* kModuleStr = R"(
HloModule m
test {
zero = u32[] constant(0)
param = u32[] parameter(0)
ROOT compare = pred[] compare(param, zero), direction=LT
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::ConstantScalar(false)));
}
TEST_F(AlgebraicSimplifierTest, CompareLeZero) {
const char* kModuleStr = R"(
HloModule m
test {
zero = u32[] constant(0)
param = u32[] parameter(0)
ROOT compare = pred[] compare(param, zero), direction=LE
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Le(m::Parameter(0), m::ConstantEffectiveScalar(0))));
}
TEST_F(AlgebraicSimplifierTest, CompareGeZero) {
const char* kModuleStr = R"(
HloModule m
test {
zero = u32[] constant(0)
param = u32[] parameter(0)
ROOT compare = pred[] compare(param, zero), direction=GE
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::ConstantScalar(true)));
}
TEST_F(AlgebraicSimplifierTest, CompareGtZero) {
const char* kModuleStr = R"(
HloModule m
test {
zero = u32[] constant(0)
param = u32[] parameter(0)
ROOT compare = pred[] compare(param, zero), direction=GT
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Gt(m::Parameter(0), m::ConstantEffectiveScalar(0))));
}
TEST_F(AlgebraicSimplifierTest, CompareZeroGt) {
const char* kModuleStr = R"(
HloModule m
test {
zero = u32[] constant(0)
param = u32[] parameter(0)
ROOT compare = pred[] compare(zero, param), direction=GT
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::ConstantScalar(false)));
}
TEST_F(AlgebraicSimplifierTest, CompareZeroGe) {
const char* kModuleStr = R"(
HloModule m
test {
zero = u32[] constant(0)
param = u32[] parameter(0)
ROOT compare = pred[] compare(zero, param), direction=GE
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Ge(m::ConstantEffectiveScalar(0), m::Parameter(0))));
}
TEST_F(AlgebraicSimplifierTest, CompareZeroLe) {
const char* kModuleStr = R"(
HloModule m
test {
zero = u32[] constant(0)
param = u32[] parameter(0)
ROOT compare = pred[] compare(zero, param), direction=LE
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::ConstantScalar(true)));
}
TEST_F(AlgebraicSimplifierTest, CompareZeroLt) {
const char* kModuleStr = R"(
HloModule m
test {
zero = u32[] constant(0)
param = u32[] parameter(0)
ROOT compare = pred[] compare(zero, param), direction=LT
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(
m->entry_computation()->root_instruction(),
GmockMatch(m::Lt(m::ConstantEffectiveScalar(0), m::Parameter(0))));
}
TEST_F(AlgebraicSimplifierTest, CompareSame) {
const char* kModuleStr = R"(
HloModule m