[XLA] Simplify (A % B) % B => A % B.

PiperOrigin-RevId: 257212321
This commit is contained in:
Justin Lebar 2019-07-09 09:44:18 -07:00 committed by TensorFlower Gardener
parent 9c7ddffd97
commit 6df6fd298c
2 changed files with 20 additions and 0 deletions

View File

@ -2660,6 +2660,11 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) {
HloInstruction *a, *b; HloInstruction *a, *b;
CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b))));
// (A % B) % B == A % B.
if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) {
return ReplaceInstruction(remainder, a);
}
// A % B => A & (B - 1) if B is a power of 2. // A % B => A & (B - 1) if B is a power of 2.
switch (remainder->shape().element_type()) { switch (remainder->shape().element_type()) {
case S8: case S8:

View File

@ -5503,5 +5503,20 @@ TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) {
ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
} }
TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) {
const char* kModuleStr = R"(
HloModule m
test {
p = s32[1000] parameter(0)
q = s32[1000] parameter(1)
r = s32[1000] remainder(p, q)
ROOT rr = s32[1000] remainder(r, q)
})";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
EXPECT_THAT(m->entry_computation()->root_instruction(),
GmockMatch(m::Remainder(m::Parameter(), m::Parameter())));
}
} // namespace } // namespace
} // namespace xla } // namespace xla