From 361470d24adc8b3cbc5b0e4af3a75f92392369c5 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Tue, 19 May 2020 20:15:55 -0700 Subject: [PATCH] [XLA] Strength reduce cvt(pred) / bcast(f32) to bcast(1 / f32) * cvt(pred) This allows us to reduce the number of redundant divides. PiperOrigin-RevId: 312407220 Change-Id: Id6ac5322d2eeecd1a40aee0e53b2c814220726d0 --- .../xla/service/algebraic_simplifier.cc | 16 +++++++++++++++ .../xla/service/algebraic_simplifier_test.cc | 20 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index ecbf2075abe..2fbfd156844 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1488,6 +1488,22 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return ReplaceInstruction(divide, new_divide); } + // If X is a convert from pred, then + // X / broadcast(Y) => broadcast(1/Y) * X + if (Match(divide, + m::Divide( + m::Convert(&a, + m::Op().WithShape(m::Shape().WithElementType(PRED))), + m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) { + TF_ASSIGN_OR_RETURN( + auto recip, MakeBinaryHlo(HloOpcode::kDivide, MakeScalarLike(b, 1), b)); + auto recip_bcast = computation_->AddInstruction( + HloInstruction::CreateBroadcast(divide->shape(), recip, {})); + TF_ASSIGN_OR_RETURN(auto mul, + MakeBinaryHlo(HloOpcode::kMultiply, recip_bcast, a)); + return ReplaceInstruction(divide, mul); + } + return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 08a004e39fe..0260a925b63 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -6481,5 +6481,25 @@ TEST_F(AlgebraicSimplifierTest, SwapConvOperands) { EXPECT_EQ(conv->window().dimensions(1).padding_high(), 1); } +TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = pred[2] parameter(0) + cvt = f32[2] convert(p0) + p1 = f32[] parameter(1) + bcast = f32[2] broadcast(p1), dimensions={} + ROOT div = f32[2] divide(cvt, bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::Convert(m::Parameter(0)), + m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1)))))); +} + } // namespace } // namespace xla