[XLA] Simplify unsigned comparisons against 0
PiperOrigin-RevId: 357326640 Change-Id: Ic5d018758772ccb6cba6387e962d28d1217db7a7
This commit is contained in:
parent
5bfc37ef25
commit
6d597cd0ae
@ -3195,6 +3195,30 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (Cast<HloCompareInstruction>(compare)->type() ==
|
||||||
|
Comparison::Type::kUnsigned) {
|
||||||
|
// X u< 0 -> false
|
||||||
|
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
||||||
|
IsAll(rhs, 0)) {
|
||||||
|
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
||||||
|
}
|
||||||
|
// X u>= 0 -> true
|
||||||
|
if (compare->comparison_direction() == ComparisonDirection::kGe &&
|
||||||
|
IsAll(rhs, 0)) {
|
||||||
|
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
|
||||||
|
}
|
||||||
|
// 0 u> X -> false
|
||||||
|
if (compare->comparison_direction() == ComparisonDirection::kGt &&
|
||||||
|
IsAll(lhs, 0)) {
|
||||||
|
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
||||||
|
}
|
||||||
|
// 0 u<= X -> true
|
||||||
|
if (compare->comparison_direction() == ComparisonDirection::kLe &&
|
||||||
|
IsAll(lhs, 0)) {
|
||||||
|
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
||||||
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
||||||
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
||||||
|
@ -6259,6 +6259,121 @@ TEST_F(AlgebraicSimplifierTest, CompareIota) {
|
|||||||
GmockMatch(m::Broadcast(m::ConstantScalar(false))));
|
GmockMatch(m::Broadcast(m::ConstantScalar(false))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, CompareLtZero) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
zero = u32[] constant(0)
|
||||||
|
param = u32[] parameter(0)
|
||||||
|
ROOT compare = pred[] compare(param, zero), direction=LT
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::ConstantScalar(false)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, CompareLeZero) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
zero = u32[] constant(0)
|
||||||
|
param = u32[] parameter(0)
|
||||||
|
ROOT compare = pred[] compare(param, zero), direction=LE
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(
|
||||||
|
m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Le(m::Parameter(0), m::ConstantEffectiveScalar(0))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, CompareGeZero) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
zero = u32[] constant(0)
|
||||||
|
param = u32[] parameter(0)
|
||||||
|
ROOT compare = pred[] compare(param, zero), direction=GE
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::ConstantScalar(true)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, CompareGtZero) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
zero = u32[] constant(0)
|
||||||
|
param = u32[] parameter(0)
|
||||||
|
ROOT compare = pred[] compare(param, zero), direction=GT
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
EXPECT_THAT(
|
||||||
|
m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Gt(m::Parameter(0), m::ConstantEffectiveScalar(0))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, CompareZeroGt) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
zero = u32[] constant(0)
|
||||||
|
param = u32[] parameter(0)
|
||||||
|
ROOT compare = pred[] compare(zero, param), direction=GT
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::ConstantScalar(false)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, CompareZeroGe) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
zero = u32[] constant(0)
|
||||||
|
param = u32[] parameter(0)
|
||||||
|
ROOT compare = pred[] compare(zero, param), direction=GE
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(
|
||||||
|
m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Ge(m::ConstantEffectiveScalar(0), m::Parameter(0))));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, CompareZeroLe) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
zero = u32[] constant(0)
|
||||||
|
param = u32[] parameter(0)
|
||||||
|
ROOT compare = pred[] compare(zero, param), direction=LE
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::ConstantScalar(true)));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, CompareZeroLt) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
zero = u32[] constant(0)
|
||||||
|
param = u32[] parameter(0)
|
||||||
|
ROOT compare = pred[] compare(zero, param), direction=LT
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(
|
||||||
|
m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Lt(m::ConstantEffectiveScalar(0), m::Parameter(0))));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(AlgebraicSimplifierTest, CompareSame) {
|
TEST_F(AlgebraicSimplifierTest, CompareSame) {
|
||||||
const char* kModuleStr = R"(
|
const char* kModuleStr = R"(
|
||||||
HloModule m
|
HloModule m
|
||||||
|
Loading…
x
Reference in New Issue
Block a user