From fe270f8c0fde021dfec59af8441c531bee15b112 Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 17 May 2019 23:21:27 -0700 Subject: [PATCH] [XLA] Convert a broadcasted denominator of a divide into a broadcast of a reciprocal. PiperOrigin-RevId: 248839769 --- .../xla/service/algebraic_simplifier.cc | 14 +++++++++++ .../xla/service/algebraic_simplifier_test.cc | 23 +++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 53afc598813..ad904ef304c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1122,6 +1122,20 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return ReplaceInstruction(divide, new_divide); } + // A / Broaddcast(B) => A * Broadcast(1/B) + if (Match(divide, m::Divide(m::Op(&a), m::Broadcast(&c, m::Op(&b))))) { + auto one = MakeBroadcastHlo( + computation_->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::One(b->shape().element_type()))), + {}, b->shape().dimensions()); + TF_ASSIGN_OR_RETURN(auto recip, MakeBinaryHlo(HloOpcode::kDivide, one, b)); + auto recip_broadcast = + MakeBroadcastHlo(recip, c->dimensions(), c->shape().dimensions()); + TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, + recip_broadcast)); + return ReplaceInstruction(divide, new_divide); + } + // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) if (Match(divide, m::Divide(m::Divide(m::Op(&a), m::Op(&b)), m::Divide(m::Op(&c), m::Op(&d))))) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e37b69c5cba..19a88fd7d40 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5061,6 +5061,29 @@ TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) { EXPECT_THAT(root, GmockMatch(m::Multiply())); } +TEST_F(AlgebraicSimplifierTest, DivOfBroadcast) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + p0 = f32[10] parameter(0) + b = f32[30,10] broadcast(f32[10] p0), dimensions={1} + p1 = f32[30,10] parameter(1) + ROOT d = f32[30,10] divide(p1,b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Multiply( + m::Parameter(1), + m::Broadcast(m::Divide(m::Broadcast(m::Constant()), + m::Parameter(0)))))); +} + // Test that 1/sqrt(X) is simplified to rsqrt(X). TEST_F(AlgebraicSimplifierTest, RecipSqrt) { const char* kModuleStr = R"(