[XLA] Simplify unsigned comparisons against 0
PiperOrigin-RevId: 357326640 Change-Id: Ic5d018758772ccb6cba6387e962d28d1217db7a7
This commit is contained in:
parent
5bfc37ef25
commit
6d597cd0ae
tensorflow/compiler/xla/service
@ -3195,6 +3195,30 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) {
|
||||
}
|
||||
}
|
||||
|
||||
if (Cast<HloCompareInstruction>(compare)->type() ==
|
||||
Comparison::Type::kUnsigned) {
|
||||
// X u< 0 -> false
|
||||
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
||||
IsAll(rhs, 0)) {
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
||||
}
|
||||
// X u>= 0 -> true
|
||||
if (compare->comparison_direction() == ComparisonDirection::kGe &&
|
||||
IsAll(rhs, 0)) {
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
|
||||
}
|
||||
// 0 u> X -> false
|
||||
if (compare->comparison_direction() == ComparisonDirection::kGt &&
|
||||
IsAll(lhs, 0)) {
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
||||
}
|
||||
// 0 u<= X -> true
|
||||
if (compare->comparison_direction() == ComparisonDirection::kLe &&
|
||||
IsAll(lhs, 0)) {
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, true));
|
||||
}
|
||||
}
|
||||
|
||||
if (compare->comparison_direction() == ComparisonDirection::kLt &&
|
||||
lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) {
|
||||
return ReplaceInstruction(compare, MakeScalarLike(compare, false));
|
||||
|
@ -6259,6 +6259,121 @@ TEST_F(AlgebraicSimplifierTest, CompareIota) {
|
||||
GmockMatch(m::Broadcast(m::ConstantScalar(false))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareLtZero) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
zero = u32[] constant(0)
|
||||
param = u32[] parameter(0)
|
||||
ROOT compare = pred[] compare(param, zero), direction=LT
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::ConstantScalar(false)));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareLeZero) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
zero = u32[] constant(0)
|
||||
param = u32[] parameter(0)
|
||||
ROOT compare = pred[] compare(param, zero), direction=LE
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Le(m::Parameter(0), m::ConstantEffectiveScalar(0))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareGeZero) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
zero = u32[] constant(0)
|
||||
param = u32[] parameter(0)
|
||||
ROOT compare = pred[] compare(param, zero), direction=GE
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::ConstantScalar(true)));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareGtZero) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
zero = u32[] constant(0)
|
||||
param = u32[] parameter(0)
|
||||
ROOT compare = pred[] compare(param, zero), direction=GT
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Gt(m::Parameter(0), m::ConstantEffectiveScalar(0))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareZeroGt) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
zero = u32[] constant(0)
|
||||
param = u32[] parameter(0)
|
||||
ROOT compare = pred[] compare(zero, param), direction=GT
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::ConstantScalar(false)));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareZeroGe) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
zero = u32[] constant(0)
|
||||
param = u32[] parameter(0)
|
||||
ROOT compare = pred[] compare(zero, param), direction=GE
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Ge(m::ConstantEffectiveScalar(0), m::Parameter(0))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareZeroLe) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
zero = u32[] constant(0)
|
||||
param = u32[] parameter(0)
|
||||
ROOT compare = pred[] compare(zero, param), direction=LE
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::ConstantScalar(true)));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareZeroLt) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
zero = u32[] constant(0)
|
||||
param = u32[] parameter(0)
|
||||
ROOT compare = pred[] compare(zero, param), direction=LT
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(
|
||||
m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Lt(m::ConstantEffectiveScalar(0), m::Parameter(0))));
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, CompareSame) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
|
Loading…
Reference in New Issue
Block a user