diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 0290d0cf26b..10030ce6491 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -2660,6 +2660,11 @@ Status AlgebraicSimplifierVisitor::HandleRemainder(HloInstruction* remainder) { HloInstruction *a, *b; CHECK(Match(remainder, m::Remainder(m::Op(&a), m::Op(&b)))); + // (A % B) % B == A % B. + if (Match(a, m::Remainder(m::Op(), m::Op().Is(b)))) { + return ReplaceInstruction(remainder, a); + } + // A % B => A & (B - 1) if B is a power of 2. switch (remainder->shape().element_type()) { case S8: diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 31b0e16c2a0..05d57cf3ba2 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5503,5 +5503,20 @@ TEST_F(AlgebraicSimplifierTest, RemainderOfNPlusIotaOverflow) { ASSERT_FALSE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); } +TEST_F(AlgebraicSimplifierTest, RepeatedRemainder) { + const char* kModuleStr = R"( + HloModule m + test { + p = s32[1000] parameter(0) + q = s32[1000] parameter(1) + r = s32[1000] remainder(p, q) + ROOT rr = s32[1000] remainder(r, q) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Remainder(m::Parameter(), m::Parameter()))); +} + } // namespace } // namespace xla