[XLA] Simplify (A % B) % B => A % B.
PiperOrigin-RevId: 257212321
This commit is contained in:
parent
9c7ddffd97
commit
6df6fd298c
@ -2660,6 +2660,11 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
|||||||
HloInstruction *a, *b;
|
HloInstruction *a, *b;
|
||||||
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
|
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
|
||||||
|
|
||||||
|
// (A % B) % B == A % B.
|
||||||
|
if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
|
||||||
|
return ReplaceInstruction(remainder, a);
|
||||||
|
}
|
||||||
|
|
||||||
// A % B => A & (B - 1) if B is a power of 2.
|
// A % B => A & (B - 1) if B is a power of 2.
|
||||||
switch (remainder->shape().element_type()) {
|
switch (remainder->shape().element_type()) {
|
||||||
case S8:
|
case S8:
|
||||||
|
@ -5503,5 +5503,20 @@ TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
|
|||||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
p = s32[1000] parameter(0)
|
||||||
|
q = s32[1000] parameter(1)
|
||||||
|
r = s32[1000] remainder(p, q)
|
||||||
|
ROOT rr = s32[1000] remainder(r, q)
|
||||||
|
})";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||||
|
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||||
|
GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user