[XLA] Simplify (A % B) % B => A % B.
PiperOrigin-RevId: 257212321
This commit is contained in:
parent
9c7ddffd97
commit
6df6fd298c
@ -2660,6 +2660,11 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
|
||||
HloInstruction *a, *b;
|
||||
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
|
||||
|
||||
// (A % B) % B == A % B.
|
||||
if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
|
||||
return ReplaceInstruction(remainder, a);
|
||||
}
|
||||
|
||||
// A % B => A & (B - 1) if B is a power of 2.
|
||||
switch (remainder->shape().element_type()) {
|
||||
case S8:
|
||||
|
@ -5503,5 +5503,20 @@ TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
|
||||
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) {
|
||||
const char* kModuleStr = R"(
|
||||
HloModule m
|
||||
test {
|
||||
p = s32[1000] parameter(0)
|
||||
q = s32[1000] parameter(1)
|
||||
r = s32[1000] remainder(p, q)
|
||||
ROOT rr = s32[1000] remainder(r, q)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
|
||||
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
|
||||
EXPECT_THAT(m->entry_computation()->root_instruction(),
|
||||
GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user