diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index dc20123ee5f..c5e5063f884 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -531,9 +531,17 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; HloInstruction *a, *c1, *c2; if (Match(add, m::Add(m::Add(m::NonConstant(&a), m::Constant(&c1)), - m::Constant(&c2)))) { + m::Constant(&c2))) || + Match(add, m::Add(m::Add(m::NonConstant(&a), + m::Broadcast(m::ConstantScalar(&c1))), + m::Broadcast(m::ConstantScalar(&c2))))) { TF_ASSIGN_OR_RETURN(auto* sum_of_constants, MakeBinaryHlo(HloOpcode::kAdd, c1, c2)); + if (ShapeUtil::IsScalar(sum_of_constants->shape()) && + !ShapeUtil::IsScalar(add->shape())) { + sum_of_constants = computation_->AddInstruction( + HloInstruction::CreateBroadcast(add->shape(), sum_of_constants, {})); + } return ReplaceWithNewInstruction( add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, a, sum_of_constants)); @@ -861,9 +869,17 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { // Canonicalize subtraction of a constant to addition. VLOG(10) << "trying transform [A - Const => A + (-Const)]"; - if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs)))) { + if (Match(sub, m::Subtract(m::NonConstant(&lhs), m::Constant(&rhs))) || + Match(sub, m::Subtract(m::NonConstant(&lhs), + m::Broadcast(m::Constant(&rhs))))) { HloInstruction* negative_const = computation_->AddInstruction( HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); + if (const HloInstruction* broadcast = + DynCast(sub->operand(1))) { + negative_const = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + broadcast->shape(), negative_const, broadcast->dimensions())); + } return ReplaceWithNewInstruction( sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs, negative_const)); @@ -1883,6 +1899,29 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } + VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]"; + HloInstruction *a, *c1, *c2; + if (Match(multiply, + m::Multiply(m::Multiply(m::NonConstant(&a), m::Constant(&c1)), + m::Constant(&c2))) || + Match(multiply, + m::Multiply(m::Multiply(m::NonConstant(&a), + m::Broadcast(m::ConstantScalar(&c1))), + m::Broadcast(m::ConstantScalar(&c2))))) { + TF_ASSIGN_OR_RETURN(auto* product_of_constants, + MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); + if (ShapeUtil::IsScalar(product_of_constants->shape()) && + !ShapeUtil::IsScalar(multiply->shape())) { + product_of_constants = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + multiply->shape(), product_of_constants, {})); + } + return ReplaceWithNewInstruction( + multiply, + HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, a, + product_of_constants)); + } + // exp(A) * exp(B) => exp(A+B) if (Match(multiply, m::Multiply(m::Exp(m::Op(&lhs)), m::Exp(m::Op(&rhs))))) { auto add = computation_->AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index fed8d26e492..146267fc571 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -295,6 +295,47 @@ TEST_F(AlgebraicSimplifierTest, MulZero) { EXPECT_EQ(computation->root_instruction(), zero); } +TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeConstants) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + c0 = f32[] constant(2.0) + c1 = f32[] constant(3.0) + multiply0 = f32[] multiply(p0, c0) + ROOT multiply1 = f32[] multiply(multiply0, c1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply(m::Parameter(0), + m::Multiply(m::ConstantScalar(2.0), + m::ConstantScalar(3.0))))); +} + +TEST_F(AlgebraicSimplifierTest, MultiplyReassociateMergeBroadcastedConstants) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c0 = f32[] constant(2.0) + c1 = f32[] constant(3.0) + b0 = f32[4] broadcast(c0), dimensions={} + b1 = f32[4] broadcast(c1), dimensions={} + multiply0 = f32[4] multiply(p0, b0) + ROOT multiply1 = f32[4] multiply(multiply0, b1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Multiply( + m::Parameter(0), m::Broadcast(m::Multiply(m::ConstantScalar(2.0), + m::ConstantScalar(3.0)))))); +} + // Test that select(true, a, b) is simplified to a TEST_F(AlgebraicSimplifierTest, SelectTrue) { Shape r0s32 = ShapeUtil::MakeShape(S32, {}); @@ -446,6 +487,27 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { m::Add(m::Op().Is(constant1), m::Op().Is(constant2))))); } +TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c0 = f32[] constant(1.0) + c1 = f32[] constant(2.0) + b0 = f32[4] broadcast(c0), dimensions={} + b1 = f32[4] broadcast(c1), dimensions={} + add0 = f32[4] add(p0, b0) + ROOT add1 = f32[4] add(add0, b1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Add(m::ConstantScalar(1.0), + m::ConstantScalar(2.0)))))); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { auto m = CreateNewVerifiedModule(); Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); @@ -640,6 +702,25 @@ TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { m::Negate(m::Op().Is(constant))))); } +// Test that A - Broadcast(Const) is canonicalized to A + Broadcast(-Const). +TEST_F(AlgebraicSimplifierTest, SubBroadcastConstCanonicalization) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[4] parameter(0) + c = f32[] constant(0.125) + b = f32[4] broadcast(c), dimensions={} + ROOT sub = f32[4] subtract(p0, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Add(m::Parameter(0), + m::Broadcast(m::Negate(m::ConstantScalar(0.125)))))); +} + // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { auto m = CreateNewVerifiedModule();