From 77b06b0577dd9f6da817dc86fd55507b2de7ae22 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Sun, 31 Mar 2019 21:44:54 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 241258289 --- .../xla/service/algebraic_simplifier.cc | 23 +++++++++++-------- .../xla/service/algebraic_simplifier_test.cc | 20 ++++++++++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index b3044504312..35ca86d5d7f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1076,33 +1076,38 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { // // (Backends can do this transformation, but generally only if the constant is // a scalar.) - if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) { - Shape result_shape = b->literal().shape(); + if (Match(divide, m::Divide(m::NonConstant(&a), m::Op(&b))) && + (Match(b, m::Constant(&c)) || Match(b, m::Broadcast(m::Constant(&c))))) { + Shape result_shape = c->literal().shape(); Literal new_literal(result_shape); switch (result_shape.element_type()) { case F16: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case F32: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case BF16: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case F64: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case C64: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; case C128: - TF_RETURN_IF_ERROR(InvertConstant(*b, &new_literal)); + TF_RETURN_IF_ERROR(InvertConstant(*c, &new_literal)); break; default: return Status::OK(); } auto inverse = computation_->AddInstruction( - simplifier_->CreateConstantWithLayoutUpdated((new_literal.Clone()))); + simplifier_->CreateConstantWithLayoutUpdated(new_literal.Clone())); + if (b != c) { + inverse = computation_->AddInstruction(HloInstruction::CreateBroadcast( + b->shape(), inverse, b->dimensions())); + } TF_ASSIGN_OR_RETURN(auto new_divide, MakeBinaryHlo(HloOpcode::kMultiply, a, inverse)); return ReplaceInstruction(divide, new_divide); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 7f399ce0f11..098fefa235c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -853,6 +853,26 @@ TEST_F(AlgebraicSimplifierTest, DivideByConstant) { GmockMatch(m::Multiply(m::Parameter(0), m::Constant()))); } +// A / Broadcast(Const) => A * Broadcast(InvertedConst) +TEST_F(AlgebraicSimplifierTest, DivideByBroadcastedConstant) { + const char* kModuleStr = R"( + HloModule m + test { + p = f32[4] parameter(0) + c = f32[] constant(256.0) + b = f32[4] broadcast(c), dimensions={} + ROOT d = f32[4] divide(p, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply( + m::Parameter(0), + m::Broadcast(m::Op().IsConstantScalar(1.0f / 256.0f))))); +} + // pow(pow(A, X), Y) => pow(A, X*Y) TEST_F(AlgebraicSimplifierTest, PowerOfPower) { auto m = CreateNewVerifiedModule();