diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 440e04c9205..e0a8b87c83b 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -472,8 +472,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { HloInstruction* dot); HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) { - if (scalar_add_computation_) { - return scalar_add_computation_; + HloComputation*& scalar_add_computation = scalar_add_computations_[type]; + if (scalar_add_computation) { + return scalar_add_computation; } HloComputation::Builder b("scalar_add_computation"); @@ -485,9 +486,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { HloInstruction::CreateParameter(1, shape, "scalar_rhs")); auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); - scalar_add_computation_ = + scalar_add_computation = computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_add_computation_; + return scalar_add_computation; } // Tries to fold a kPad in the input or filter into the convolution @@ -528,8 +529,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { // Whether algebraic simplification has occurred. bool changed_ = false; - // Cached computation for adding two scalar F32. - HloComputation* scalar_add_computation_ = nullptr; + // Cached computation for adding two scalars of a given type. + absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_; AlgebraicSimplifier* simplifier_ = nullptr; }; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 9f823c76d80..3ac47821654 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -6520,5 +6520,23 @@ TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) { m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1)))))); } +TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) { + constexpr char kModuleStr[] = R"( + HloModule test + ENTRY test { + a = c64[2,2] parameter(0) + b = c64[2] parameter(1) + cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0} + c = f64[2,2] parameter(2) + d = f64[2] parameter(3) + dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple = (c64[2], f64[2]) tuple(cd, dd) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_EQ(3, m->computation_count()); +} + } // namespace } // namespace xla