diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 440e04c9205..e0a8b87c83b 100755
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -472,8 +472,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
       HloInstruction* dot);
 
   HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) {
-    if (scalar_add_computation_) {
-      return scalar_add_computation_;
+    HloComputation*& scalar_add_computation = scalar_add_computations_[type];
+    if (scalar_add_computation) {
+      return scalar_add_computation;
     }
 
     HloComputation::Builder b("scalar_add_computation");
@@ -485,9 +486,9 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
         HloInstruction::CreateParameter(1, shape, "scalar_rhs"));
     auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary(
         shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs));
-    scalar_add_computation_ =
+    scalar_add_computation =
         computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
-    return scalar_add_computation_;
+    return scalar_add_computation;
   }
 
   // Tries to fold a kPad in the input or filter into the convolution
@@ -528,8 +529,8 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor {
   // Whether algebraic simplification has occurred.
   bool changed_ = false;
 
-  // Cached computation for adding two scalar F32.
-  HloComputation* scalar_add_computation_ = nullptr;
+  // Cached computation for adding two scalars of a given type.
+  absl::flat_hash_map<PrimitiveType, HloComputation*> scalar_add_computations_;
 
   AlgebraicSimplifier* simplifier_ = nullptr;
 };
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 9f823c76d80..3ac47821654 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -6520,5 +6520,23 @@ TEST_F(AlgebraicSimplifierTest, ScalarDividePredicate) {
           m::Broadcast(m::Divide(m::ConstantScalar(1), m::Parameter(1))))));
 }
 
+TEST_F(AlgebraicSimplifierTest, MultipleDotStrengthReductions) {
+  constexpr char kModuleStr[] = R"(
+    HloModule test
+    ENTRY test {
+      a = c64[2,2] parameter(0)
+      b = c64[2] parameter(1)
+      cd = c64[2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      c = f64[2,2] parameter(2)
+      d = f64[2] parameter(3)
+      dd = f64[2] dot(c, d), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+      ROOT tuple = (c64[2], f64[2]) tuple(cd, dd)
+    }
+  )";
+  TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr));
+  ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie());
+  EXPECT_EQ(3, m->computation_count());
+}
+
 }  // namespace
 }  // namespace xla