diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 72407213f97..cd6163415c3 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -962,6 +962,7 @@ tf_cuda_library(
         "util/guarded_philox_random.h",
         "util/mirror_pad_mode.h",
         "util/padding.h",
+        "util/einsum_op_util.h",
         "util/port.h",
         "util/ptr_util.h",
         "util/reffed_status_callback.h",
diff --git a/tensorflow/core/api_def/base_api/api_def_Einsum.pbtxt b/tensorflow/core/api_def/base_api/api_def_Einsum.pbtxt
new file mode 100644
index 00000000000..f84fd23e5e2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Einsum.pbtxt
@@ -0,0 +1,100 @@
+op {
+  graph_op_name: "Einsum"
+  in_arg {
+    name: "inputs"
+    description: <<END
+List of 1 or 2 Tensors.
+END
+  }
+  out_arg {
+    name: "output"
+    description: <<END
+Output Tensor with shape depending upon `equation`.
+END
+  }
+  attr {
+    name: "equation"
+    description: <<END
+String describing the Einstein Summation operation; in the format of np.einsum.
+END
+  }
+  summary: "Tensor contraction according to Einstein summation convention."
+  description: <<END
+Implements generalized Tensor contraction and reduction. Each input Tensor must
+have a corresponding input subscript appearing in the comma-separated left-hand
+side of the equation. The right-hand side of the equation consists of the
+output subscript. The input subscripts and the output subscript should consist
+of zero or more named axis labels and at most one ellipsis (`...`).
+
+The named axis labels may be any single character other than those having
+special meaning, namely `,.->`. The behavior of this Op is undefined if it
+receives an ill-formatted equation; since the validation is done at
+graph-building time, we omit format validation checks at runtime.
+
+Note: This Op is *not* intended to be called by the user; instead users should
+call `tf.einsum` directly. It is a hidden Op used by `tf.einsum`.
+
+Operations are applied to the input(s) according to the following rules:
+
+ (a) Generalized Diagonals: For input dimensions corresponding to axis labels
+     appearing more than once in the same input subscript, we take the
+     generalized (`k`-dimensional) diagonal.
+     For example, in the equation `iii->i` with input shape `[3, 3, 3]`, the
+     generalized diagonal would consist of `3` elements at indices `(0, 0, 0)`,
+     `(1, 1, 1)` and `(2, 2, 2)` to create a Tensor of shape `[3]`.
+
+ (b) Reduction: Axes corresponding to labels appearing only in one input
+     subscript but not in the output subscript are summed over prior to Tensor
+     contraction.
+     For example, in the equation `ab,bc->b`, the axis labels `a` and `c` are
+     the reduction axis labels.
+
+ (c) Batch Dimensions: Axes corresponding to labels appearing in each of the
+     input subscripts and also in the output subscript make up the batch
+     dimensions in Tensor contraction. Unnamed axis labels corresponding to
+     ellipsis (`...`) also correspond to batch dimensions.
+     For example, for the equation denoting batch matrix multiplication,
+     `bij,bjk->bik`, the axis label `b` corresponds to a batch dimension.
+
+ (d) Contraction: In case of binary einsum, axes corresponding to labels
+     appearing in two different inputs (and not in the output) are contracted
+     against each other.
+     Considering the batch matrix multiplication equation again
+     (`bij,bjk->bik`), the contracted axis label is `j`.
+
+ (e) Expand Diagonal: If the output subcripts contain repeated (explicit) axis
+     labels, the opposite operation of (a) is applied. For example, in the
+     equation `i->iii`, and input shape `[3]`, the output of shape `[3, 3, 3]`
+     are all zeros, except for the (generalized) diagonal which is populated
+     with values from the input.
+     Note: This operation is not supported by `np.einsum` or `tf.einsum`; it is
+     provided to enable computing the symbolic gradient of `tf.einsum`.
+
+The output subcripts must contain only labels appearing in at least one of the
+input subscripts. Furthermore, all dimensions mapping to the same axis label
+must be equal.
+
+Any of the input and output subscripts may contain at most a single ellipsis
+(`...`). These ellipsis are mapped against dimensions not corresponding to any
+named axis label. If two inputs contain ellipsis, then they are broadcasted
+according to standard NumPy broadcasting
+[rules](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
+
+The broadcasted dimensions are placed in the corresponding location of the
+ellipsis in the output subscript. If the broadcasted dimensions are non-empty
+and the output subcripts do not contain ellipsis, then an InvalidArgument error
+is raised.
+
+@compatibility(numpy)
+Similar to [`numpy.einsum`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.einsum.html).
+
+Comparison with `numpy.einsum`:
+
+ * This Op only supports unary and binary forms of `numpy.einsum`.
+ * This Op does not support implicit form. (i.e. equations without `->`).
+ * This Op also supports repeated indices in the output subscript, which is not
+   supported by `numpy.einsum`.
+@end_compatibility
+
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Einsum.pbtxt b/tensorflow/core/api_def/python_api/api_def_Einsum.pbtxt
new file mode 100644
index 00000000000..5178c80df4b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Einsum.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Einsum"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 8691716ef6a..fb98a880aeb 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -12,12 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-
 #include <unordered_set>
 
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/match.h"
+#include "absl/strings/str_split.h"
+#include "absl/strings/string_view.h"
 #include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/shape_inference.h"
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/util/einsum_op_util.h"
 
 namespace tensorflow {
 
@@ -236,6 +242,178 @@ Status MatMulShape(shape_inference::InferenceContext* c) {
   return Status::OK();
 }
 
+namespace {
+
+// Validate that an Einsum subscript contains exactly one or zero ellipsis; and
+// that periods (.) occur only within an ellipses (...).
+Status ValidateEinsumEllipsis(absl::string_view subscript,
+                              bool* found_ellipsis) {
+  const int num_periods = absl::c_count(subscript, '.');
+  if (num_periods != 0 && num_periods != 3) {
+    return errors::InvalidArgument(
+        "Expected at most one ellipsis (...), but found ", num_periods,
+        " periods (.) in the input subscript: ", subscript);
+  }
+  if (num_periods == 3 && !absl::StrContains(subscript, "...")) {
+    return errors::InvalidArgument(
+        "Periods found outside of ellipsis in subscript: ", subscript);
+  }
+  *found_ellipsis = num_periods > 0;
+  return Status::OK();
+}
+
+}  // namespace
+
+Status EinsumShape(shape_inference::InferenceContext* c) {
+  // We assume that the equation has a valid format. Either (x),(y)->(z)
+  // or (x)->(z), where each of (x), (y) and (z) are concatenation of zero or
+  // more latin alphabets and contains at most one ellipsis ('...').
+  string equation;
+  TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation));
+  gtl::InlinedVector<string, 2> input_labels;
+  string output_labels;
+  TF_RETURN_IF_ERROR(
+      ParseEinsumEquation(equation, &input_labels, &output_labels));
+
+  if (c->num_inputs() == 0 || c->num_inputs() > 2) {
+    return errors::InvalidArgument("Expected either 1 or 2 inputs but got: ",
+                                   c->num_inputs());
+  }
+  if (c->num_inputs() != input_labels.size()) {
+    return errors::InvalidArgument("Expected ", input_labels.size(),
+                                   " inputs for equation ", equation,
+                                   " but got: ", c->num_inputs());
+  }
+
+  // Validate input subscripts, build the label to dimension mapping and obtain
+  // the broadcast shapes that map to ellipsis.
+  absl::flat_hash_map<char, DimensionHandle> label_to_dimension;
+  gtl::InlinedVector<ShapeHandle, 2> input_bcast_shapes(c->num_inputs());
+  for (int i = 0; i < c->num_inputs(); ++i) {
+    bool has_ellipsis = false;
+    TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis));
+    ShapeHandle input_shape = c->input(i);
+    // Validate that the input rank is sufficient for the given number of named
+    // labels.
+    if (c->RankKnown(input_shape)) {
+      if (has_ellipsis) {
+        const int num_named_labels =
+            static_cast<int>(input_labels[i].size()) - 3;
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(
+            c->WithRankAtLeast(input_shape, num_named_labels, &input_shape),
+            " for ", i, "th input and equation: ", equation);
+      } else {
+        const int num_named_labels = static_cast<int>(input_labels[i].size());
+        TF_RETURN_WITH_CONTEXT_IF_ERROR(
+            c->WithRank(input_shape, num_named_labels, &input_shape), " for ",
+            i, "th input and equation: ", equation);
+      }
+    }
+
+    bool seen_ellipsis = false;
+    input_bcast_shapes[i] = c->Scalar();
+    // Run through the input labels; populate label_to_dimension mapping and
+    // compute the broadcast shapes corresponding to the ellipsis (if present).
+    for (int label_idx = 0; label_idx < input_labels[i].size(); ++label_idx) {
+      const char label = input_labels[i][label_idx];
+      // Calculate the input axis that the current label is referring to. After
+      // the ellipsis, the axis may be found by using negative indices; i.e the
+      // (rank - k)th dimension corresponds to the (num_labels - k)th label.
+      const int64 axis_before_ellipsis = label_idx;
+      const int64 axis_after_ellipsis =
+          c->RankKnown(input_shape)
+              ? label_idx + c->Rank(input_shape) - input_labels[i].size()
+              : -1;
+
+      // Populate the input broadcast shape when we encounter an ellipsis (...).
+      if (label == '.') {
+        if (!c->RankKnown(input_shape)) {
+          input_bcast_shapes[i] = c->UnknownShape();
+        } else {
+          // The broadcast shape runs till the named label right after the
+          // ellipsis, the label with index (label_idx + 3).
+          TF_RETURN_IF_ERROR(c->Subshape(input_shape, axis_before_ellipsis,
+                                         axis_after_ellipsis + 3,
+                                         &input_bcast_shapes[i]));
+        }
+        label_idx += 2;  // Skip the rest of the ellipsis.
+        seen_ellipsis = true;
+        continue;
+      }
+      // Obtain the dimension that the current label corresponds to.
+      int64 axis = seen_ellipsis ? axis_after_ellipsis : axis_before_ellipsis;
+      DimensionHandle new_dim = c->RankKnown(input_shape)
+                                    ? c->Dim(input_shape, axis)
+                                    : c->UnknownDim();
+      // If we've seen this label before, make sure previous and current
+      // dimensions are compatible.
+      if (label_to_dimension.contains(label)) {
+        DimensionHandle merged;
+        TF_RETURN_IF_ERROR(
+            c->Merge(label_to_dimension[label], new_dim, &merged));
+        label_to_dimension[label] = merged;
+      } else {
+        label_to_dimension[label] = new_dim;
+      }
+    }
+  }
+
+  // For two inputs, broadcast the two input broadcast shapes to create the
+  // output broadcast shape. For one input, just copy the single broadcast
+  // shape.
+  ShapeHandle output_bcast_shape;
+  if (input_bcast_shapes.size() == 1) {
+    output_bcast_shape = input_bcast_shapes[0];
+  } else if (input_bcast_shapes.size() == 2) {
+    TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(
+        c, input_bcast_shapes[0], input_bcast_shapes[1], &output_bcast_shape));
+  }
+
+  bool output_has_ellipsis = false;
+  TF_RETURN_IF_ERROR(
+      ValidateEinsumEllipsis(output_labels, &output_has_ellipsis));
+  if (output_has_ellipsis) {
+    // If the output subscript has ellipsis and the output broadcast rank is
+    // unknown, then the output shape should have unknown rank.
+    if (!c->RankKnown(output_bcast_shape)) {
+      c->set_output(0, c->UnknownShape());
+      return Status::OK();
+    }
+  } else {
+    // If the output subscripts don't have ellipsis then make sure the output
+    // broadcasting shape is empty.
+    TF_RETURN_WITH_CONTEXT_IF_ERROR(
+        c->WithRankAtMost(output_bcast_shape, 0, &output_bcast_shape),
+        " for einsum equation '", equation,
+        "' without ellipsis (...) in the output subscripts where input(s) have "
+        "non-empty broadcasting shape");
+    output_bcast_shape = c->Scalar();
+  }
+
+  // Create the output shape from output labels and label_to_dimension mapping.
+  std::vector<DimensionHandle> output_dims;
+  for (int label_idx = 0; label_idx < output_labels.size(); ++label_idx) {
+    const char label = output_labels[label_idx];
+    // Append the output_bcast_shape when the ellipsis is encountered.
+    if (label == '.') {
+      for (int k = 0; k < c->Rank(output_bcast_shape); ++k) {
+        output_dims.push_back(c->Dim(output_bcast_shape, k));
+      }
+      label_idx += 2;  // Skip the rest of the ellipsis.
+      continue;
+    }
+    auto dimension_it = label_to_dimension.find(label);
+    if (dimension_it == label_to_dimension.end()) {
+      return errors::InvalidArgument(
+          "Einsum output subscripts for equation '", equation, "' has label '",
+          label, "' which is not present in the input subscripts");
+    }
+    output_dims.push_back(dimension_it->second);
+  }
+  c->set_output(0, c->MakeShape(output_dims));
+  return Status::OK();
+}
+
 Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) {
   ShapeHandle a_shape;
   ShapeHandle b_shape;
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 5050270a2f2..e11a8557f60 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -230,6 +230,9 @@ Status MatMulShape(shape_inference::InferenceContext* c);
 // batch dimensions.
 Status BatchMatMulV2Shape(shape_inference::InferenceContext* c);
 
+// Shape function for Einsum.
+Status EinsumShape(shape_inference::InferenceContext* c);
+
 // Shape function for BiasAdd-like operations.
 Status BiasAddShape(shape_inference::InferenceContext* c);
 
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index f30fd2aea3a..40ca891f929 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -213,6 +213,129 @@ TEST(CommonShapeFnsTest, MatMulShapeTest) {
   }
 }
 
+TEST(CommonShapeFnsTest, Einsum_ShapeFn) {
+  ShapeInferenceTestOp op("Einsum");
+  auto set_equation = [&op](int n, string equation) {
+    std::vector<NodeDefBuilder::NodeOut> input_list;
+    input_list.reserve(n);
+    for (int i = 0; i < n; ++i) {
+      input_list.emplace_back("a", 0, DT_FLOAT);
+    }
+    TF_ASSERT_OK(NodeDefBuilder("test", "Einsum")
+                     .Input(input_list)
+                     .Attr("equation", equation)
+                     .Finalize(&op.node_def));
+  };
+
+  // Unary cases.
+  set_equation(1, "abc->c");
+  INFER_OK(op, "[?,?,?]", "[d0_2]");
+  set_equation(1, "abc->aabbcc");
+  INFER_OK(op, "[?,?,?]", "[d0_0,d0_0,d0_1,d0_1,d0_2,d0_2]");
+  set_equation(1, "abc->");
+  INFER_OK(op, "[?,?,?]", "[]");
+  set_equation(1, "->");
+  INFER_OK(op, "[]", "[]");
+
+  // Binary cases.
+  set_equation(2, "ij,jk->ik");
+  INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
+  set_equation(2, "ij,jk->ik");
+  INFER_OK(op, "[?,?];[?,?]", "[d0_0,d1_1]");
+  set_equation(2, "ab,ab->");
+  INFER_OK(op, "[?,?];[?,?]", "[]");
+  set_equation(2, "ab,->b");
+  INFER_OK(op, "[?,?];[]", "[d0_1]");
+  set_equation(2, ",->");
+  INFER_OK(op, "[];[]", "[]");
+  set_equation(2, "aaa,b->abbb");
+  INFER_OK(op, "[?,?,?];[?]", "[d0_0,d1_0,d1_0,d1_0]");
+  set_equation(2, ",abcd->badc");
+  INFER_OK(op, "[];[?,?,?,?]", "[d1_1,d1_0,d1_3,d1_2]");
+
+  // Ellipsis cases.
+  set_equation(1, "a...bc->c...");
+  INFER_OK(op, "[?,?,?,?,?]", "[d0_4,d0_1,d0_2]");
+  set_equation(2, "...ij,...jk->...ik");
+  INFER_OK(op, "[?,?,?,?,?];[1,?,?]", "[d0_0,d0_1,d0_2,d0_3,d1_2]");
+  INFER_OK(op, "[1,?,?];[?,?,?,?,?]", "[d1_0,d1_1,d1_2,d0_1,d1_4]");
+
+  // Unknown rank.
+  set_equation(1, "abc->c");
+  INFER_OK(op, "?", "[?]");
+  set_equation(1, "a...bc->c");
+  INFER_OK(op, "?", "[?]");
+  set_equation(1, "a...bc->c...");
+  INFER_OK(op, "?", "?");
+
+  set_equation(2, "...ij,...jk->...ik");
+  INFER_OK(op, "?;?", "?");
+  INFER_OK(op, "[?,?,?];?", "?");
+  INFER_OK(op, "?;[?,?,?]", "?");
+  set_equation(2, "...ij,...jk->ik");
+  INFER_OK(op, "?;?", "[?,?]");
+  set_equation(2, "abd,b...c->...cad");
+  INFER_OK(op, "[?,?,?];[?,?,?,?]", "[d1_1,d1_2,d1_3,d0_0,d0_2]");
+  set_equation(2, "...ab,b...c->ac...");
+  INFER_OK(op, "[?,1,?,?];[?,?,?]", "[d0_2,d1_2,d0_0,d1_1]");
+
+  // Wrong number of inputs.
+  set_equation(2, "ab->b");
+  INFER_ERROR("got: 2", op, "[?,?];[?,?]");
+  set_equation(1, "ab,a->b");
+  INFER_ERROR("got: 1", op, "[?,?]");
+
+  // Invalid format. Implicit form is not supported.
+  set_equation(1, "a");
+  INFER_ERROR("equation", op, "[2]");
+  set_equation(2, "ab,bc");
+  INFER_ERROR("equation", op, "[2,2];[2,2]");
+
+  // Wrong number of ellipsis or periods outside of ellipsis.
+  set_equation(1, "..a.->a...");
+  INFER_ERROR("ellipsis", op, "[1,1,2,1]");
+  set_equation(1, "...a->.a..");
+  INFER_ERROR("ellipsis", op, "[1,1,1,2]");
+  set_equation(1, "...a...->...a");
+  INFER_ERROR("ellipsis", op, "[1,1,1,2]");
+  set_equation(1, "..a..b..->...ab");
+  INFER_ERROR("ellipsis", op, "[1,1,2,1]");
+  set_equation(2, "...a...,ab->a");
+  INFER_ERROR("ellipsis", op, "[1,2,1];[2,1]");
+  set_equation(2, "a,...ab...->a");
+  INFER_ERROR("ellipsis", op, "[2];[1,2,1,1]");
+  set_equation(2, "a,ab->a......");
+  INFER_ERROR("ellipsis", op, "[2];[2,1]");
+
+  // Output label doesn't appear in input.
+  set_equation(1, "abc->d");
+  INFER_ERROR("'d'", op, "[?,?,?]");
+
+  // Mismatch in input rank.
+  set_equation(1, "abc->c");
+  INFER_ERROR("4", op, "[?,?,?,?]");
+  INFER_ERROR("2", op, "[?,?]");
+  set_equation(1, "...abc->...c");
+  INFER_ERROR("2", op, "[?,?]");
+
+  // Input dimensions are not consistent.
+  set_equation(2, "ab,ab->a");
+  INFER_ERROR("are 1 and 2", op, "[1,2];[2,1]");
+  set_equation(2, "aa,bb->a");
+  INFER_ERROR("are 1 and 2", op, "[1,2];[2,2]");
+
+  // Invalid broadcasting dimensions.
+  set_equation(2, "...ij,...jk->...ik");
+  INFER_ERROR("are 2 and 3", op, "[2,?,?];[3,?,?]");
+  set_equation(2, "i...j,jk...->...ik");
+  INFER_ERROR("are 2 and 3", op, "[?,2,?];[?,?,3]");
+  set_equation(2, "...ij,...jk->ik");
+  set_equation(2, "i...j,jk...->ik");
+  INFER_ERROR("non-empty broadcasting", op, "[?,2,?];[?,?]");
+  set_equation(2, "...ab,b...c->ac...");
+  INFER_OK(op, "?;[4,5,3]", "?");
+}
+
 TEST(CommonShapeFnsTest, BatchMatMulV2_ShapeFn) {
   ShapeInferenceTestOp op("BatchMatMulV2");
   auto set_adj = [&op](bool adj_x, bool adj_y) {
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 644e591100c..bd330539c7f 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3216,6 +3216,7 @@ cc_library(
         ":cholesky_grad",
         ":cholesky_op",
         ":determinant_op",
+        ":einsum_op",
         ":lu_op",
         ":matrix_exponential_op",
         ":matrix_inverse_op",
@@ -3414,6 +3415,21 @@ tf_kernel_library(
     ],
 )
 
+tf_kernel_library(
+    name = "einsum_op",
+    prefix = "einsum_op",
+    deps = [
+        ":batch_matmul_op",
+        ":reduction_ops",
+        ":transpose_functor",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//third_party/eigen3",
+        "@com_google_absl//absl/container:flat_hash_map",
+        "@com_google_absl//absl/strings",
+    ],
+)
+
 cc_library(
     name = "linalg_ops_common",
     srcs = ["linalg_ops_common.cc"],
diff --git a/tensorflow/core/kernels/einsum_op.cc b/tensorflow/core/kernels/einsum_op.cc
new file mode 100644
index 00000000000..ae5733c19d6
--- /dev/null
+++ b/tensorflow/core/kernels/einsum_op.cc
@@ -0,0 +1,712 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_split.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/batch_matmul_op_impl.h"
+#include "tensorflow/core/kernels/reduction_ops_common.h"
+#include "tensorflow/core/kernels/transpose_functor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/einsum_op_util.h"
+
+namespace tensorflow {
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+
+namespace {
+
+using ShapeVec = gtl::InlinedVector<int64, 8>;
+using Labels = gtl::InlinedVector<int, 8>;
+using OperandLabels = gtl::InlinedVector<Labels, 2>;
+using LabelCounts = gtl::InlinedVector<int, 8>;
+using OperandLabelCounts = gtl::InlinedVector<LabelCounts, 2>;
+using LabelToDimSizes = gtl::InlinedVector<int64, 8>;
+
+// Dummy axis label used to denote an ellipsis in an input or output subscript.
+constexpr int kEllipsisLabel = -1;
+
+// Each dimension is categorized into exactly one of five types based on whether
+// its corresponding label is present in the input and/or the output subscripts.
+enum DimensionType {
+  // Batch dimensions are those present in two inputs as well as the output.
+  // They are part of the batch dimensions during Tensor contraction.
+  // Such dimensions may be broadcasting dimensions (those mapping to ellipsis)
+  // or explicit batch dimensions corresponding to named axis labels.
+  kBroadcasting = 0,
+  kBatch = 1,
+  // Free dimensions are present in exactly one of the inputs, and also the
+  // output. These are non-contracted axes in the Tensor contraction.
+  kFree = 2,
+  // Contract dimensions are present in two inputs, but not the output. These
+  // dimensions are contracted in Tensor contraction.
+  kContract = 3,
+  // Reduce dimensions are present in exactly one input; and not in the output
+  // and are summed over prior to Tensor contraction.
+  kReduce = 4,
+};
+
+// Returns the DimensionType given whether the corresponding label is present in
+// exactly one input subscript (is_unique) and whether it is absent from the
+// output subscripts (is_removed). Does not handle broadcasting dimensions.
+DimensionType GetDimensionType(bool is_removed, bool is_unique) {
+  if (!is_removed && !is_unique)
+    return kBatch;
+  else if (!is_removed && is_unique)
+    return kFree;
+  else if (is_removed && !is_unique)
+    return kContract;
+  else  // is_removed && is_unique
+    return kReduce;
+}
+
+// Maps the character labels to consecutive integers.
+void MapToLabels(const string& subscript, Labels* labels,
+                 absl::flat_hash_map<char, int>* label_mapping) {
+  for (int i = 0; i < subscript.size(); ++i) {
+    const char label_char = subscript[i];
+    if (label_char == '.') {
+      labels->push_back(kEllipsisLabel);
+      i += 2;  // Skip next 2 characters as well.
+      continue;
+    }
+    if (!label_mapping->contains(label_char)) {
+      const int next_label = label_mapping->size();
+      (*label_mapping)[label_char] = next_label;
+    }
+    const int mapped_label = (*label_mapping)[label_char];
+    labels->push_back(mapped_label);
+  }
+}
+
+// Parses and validates the equation and the input shapes. Single character
+// labels are integerized and we populate input and output label subscripts and
+// corresponding counts. Also create the mapping from (named) labels to their
+// DimensionType.
+Status ParseEquation(const string& equation, OperandLabels* input_labels,
+                     Labels* output_labels,
+                     std::vector<DimensionType>* label_types,
+                     OperandLabelCounts* input_label_counts,
+                     LabelCounts* output_label_counts,
+                     gtl::InlinedVector<bool, 2>* input_has_ellipsis,
+                     bool* output_has_ellipsis) {
+  gtl::InlinedVector<string, 2> input_str;
+  string output_str;
+  TF_RETURN_IF_ERROR(ParseEinsumEquation(equation, &input_str, &output_str));
+
+  // Temporary map from single character labels to (consecutive) integer labels.
+  absl::flat_hash_map<char, int> label_mapping;
+  int num_inputs = input_str.size();
+  input_labels->resize(num_inputs);
+
+  // Map from single characters to integer labels.
+  for (int i = 0; i < num_inputs; ++i) {
+    MapToLabels(input_str[i], &input_labels->at(i), &label_mapping);
+  }
+  MapToLabels(output_str, output_labels, &label_mapping);
+
+  // Compute counts for input and output labels.
+  int num_labels = label_mapping.size();
+  input_label_counts->resize(num_inputs);
+  input_has_ellipsis->resize(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    input_label_counts->at(i).resize(num_labels);
+    for (const int label : input_labels->at(i)) {
+      if (label != kEllipsisLabel)
+        input_label_counts->at(i)[label] += 1;
+      else
+        input_has_ellipsis->at(i) = true;
+    }
+  }
+  output_label_counts->resize(num_labels);
+  for (const int label : *output_labels) {
+    if (label != kEllipsisLabel)
+      output_label_counts->at(label) += 1;
+    else
+      *output_has_ellipsis = true;
+  }
+
+  // Map each label to a unique DimensionType.
+  label_types->resize(num_labels);
+  for (int label = 0; label < num_labels; ++label) {
+    if (label == kEllipsisLabel) continue;
+    bool removed = (*output_label_counts)[label] == 0;
+    bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 ||
+                  (*input_label_counts)[1][label] == 0;
+    (*label_types)[label] = GetDimensionType(removed, unique);
+  }
+  return Status::OK();
+}
+
+// Insert new (unnamed) broadcasting labels at the location of ellipsis.
+void InsertBroadcastLabels(int num_bcast_dims, int num_named_labels,
+                           int ellipsis_axis, Labels* labels,
+                           LabelCounts* label_counts) {
+  labels->erase(labels->begin() + ellipsis_axis);
+  labels->insert(labels->begin() + ellipsis_axis, num_bcast_dims, 0);
+  std::iota(labels->begin() + ellipsis_axis,
+            labels->begin() + ellipsis_axis + num_bcast_dims, num_named_labels);
+  // Increment label counts. Since these are new labels, the count is set to 1.
+  label_counts->resize(num_named_labels + num_bcast_dims, 1);
+}
+
+// Record and validate the label to dimension mapping. Must be a named
+// (non-broadcasting) label as broadcasting labels don't have a fixed dimension.
+Status RecordLabelToDimension(const int label, const int axis,
+                              const Tensor& input,
+                              LabelToDimSizes* label_to_dim_sizes) {
+  const int64 input_dim = input.dim_size(axis);
+  // We know that label_to_dim_sizes has the size to accommodate named labels.
+  if (label_to_dim_sizes->at(label) != 0 &&
+      label_to_dim_sizes->at(label) != input_dim) {
+    return errors::InvalidArgument(
+        "Expected dimension ", label_to_dim_sizes->at(label), " at axis ", axis,
+        " of the input shaped ", input.shape().DebugString(),
+        " but got dimension ", input_dim);
+  }
+  (*label_to_dim_sizes)[label] = input_dim;
+  return Status::OK();
+}
+
+// Validate input dimensions and populate unnamed labels and their label counts.
+Status ProcessDimensions(const OpInputList& inputs,
+                         const gtl::InlinedVector<bool, 2>& input_has_ellipsis,
+                         const bool output_has_ellipsis,
+                         OperandLabels* input_labels, Labels* output_labels,
+                         std::vector<DimensionType>* label_types,
+                         OperandLabelCounts* input_label_counts,
+                         LabelCounts* output_label_counts,
+                         LabelToDimSizes* label_to_dim_sizes) {
+  if (inputs.size() != input_labels->size()) {
+    return errors::InvalidArgument("Expected ", input_labels->size(),
+                                   " inputs but got: ", inputs.size());
+  }
+  const int num_inputs = inputs.size();
+
+  // We infer the number of broadcasting dimensions by taking the maximum rank
+  // among the broadcasting subshapes of the input.
+  int max_bcast_dims = 0;
+  const int num_named_labels = label_types->size();
+  label_to_dim_sizes->resize(num_named_labels);
+  for (int i = 0; i < num_inputs; ++i) {
+    Labels* labels = &(*input_labels)[i];
+
+    if (!input_has_ellipsis[i]) {
+      if (inputs[i].dims() != labels->size()) {
+        return errors::InvalidArgument("Expected input ", i, " to have rank ",
+                                       labels->size(),
+                                       " but got: ", inputs[i].dims());
+      }
+      for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
+        const int label = (*labels)[label_idx];
+        TF_RETURN_IF_ERROR(RecordLabelToDimension(label, label_idx, inputs[i],
+                                                  label_to_dim_sizes));
+      }
+      continue;
+    }
+
+    // Input has an ellipsis.
+    if (inputs[i].dims() + 1 < labels->size()) {
+      return errors::InvalidArgument(
+          "Expected input ", i, " to have rank at least ", labels->size() - 1,
+          " but got: ", inputs[i].dims());
+    }
+    int ellipsis_axis = -1;
+    const int num_bcast_dims = inputs[i].dims() - labels->size() + 1;
+    for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
+      const int label = (*labels)[label_idx];
+      if (label == kEllipsisLabel) {
+        ellipsis_axis = label_idx;
+        continue;
+      }
+      // Current label is not an ellipsis.
+      const int axis =
+          label_idx + (ellipsis_axis == -1 ? 0 : num_bcast_dims - 1);
+      TF_RETURN_IF_ERROR(
+          RecordLabelToDimension(label, axis, inputs[i], label_to_dim_sizes));
+    }
+    // Found an ellipsis. Replace 'kEllipsisLabel' with broadcasting dimensions.
+    if (ellipsis_axis != -1) {
+      InsertBroadcastLabels(num_bcast_dims, num_named_labels, ellipsis_axis,
+                            labels, &input_label_counts->at(i));
+      max_bcast_dims = std::max(max_bcast_dims, num_bcast_dims);
+    }
+  }
+  if (!absl::c_linear_search(input_has_ellipsis, true) &&
+      !output_has_ellipsis) {
+    return Status::OK();
+  }
+  // Insert broadcasting dimensions in the output labels.
+  auto it =
+      std::find(output_labels->begin(), output_labels->end(), kEllipsisLabel);
+  if (it != output_labels->end()) {
+    const int ellipsis_axis = it - output_labels->begin();
+    InsertBroadcastLabels(max_bcast_dims, num_named_labels, ellipsis_axis,
+                          output_labels, output_label_counts);
+  } else if (max_bcast_dims > 0) {
+    return errors::InvalidArgument("Output contains ", max_bcast_dims,
+                                   " broadcasting dimension(s) but no ellipsis "
+                                   "(...) was found in the output subscripts.");
+  }
+  // Populate DimensionType for the new broadcasting labels.
+  label_types->resize(num_named_labels + max_bcast_dims, kBroadcasting);
+  return Status::OK();
+}
+
+// Permutes the labels according to the given permutation.
+void PermuteLabels(const std::vector<int>& permutation, Labels* labels) {
+  Labels permuted_labels(labels->size());
+  for (int i = 0; i < labels->size(); ++i) {
+    permuted_labels[i] = (*labels)[permutation[i]];
+  }
+  labels->swap(permuted_labels);
+}
+
+// Returns a reshaped input Tensor. The underlying buffer is not copied.
+Status CopyFrom(const Tensor& input, const TensorShape& shape, Tensor* output) {
+  if (output->CopyFrom(input, shape)) return Status::OK();
+  return errors::Internal(
+      "Encountered error while reshaping a Tensor of shape ",
+      input.shape().DebugString(), " to shape ", shape.DebugString());
+}
+
+// Returns whether transposing would be a no-op; whether input has rank < 2 or
+// the permutation is the identity permutation.
+bool ShouldTranspose(const TensorShape& input_shape,
+                     const std::vector<int>& permutation) {
+  if (input_shape.dims() < 2) return false;
+  for (int i = 0; i < permutation.size(); ++i) {
+    if (permutation[i] != i) return true;
+  }
+  return false;
+}
+
+// Transpose the input given a permutation. Returns a reference to the input if
+// transposing is not necessary.
+template <typename Device, typename T>
+Status TransposeOperand(OpKernelContext* ctx, const Tensor& input,
+                        const std::vector<int>& permutation, Tensor* output) {
+  if (!ShouldTranspose(input.shape(), permutation)) {
+    return CopyFrom(input, input.shape(), output);
+  }
+  TensorShape transposed_shape;
+  for (int i = 0; i < input.dims(); ++i) {
+    transposed_shape.AddDim(input.dim_size(permutation[i]));
+  }
+  TF_RETURN_IF_ERROR(
+      ctx->allocate_temp(DataTypeToEnum<T>::value, transposed_shape, output));
+  const Device& device = ctx->eigen_device<Device>();
+  TF_RETURN_IF_ERROR(DoTranspose(device, input, permutation, output));
+  return Status::OK();
+}
+
+// If there are repeated labels in either the input or output, then this strides
+// the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively.
+template <typename Device, typename T>
+Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input,
+                       const Labels& labels, const LabelCounts& label_counts,
+                       const bool should_inflate, Tensor* output) {
+  // Return early if there are no repeated indices.
+  if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) {
+    return CopyFrom(input, input.shape(), output);
+  }
+  // We reshape so that each repeated label is compressed to one dimension.
+  // E.g. For iiij -> ij, The shape [3, 3, 3, 5] would be compressed to [27, 5].
+  // Striding appropriately (in this case with strides 14 (=1+3+9) and 1)
+  // recovers the generalized diagonal of shape [3, 5].
+  ShapeVec reshape;
+  ShapeVec strides;
+  // Strided and inflated shapes correspond to input and output shapes,
+  // respectively, should_inflate is true (vice-versa if should_inflate is
+  // false). E.g. they are [3, 5] and [3, 3, 3, 5] in the above example.
+  ShapeVec strided_shape;
+  ShapeVec inflated_shape;
+  for (int label : labels) {
+    const int count = label_counts[label];
+    const int current_axis =
+        should_inflate ? strided_shape.size() : inflated_shape.size();
+    const int64 dim = input.dim_size(current_axis);
+    strided_shape.push_back(dim);
+    inflated_shape.insert(inflated_shape.end(), count, dim);
+    const int64 reshape_dim = MathUtil::IPow(dim, count);
+    reshape.push_back(reshape_dim);
+    // While taking the d-diagonal in a rank k Tensor, we take d equally-spaced
+    // elements including the first and last element. Then,
+    // (k - 1) * stride = d^k - 1, or, stride = (d^k - 1)/(d - 1).
+    const int64 stride =
+        (dim > 1 && count > 1) ? (reshape_dim - 1) / (dim - 1) : 1;
+    strides.push_back(stride);
+  }
+
+  TensorShape output_shape =
+      TensorShape(should_inflate ? inflated_shape : strided_shape);
+  TF_RETURN_IF_ERROR(
+      ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
+  const Device& device = ctx->eigen_device<Device>();
+  switch (reshape.size()) {
+#define NDIMS_CASE(N)                                         \
+  case N: {                                                   \
+    if (should_inflate) {                                     \
+      auto output_map = output->shaped<T, N>(reshape);        \
+      auto input_map = input.shaped<T, N>(strided_shape);     \
+      output_map.device(device) = input_map.inflate(strides); \
+    } else {                                                  \
+      auto input_map = input.shaped<T, N>(reshape);           \
+      auto output_map = output->shaped<T, N>(strided_shape);  \
+      output_map.device(device) = input_map.stride(strides);  \
+    }                                                         \
+  } break;
+    NDIMS_CASE(1);
+    NDIMS_CASE(2);
+    NDIMS_CASE(3);
+    NDIMS_CASE(4);
+    NDIMS_CASE(5);
+    NDIMS_CASE(6);
+    default:
+      return errors::Unimplemented(
+          "Unsupported rank: ", reshape.size(),
+          " while handling repeated indices. Up to rank 6 is supported.");
+#undef NDIMS_CASE
+  }
+  return Status::OK();
+}
+
+// Returns true if the input dimensions are already sorted in the order
+// [batch, contract, free, reduce]. Used to implement an optimization to avoid
+// an extra transpose and instead uses (adj_x and adj_y) in BatchMatMul.
+bool ShouldSwapFreeAndContract(const Labels& labels,
+                               const std::vector<DimensionType>& label_types) {
+  // Check that ordering is according to dimension type, with the role of
+  // free and contract dimensions swapped.
+  gtl::InlinedVector<int, 5> remap = {0, 1, 3, 2, 4};
+  for (int i = 0; i + 1 < labels.size(); ++i) {
+    const int dimtype_a = remap[label_types[labels[i]]];
+    const int dimtype_b = remap[label_types[labels[i + 1]]];
+    if (dimtype_a > dimtype_b ||
+        (dimtype_a == dimtype_b && labels[i] > labels[i + 1])) {
+      return false;
+    }
+  }
+  return true;
+}
+
+template <typename Device, typename T>
+Status ReduceOperand(OpKernelContext* ctx, const Tensor& input,
+                     const std::vector<DimensionType>& label_types,
+                     const LabelCounts& label_counts, Labels* labels,
+                     Labels* free_labels, bool* swap_free_and_contract,
+                     Tensor* output) {
+  // Find the permutation to transpose the input dimensions in the order of
+  // DimensionType; i.e. batch, free, contract and reduce dimensions. This
+  // makes it more convenient to invoke Reduce/Contract operations.
+  std::vector<int> permutation(input.dims());
+  absl::c_iota(permutation, 0);
+  Tensor input_transposed;
+  // Check if we can avoid the transpose. We need to flip the adj_x (or adj_y)
+  // flag during BatchMatMul. This is an extra optimization not necessary for
+  // correctness.
+  if (ShouldSwapFreeAndContract(*labels, label_types)) {
+    *swap_free_and_contract = true;
+  } else {
+    absl::c_sort(permutation, [&](int i, int j) {
+      int label_i = (*labels)[i];
+      int label_j = (*labels)[j];
+      return std::tie(label_types[label_i], label_i) <
+             std::tie(label_types[label_j], label_j);
+    });
+  }
+  // Transpose the input so that DimensionTypes are in order.
+  TF_RETURN_IF_ERROR(
+      TransposeOperand<Device, T>(ctx, input, permutation, &input_transposed));
+  PermuteLabels(permutation, labels);
+
+  // Take the generalized diagonal for dimensions with repeated axis labels.
+  Tensor input_deduped;
+  labels->erase(std::unique(labels->begin(), labels->end()), labels->end());
+  TF_RETURN_IF_ERROR(
+      StrideOrInflate<Device, T>(ctx, input_transposed, *labels, label_counts,
+                                 false /* should_inflate */, &input_deduped));
+
+  // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, reduce]
+  // where we've compacted the dimensions of each DimensionType.
+  gtl::InlinedVector<int64, 5> reshape(5, 1);
+  // The output shape is [batch shape] + [free size, contract size]
+  // That is, the batch shape is preserved (for broadcasting while contracting)
+  // while the free dims and contract dims are compressed to one dimension each.
+  TensorShape output_shape;
+  for (int label_idx = 0; label_idx < labels->size(); ++label_idx) {
+    const int label = labels->at(label_idx);
+    int64 dim = input_deduped.dim_size(label_idx);
+    if (label_types[label] == kBroadcasting || label_types[label] == kBatch) {
+      output_shape.AddDim(dim);
+    } else if (label_types[label] == kFree) {
+      free_labels->push_back(label);
+    }
+    reshape[label_types[label]] *= dim;
+  }
+  if (*swap_free_and_contract) std::swap(reshape[kFree], reshape[kContract]);
+  output_shape.AddDim(reshape[kFree]);
+  output_shape.AddDim(reshape[kContract]);
+
+  if (reshape[kReduce] == 1) {  // No need to actually reduce.
+    return CopyFrom(input_deduped, output_shape, output);
+  }
+  TF_RETURN_IF_ERROR(
+      ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
+  using Reducer = Eigen::internal::SumReducer<T>;
+  using Index = typename TTypes<T>::Tensor::Index;
+  // Reduce along the last axis (i.e axis 1) of the rank-2 Tensor.
+  const int64 output_size = reshape[kBroadcasting] * reshape[kBatch] *
+                            reshape[kFree] * reshape[kContract];
+  functor::ReduceFunctor<Device, Reducer>::Reduce(
+      ctx, output->shaped<T, 1>({output_size}),
+      const_cast<const Tensor&>(input_deduped)
+          .shaped<T, 2>({output_size, reshape[kReduce]}),
+      Eigen::array<Index, 1>({1}), Reducer());
+  return Status::OK();
+}
+
+// Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M].
+Status ReshapeToRank3(const Tensor& input, int batch_size, Tensor* output) {
+  const int rank = input.dims();
+  TensorShape output_shape = {batch_size, input.dim_size(rank - 2),
+                              input.dim_size(rank - 1)};
+  return CopyFrom(input, output_shape, output);
+}
+
+// Conjugates the input.
+template <typename Device, typename T>
+Status Conjugate(OpKernelContext* ctx, Tensor* input) {
+  std::vector<int> permutation(input->dims());
+  std::iota(permutation.begin(), permutation.end(), 0);
+  Tensor output;
+  TF_RETURN_IF_ERROR(
+      ctx->allocate_temp(DataTypeToEnum<T>::value, input->shape(), &output));
+  const Device& d = ctx->eigen_device<Device>();
+  TF_RETURN_IF_ERROR(DoConjugateTranspose(d, *input, permutation, &output));
+  std::swap(*input, output);
+  return Status::OK();
+}
+
+// Contracts the inputs along the last axis. (or the second last if the
+// corresponding value of swap_free_and_contract is true). The batch dimensions
+// are broadcast to the output shape.
+// TODO(anudhyan): Factor this function into a BatchMatMul functor and support
+// transpose_x and transpose_y attributes (in addition to adj_x and adj_y).
+// Also, the BatchMatMul might devolve into a component-wise multiplication when
+// the matrix shape is [1,1]; in this case BatchMatMul functor would be very
+// inefficient. The functor should detect if this is the case and perform
+// componentwise multiplication functor instead.
+template <typename Device, typename T>
+Status ContractOperands(OpKernelContext* ctx, absl::Span<const Tensor> inputs,
+                        absl::Span<const bool> swap_free_and_contract,
+                        Tensor* output) {
+  if (inputs.size() == 1) return CopyFrom(inputs[0], inputs[0].shape(), output);
+  MatMulBCast bcast(inputs[0].shape().dim_sizes(),
+                    inputs[1].shape().dim_sizes());
+  if (!bcast.IsValid()) {
+    return errors::InvalidArgument(
+        "Invalid broadcasting dimensions: ", inputs[0].shape().DebugString(),
+        " vs. ", inputs[1].shape().DebugString());
+  }
+  Tensor lhs;
+  TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[0], bcast.x_batch_size(), &lhs));
+  Tensor rhs;
+  TF_RETURN_IF_ERROR(ReshapeToRank3(inputs[1], bcast.y_batch_size(), &rhs));
+  TensorShape output_shape = bcast.output_batch_shape();
+  for (int i = 0; i < inputs.size(); ++i) {
+    const int64 free_axis =
+        inputs[i].dims() - (swap_free_and_contract[i] ? 1 : 2);
+    output_shape.AddDim(inputs[i].dim_size(free_axis));
+  }
+  bool adj_x = swap_free_and_contract[0];
+  bool adj_y = !swap_free_and_contract[1];
+  if (is_complex<T>::value) {
+    if (adj_x) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &lhs));
+    if (adj_y) TF_RETURN_IF_ERROR(Conjugate<Device, T>(ctx, &rhs));
+  }
+  TF_RETURN_IF_ERROR(
+      ctx->allocate_temp(DataTypeToEnum<T>::value, output_shape, output));
+  Tensor output_reshaped;
+  TF_RETURN_IF_ERROR(
+      ReshapeToRank3(*output, bcast.output_batch_size(), &output_reshaped));
+  LaunchBatchMatMul<Device, T>::Launch(ctx, lhs, rhs, adj_x, adj_y, bcast,
+                                       &output_reshaped);
+  return Status::OK();
+}
+}  // namespace
+
+template <typename Device, typename T>
+class EinsumOp : public OpKernel {
+ public:
+  explicit EinsumOp(OpKernelConstruction* c) : OpKernel(c) {
+    string equation;
+    OP_REQUIRES_OK(c, c->GetAttr("equation", &equation));
+    OP_REQUIRES_OK(c, ParseEquation(equation, &input_labels_, &output_labels_,
+                                    &label_types_, &input_label_counts_,
+                                    &output_label_counts_, &input_has_ellipsis_,
+                                    &output_has_ellipsis_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    OpInputList inputs;
+    OP_REQUIRES_OK(ctx, ctx->input_list("inputs", &inputs));
+
+    OperandLabels input_labels(input_labels_);
+    Labels output_labels(output_labels_);
+    std::vector<DimensionType> label_types(label_types_);
+    OperandLabelCounts input_label_counts(input_label_counts_);
+    LabelCounts output_label_counts(output_label_counts_);
+    LabelToDimSizes label_to_dim_sizes;
+
+    OP_REQUIRES_OK(ctx, ProcessDimensions(
+                            inputs, input_has_ellipsis_, output_has_ellipsis_,
+                            &input_labels, &output_labels, &label_types,
+                            &input_label_counts, &output_label_counts,
+                            &label_to_dim_sizes));
+
+    // The reduction phase (a) sums across reduction dimensions, (b) takes
+    // generalized diagonals, and (c) reshapes it into shape
+    //   [(broadcasting) batch shape] + [F,C]
+    // where F and C denote the total (compacted) size of free and contract
+    // dimensions, respectively.
+    const int num_inputs = inputs.size();
+    OperandLabels free_labels(num_inputs);
+    gtl::InlinedVector<Tensor, 2> inputs_reduced(num_inputs);
+    gtl::InlinedVector<bool, 2> swap_free_and_contract(num_inputs);
+    for (int i = 0; i < num_inputs; ++i) {
+      OP_REQUIRES_OK(ctx,
+                     ReduceOperand<Device, T>(
+                         ctx, inputs[i], label_types, input_label_counts[i],
+                         &input_labels[i], &free_labels[i],
+                         &swap_free_and_contract[i], &inputs_reduced[i]));
+    }
+
+    // After reduction, the inputs should be reshaped to Tensors suitable for
+    // contraction. If num_inputs is 1, the reduced input is simply forwarded to
+    // the output.
+    Tensor contraction_output_reshaped;
+    OP_REQUIRES_OK(ctx, ContractOperands<Device, T>(
+                            ctx, inputs_reduced, swap_free_and_contract,
+                            &contraction_output_reshaped));
+
+    // Copy the batch labels from the contraction output. Recover the batch
+    // shape, which may have been broadcasted.
+    TensorShape result_shape = contraction_output_reshaped.shape();
+    result_shape.RemoveLastDims(2);
+
+    int num_labels = label_types.size();
+    Labels result_labels;
+    // All batch dimensions should be present in the contracted result. First
+    // the broadcasting dimensions, then the named batch dimensions.
+    for (int label = 0; label < num_labels; ++label) {
+      if (label_types[label] == kBroadcasting) result_labels.push_back(label);
+    }
+    for (int label = 0; label < num_labels; ++label) {
+      if (label_types[label] == kBatch) result_labels.push_back(label);
+    }
+    for (int i = 0; i < num_inputs; ++i) {
+      for (int label : free_labels[i]) {
+        result_labels.push_back(label);
+        result_shape.AddDim(label_to_dim_sizes[label]);
+      }
+    }
+
+    // Reshape the contraction (or reduction) result to its expanded shape:
+    // [(broadcasted) batch shape] + [free shape 0] + [free shape 1].
+    Tensor contraction_output;
+    OP_REQUIRES_OK(ctx, CopyFrom(contraction_output_reshaped, result_shape,
+                                 &contraction_output));
+
+    // Inflate the output if necessary. (E.g. for the equation 'i->iii' which
+    // may arise while computing gradient of a regular Einsum).
+    // TODO(anudhyan): It's possible that Eigen's contract and inflate can be
+    // chained here to avoid materializing an intermediate.
+    Tensor output_inflated;
+    OP_REQUIRES_OK(
+        ctx, StrideOrInflate<Device, T>(
+                 ctx, contraction_output, result_labels, output_label_counts,
+                 true /* should_inflate */, &output_inflated));
+    if (output_inflated.dims() > contraction_output.dims()) {
+      // We inflated the output. Modify result labels accordingly.
+      Labels inflated_labels;
+      for (int label : result_labels) {
+        inflated_labels.insert(inflated_labels.end(),
+                               output_label_counts[label], label);
+      }
+      result_labels.swap(inflated_labels);
+    }
+    // Find the permutation to map the result labels to the output labels. Note
+    // that both the result and the final output may have the repeated labels,
+    // in which case the permutation preserves the left-to-right ordering.
+    // E.g. if result labels are [0, 0, 1] and output is [0, l, 0] then the
+    // permutation should be [0, 2, 1]. We also use the fact that repeated
+    // labels in the result are adjacent to each other.
+    std::vector<int> output_permutation(output_labels.size());
+    std::vector<int> label_to_position(num_labels, -1);
+    for (int i = 0; i < result_labels.size(); ++i) {
+      // Remember the position of only the leftmost result label.
+      if (label_to_position[result_labels[i]] == -1) {
+        label_to_position[result_labels[i]] = i;
+      }
+    }
+    for (int i = 0; i < output_labels.size(); ++i) {
+      output_permutation[i] = label_to_position[output_labels[i]];
+      // We have found the leftmost occurrence. The next one would be adjacent.
+      label_to_position[output_labels[i]] += 1;
+    }
+    Tensor output;
+    OP_REQUIRES_OK(ctx, TransposeOperand<Device, T>(
+                            ctx, output_inflated, output_permutation, &output));
+    ctx->set_output(0, output);
+  }
+
+ private:
+  OperandLabels input_labels_;
+  Labels output_labels_;
+  std::vector<DimensionType> label_types_;
+  OperandLabelCounts input_label_counts_;
+  LabelCounts output_label_counts_;
+  gtl::InlinedVector<bool, 2> input_has_ellipsis_;
+  bool output_has_ellipsis_ = false;
+};
+
+#define REGISTER_EINSUM(D, TYPE)                                   \
+  REGISTER_KERNEL_BUILDER(                                         \
+      Name("Einsum").Device(DEVICE_##D).TypeConstraint<TYPE>("T"), \
+      EinsumOp<D##Device, TYPE>);
+
+// TODO(anudhyan): Also register GPU kernels for Einsum.
+#define REGISTER_CPU(TYPE) REGISTER_EINSUM(CPU, TYPE)
+TF_CALL_float(REGISTER_CPU);
+TF_CALL_double(REGISTER_CPU);
+TF_CALL_complex64(REGISTER_CPU);
+TF_CALL_complex128(REGISTER_CPU);
+#undef REGISTER_CPU
+#undef REGISTER_EINSUM
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 51ab3e268c6..f037d38ef81 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -474,6 +474,14 @@ REGISTER_OP("TridiagonalSolve")
     .Attr("T: {double, float, complex64, complex128}")
     .SetShapeFn(TridiagonalSolveShapeFn);
 
+REGISTER_OP("Einsum")
+    .Input("inputs: N * T")
+    .Output("output: T")
+    .Attr("equation: string")
+    .Attr("N: int >= 1")
+    .Attr("T: type")
+    .SetShapeFn(shape_inference::EinsumShape);
+
 // Deprecated op registrations:
 
 // Can be deleted after 3feb2017.
diff --git a/tensorflow/core/util/einsum_op_util.cc b/tensorflow/core/util/einsum_op_util.cc
new file mode 100644
index 00000000000..ec339a5128a
--- /dev/null
+++ b/tensorflow/core/util/einsum_op_util.cc
@@ -0,0 +1,47 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/einsum_op_util.h"
+
+#include <string>
+
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+
+Status ParseEinsumEquation(const string& equation,
+                           gtl::InlinedVector<string, 2>* input_subscripts,
+                           string* output_subscript) {
+  gtl::InlinedVector<string, 2> inputs_and_output_subscripts =
+      absl::StrSplit(equation, "->");
+  if (inputs_and_output_subscripts.size() != 2) {
+    return errors::InvalidArgument(
+        "Expecting exactly one '->' in einsum equation: ", equation);
+  }
+  *output_subscript = std::move(inputs_and_output_subscripts[1]);
+  *input_subscripts =
+      absl::StrSplit(std::move(inputs_and_output_subscripts[0]), ',');
+  if (input_subscripts->size() != 1 && input_subscripts->size() != 2) {
+    return errors::InvalidArgument(
+        "Expecting 1 or 2 input subscripts in equation '", equation,
+        "' but got: ", input_subscripts->size());
+  }
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/util/einsum_op_util.h b/tensorflow/core/util/einsum_op_util.h
new file mode 100644
index 00000000000..f12af7bb7b8
--- /dev/null
+++ b/tensorflow/core/util/einsum_op_util.h
@@ -0,0 +1,29 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
+#define TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+
+namespace tensorflow {
+
+// Parses and validates an einsum equation in explicit form.
+Status ParseEinsumEquation(const string& equation,
+                           gtl::InlinedVector<string, 2>* input_subscripts,
+                           string* output_subscript);
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_UTIL_EINSUM_OP_UTIL_H_
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 3399bdc9992..df8633236c5 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2136,6 +2136,23 @@ cuda_py_test(
     xla_enable_strict_auto_jit = True,
 )
 
+cuda_py_test(
+    name = "einsum_op_test",
+    size = "small",
+    srcs = ["einsum_op_test.py"],
+    additional_deps = [
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:linalg_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:platform",
+        "//tensorflow/python/ops/linalg",
+    ],
+    xla_enable_strict_auto_jit = True,
+)
+
 cuda_py_test(
     name = "manip_ops_test",
     size = "small",
diff --git a/tensorflow/python/kernel_tests/einsum_op_test.py b/tensorflow/python/kernel_tests/einsum_op_test.py
new file mode 100644
index 00000000000..b51b91ddbf4
--- /dev/null
+++ b/tensorflow/python/kernel_tests/einsum_op_test.py
@@ -0,0 +1,251 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.Einsum."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import special_math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
+from tensorflow.python.platform import test
+
+
+class EinsumOpTest(test.TestCase):
+
+  def _check(self, s, *input_shapes, **kwargs):
+    dtype = kwargs.pop('dtype', np.float32)
+    r = np.random.RandomState(0)
+    inputs = []
+    for shape in input_shapes:
+      arr = np.array(r.randn(*shape)).astype(dtype)
+      if dtype == np.complex64 or dtype == np.complex128:
+        arr += 1j * np.array(r.randn(*shape)).astype(dtype)
+      inputs.append(arr)
+    input_tensors = [constant_op.constant(x, shape=x.shape) for x in inputs]
+    a = np.einsum(s, *inputs)
+    b = self.evaluate(gen_linalg_ops.einsum(input_tensors, s))
+    self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
+
+  def testUnary(self):
+    self._check('->', ())
+    self._check('aa->', (3, 3))
+    self._check('aa->a', (3, 3))
+    self._check('aaa->', (3, 3, 3))
+    self._check('aaa->a', (3, 3, 3))
+    self._check('aab->a', (3, 3, 4))
+    self._check('ab->', (3, 3))
+    self._check('ab->ab', (3, 3))
+    self._check('abc->b', (3, 4, 5))
+    self._check('abc->ca', (3, 4, 5))
+    self._check('abc->cab', (3, 4, 5))
+    self._check('aabcc->a', (3, 3, 5, 4, 4))
+    self._check('aabcc->ac', (3, 3, 5, 4, 4))
+    self._check('aabcd->ad', (3, 3, 5, 4, 4))
+
+  def testUnaryEllipsis(self):
+    self._check('...->...', ())
+    self._check('...->', ())
+    self._check('->...', ())
+
+    # Tests from dask
+    self._check('a...a->a...', (2, 2))
+    self._check('a...a->', (2, 2))
+    self._check('a...a->...', (2, 5, 1, 2))
+    self._check('a...a->a...', (2, 1, 2))
+    self._check('a...a->a...', (2, 3, 4, 5, 2))
+
+    self._check('...ijk->...ki', (3, 4, 5))
+    self._check('...ijk->...ki', (1, 3, 4, 5))
+    self._check('...ijk->...ki', (2, 2, 3, 4, 5))
+
+    # Repeated indices.
+    self._check('i...ii->...i', (3, 2, 3, 3))
+
+  def testBinary(self):
+    self._check(',->', (), ())
+    self._check('a,a->', (3,), (3,))
+    self._check('a,a->a', (3,), (3,))
+    self._check('ba,b->', (3, 2), (3,))
+    self._check('ab,b->a', (3, 4), (4,))
+    self._check('ab,ab->', (3, 4), (3, 4))
+    self._check('nij,jk->nik', (5, 2, 3), (3, 4))
+    self._check('abc,bad->abcd', (1, 2, 3), (2, 1, 4))
+    # Repeated indices.
+    self._check('ijj,k->ik', (2, 3, 3), (4,))
+    self._check('aba,a->b', (3, 4, 3), (3,))
+    # From https://github.com/dask/dask/pull/3412#discussion_r182413444
+    self._check('aab,bc->ac', (2, 2, 3), (3, 4))
+    self._check('aab,bcc->ac', (2, 2, 3), (3, 4, 4))
+    # Based on https://github.com/google/jax/issues/37#issuecomment-448572187
+    self._check('sa,shb->shab', (2, 1), (2, 3, 4))
+
+  def testBroadcasting(self):
+    # Batch matmul without broadcasting.
+    self._check('...ij,...jk->...ik', (5, 1, 2, 3), (5, 1, 3, 4))
+    # Batch matmul with broadcasting.
+    self._check('...ij,...jk->...ik', (1, 2, 3), (3, 5))
+    self._check('...ij,...jk->...ik', (2, 3), (1, 3, 5))
+    self._check('...ij,...jk->...ik', (5, 2, 3), (3, 5))
+    self._check('...ij,...jk->...ik', (2, 3), (5, 3, 5))
+    self._check('...ij,...jk->...ik', (3, 1, 2, 3), (1, 1, 7, 3, 5))
+    self._check('i...j,j...k->...ik', (2, 1, 3, 1, 3), (3, 1, 7, 5))
+    # Broadcasting with repeated indices.
+    self._check('ij,jk...k->i...', (3, 2), (2, 4, 1, 4))
+    self._check('ij,jk...k->...i', (3, 2), (2, 4, 5, 4))
+    self._check('ijj,jk...k->i...', (3, 2, 2), (2, 4, 1, 4))
+    self._check('i...jj,jk...k->i...', (3, 3, 1, 2, 2), (2, 4, 1, 5, 4))
+    # Following 2 from # https://stackoverflow.com/a/19203475/1611416
+    self._check('...abc,...abcd->...d', (1, 1, 2, 3, 4), (5, 2, 3, 4, 6))
+    self._check('ab...,b->ab...', (2, 3, 1, 1, 5), (3,))
+
+  def testDtypes(self):
+    for dtype in [np.float64, np.float32, np.complex64, np.complex128]:
+      self._check('ij,jk->ik', (2, 2), (2, 2), dtype=dtype)
+      self._check('ji,jk->ik', (2, 2), (2, 2), dtype=dtype)
+      self._check('ji,kj->ik', (2, 2), (2, 2), dtype=dtype)
+      self._check('ij,jk->ki', (2, 2), (2, 2), dtype=dtype)
+      self._check('ji,kj->ki', (2, 2), (2, 2), dtype=dtype)
+
+  @test_util.run_in_graph_and_eager_modes
+  def testInvalid(self):
+    r = np.random.RandomState(0)
+    cases = [
+        # incorrect rank.
+        ('ij,jk->ik', r.randn(1, 2, 3), r.randn(3, 4)),
+        ('...ij,jk->ik', r.randn(3), r.randn(3, 4)),
+        # inconsistent dimensions.
+        ('ij,jk->ik', r.randn(2, 3), r.randn(4, 4)),
+        # broadcasting is invalid
+        ('...ij,...jk->...ik', r.randn(5, 2, 3), r.randn(7, 3, 4)),
+        # output should have ellipsis when broadcasting shape is
+        # non-empty.
+        ('...ij,...jk->ik', r.randn(2, 2, 3), r.randn(3, 4)),
+    ]
+    for args in cases:
+      with self.assertRaises((ValueError, errors.InvalidArgumentError)):
+        _ = self.evaluate(gen_linalg_ops.einsum(args[1:], args[0]))
+
+      placeholders = [
+          array_ops.placeholder_with_default(x, shape=None) for x in args[1:]
+      ]
+      with self.assertRaises((ValueError, errors.InvalidArgumentError)):
+        _ = self.evaluate(gen_linalg_ops.einsum(placeholders, args[0]))
+
+  @test_util.run_in_graph_and_eager_modes
+  def testPlaceholder(self):
+
+    def check(equation, *input_and_placeholder_shapes):
+      r = np.random.RandomState(0)
+      inputs = []
+      input_placeholders = []
+      for actual_shape, placeholder_shape in input_and_placeholder_shapes:
+        input_np = np.array(r.randn(*actual_shape))
+        inputs.append(input_np)
+        input_placeholders.append(
+            array_ops.placeholder_with_default(input_np, placeholder_shape))
+
+      a = np.einsum(equation, *inputs)
+      b = self.evaluate(gen_linalg_ops.einsum(input_placeholders, equation))
+      self.assertAllClose(a, b, atol=1e-4, rtol=1e-4)
+
+    check('bijl,bjkm->bik', ((9, 2, 3, 5), (None, None, None, 5)),
+          ((9, 3, 4, 7), (None, None, 4, None)))
+    check('bijl,bjkm->bik', ((9, 2, 3, 5), None), ((9, 3, 4, 7), None))
+    check('...ij,...->...i', ((4, 3, 1, 2), (None, 3, None, 2)),
+          ((4, 3), (None, 3)))
+    check('...ij,...jk->...ik', ((3, 1, 2, 3), None), ((1, 7, 3, 4), None))
+
+  def testOutputRepeatedLabels(self):
+    # This is the reverse operation of repeated input labels, to be used for
+    # computing symbolic gradients of einsum.
+    r = np.random.RandomState(0)
+    a = r.randn(2, 2)
+    s = 'a->aa'
+    diag_a = np.diag(np.diag(a))
+    b = self.evaluate(gen_linalg_ops.einsum([np.diag(a)], s))
+    self.assertAllClose(diag_a, b, atol=1e-4, rtol=1e-4)
+
+
+class EinsumBenchmark(test.Benchmark):
+  cases = [
+      # Unary cases.
+      ['ijk->i', 100],
+      ['ijk->kji', 100],
+      # Regular matmul or batch matmul.
+      ['ij,jk->ik', 1000],
+      ['ji,kj->ik', 1000],
+      ['ab,ab->', 100],
+      ['ab,ba->', 100],
+      ['abc,abc->', 100],
+      ['abc,bac->', 100],
+      ['abc,cba->', 100],
+      ['bij,bjk->bik', 100],
+      ['bji,bjk->bki', 100],
+      ['ikl,kji->kl', 100],
+      ['klj,lki->ij', 100],
+      ['ijk,ilj->kli', 100],
+      ['kij,mkb->ijmb', 100],
+      ['abcd,ad->bc', 40],
+      # Larger binary contractions.
+      ['ijk,jklm->il', 40],
+      ['efabc,eabcd->efd', 30],
+      ['fabec,abcde->fde', 30],
+      ['efabc,edabc->efd', 30],
+      ['eadbf,dfebc->ecfad', 30],
+      ['abcdef,bcdfg->abcdeg', 30],
+  ]
+
+  def benchmarkEinsum(self):
+    for equation, dim in self.cases:
+      with ops.Graph().as_default(), \
+          session.Session(config=benchmark.benchmark_config()) as sess, \
+          ops.device('/cpu:0'):
+        r = np.random.RandomState(0)
+        input_subscripts = equation.split('->')[0].split(',')
+        input_vars = []
+        for subscript in input_subscripts:
+          input_shape = (dim,) * len(subscript)
+          input_vars.append(
+              variables.Variable(np.array(r.randn(*input_shape), np.float32)))
+        variables.global_variables_initializer().run()
+
+        # Call einsum_v1.
+        self.run_op_benchmark(
+            sess,
+            special_math_ops.einsum(equation, *input_vars),
+            min_iters=50,
+            name='einsum_v1_cpu_({})_{}'.format(equation, dim))
+
+        # Call gen_linalg_ops.einsum.
+        self.run_op_benchmark(
+            sess,
+            gen_linalg_ops.einsum(input_vars, equation),
+            min_iters=50,
+            name='einsum_v2_cpu_({})_{}'.format(equation, dim))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 53cd28b9dce..ca661f2a572 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1000,6 +1000,10 @@ tf_module {
     name: "EditDistance"
     argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
+  member_method {
+    name: "Einsum"
+    argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "Elu"
     argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 53cd28b9dce..ca661f2a572 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1000,6 +1000,10 @@ tf_module {
     name: "EditDistance"
     argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
+  member_method {
+    name: "Einsum"
+    argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "Elu"
     argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "