diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index ecbf2075abe..2fbfd156844 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1488,6 +1488,22 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return ReplaceInstruction(divide, new_divide); } + // If X is a convert from pred, then + // X / broadcast(Y) => broadcast(1/Y) * X + if (Match(divide, + m::Divide( + m::Convert(&a, + m::Op().WithShape(m::Shape().WithElementType(PRED))), + m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) { + TF_ASSIGN_OR_RETURN( + auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b)); + auto recip_bcast = computation_->AddInstruction( + HloInstruction::CreateBroadcast(divide->shape(), recip, {})); + TF_ASSIGN_OR_RETURN(auto mul, + MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a)); + return ReplaceInstruction(divide, mul); + } + return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 08a004e39fe..0260a925b63 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -6481,5 +6481,25 @@ TEST_F(AlgebraicSimplifierTest, SwapConvOperands) { EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1); } +TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[2] parameter(0) + cvt = f32[2] convert(p0) + p1 = f32[] parameter(1) + bcast = f32[2] broadcast(p1), dimensions={} + ROOT div = f32[2] divide(cvt, bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::Convert(m::Parameter(0)), + m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1)))))); +} + } // namespace } // namespace xla