diff --git a/tensorflow/core/api_def/base_api/api_def_Eig.pbtxt b/tensorflow/core/api_def/base_api/api_def_Eig.pbtxt
new file mode 100644
index 00000000000..b85082c0cc8
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Eig.pbtxt
@@ -0,0 +1,45 @@
+op {
+  graph_op_name: "Eig"
+  endpoint {
+    name: "Eig"
+  }
+  in_arg {
+    name: "input"
+    description: <<END
+`Tensor` input of shape `[N, N]`.
+END
+  }
+  out_arg {
+    name: "e"
+    description: <<END
+Eigenvalues. Shape is `[N]`.
+END
+  }
+  out_arg {
+    name: "v"
+    description: <<END
+Eigenvectors. Shape is `[N, N]`.
+END
+  }
+  attr {
+    name: "compute_v"
+    description: <<END
+If `True` then eigenvectors will be computed and returned in `v`.
+Otherwise, only the eigenvalues will be computed.
+END
+  }
+  summary: "Computes the eigen decomposition of one or more square matrices."
+  description: <<END
+Computes the eigenvalues and (optionally) right eigenvectors of each inner matrix in
+`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`. The eigenvalues
+are sorted in non-decreasing order.
+
+```python
+# a is a tensor.
+# e is a tensor of eigenvalues.
+# v is a tensor of eigenvectors.
+e, v = eig(a)
+e = eig(a, compute_v=False)
+```
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Eig.pbtxt b/tensorflow/core/api_def/python_api/api_def_Eig.pbtxt
new file mode 100644
index 00000000000..08a413a9941
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Eig.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "Eig"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index d5607f641af..e0625f3a330 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3361,6 +3361,7 @@ cc_library(
         ":cholesky_grad",
         ":cholesky_op",
         ":determinant_op",
+        ":eig_op",
         ":einsum_op",
         ":lu_op",
         ":matrix_exponential_op",
@@ -3473,6 +3474,15 @@ tf_kernel_library(
     ]),
 )
 
+tf_kernel_library(
+    name = "eig_op",
+    prefix = "eig_op",
+    deps = LINALG_DEPS + ["//tensorflow/core:lib_internal"] + if_cuda([
+        ":cast_op",
+        ":cwise_op",
+    ]),
+)
+
 tf_kernel_library(
     name = "matrix_inverse_op",
     prefix = "matrix_inverse_op",
diff --git a/tensorflow/core/kernels/eig_op_complex128.cc b/tensorflow/core/kernels/eig_op_complex128.cc
new file mode 100644
index 00000000000..988cc2f98d9
--- /dev/null
+++ b/tensorflow/core/kernels/eig_op_complex128.cc
@@ -0,0 +1,22 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eig_op_impl.h"
+
+namespace tensorflow {
+
+REGISTER_LINALG_OP("Eig", (EigOp<complex128, complex128>), complex128);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/eig_op_complex64.cc b/tensorflow/core/kernels/eig_op_complex64.cc
new file mode 100644
index 00000000000..6a3f7928715
--- /dev/null
+++ b/tensorflow/core/kernels/eig_op_complex64.cc
@@ -0,0 +1,22 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eig_op_impl.h"
+
+namespace tensorflow {
+
+REGISTER_LINALG_OP("Eig", (EigOp<complex64, complex64>), complex64);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/eig_op_double.cc b/tensorflow/core/kernels/eig_op_double.cc
new file mode 100644
index 00000000000..2cd931cc135
--- /dev/null
+++ b/tensorflow/core/kernels/eig_op_double.cc
@@ -0,0 +1,22 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eig_op_impl.h"
+
+namespace tensorflow {
+
+REGISTER_LINALG_OP("Eig", (EigOp<double, complex128>), double);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/eig_op_float.cc b/tensorflow/core/kernels/eig_op_float.cc
new file mode 100644
index 00000000000..a06f76e935f
--- /dev/null
+++ b/tensorflow/core/kernels/eig_op_float.cc
@@ -0,0 +1,22 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/eig_op_impl.h"
+
+namespace tensorflow {
+
+REGISTER_LINALG_OP("Eig", (EigOp<float, complex64>), float);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/eig_op_impl.h b/tensorflow/core/kernels/eig_op_impl.h
new file mode 100644
index 00000000000..4ebb6bde08b
--- /dev/null
+++ b/tensorflow/core/kernels/eig_op_impl.h
@@ -0,0 +1,98 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
+
+// See docs in ../ops/linalg_ops.cc.
+
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/Eigen/Eigenvalues"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/linalg_ops_common.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/denormal.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+template <class InputScalar, class OutputScalar>
+class EigOp : public LinearAlgebraOp<InputScalar, OutputScalar> {
+ public:
+  typedef LinearAlgebraOp<InputScalar, OutputScalar> Base;
+
+  explicit EigOp(OpKernelConstruction* context) : Base(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("compute_v", &compute_v_));
+  }
+
+  using TensorShapes = typename Base::TensorShapes;
+  using InputMatrix = typename Base::InputMatrix;
+  using InputMatrixMaps = typename Base::InputMatrixMaps;
+  using InputConstMatrixMap = typename Base::InputConstMatrixMap;
+  using InputConstMatrixMaps = typename Base::InputConstMatrixMaps;
+
+  using OutputMatrix = typename Base::OutputMatrix;
+  using OutputMatrixMaps = typename Base::OutputMatrixMaps;
+  using OutputConstMatrixMap = typename Base::OutputConstMatrixMap;
+  using OutputConstMatrixMaps = typename Base::OutputConstMatrixMaps;
+
+  TensorShapes GetOutputMatrixShapes(
+      const TensorShapes& input_matrix_shapes) const final {
+    int64 n = input_matrix_shapes[0].dim_size(0);
+    if (compute_v_) {
+      return TensorShapes({TensorShape({n}), TensorShape({n, n})});
+    } else {
+      return TensorShapes({TensorShape({n})});
+    }
+  }
+
+  void ComputeMatrix(OpKernelContext* context,
+                     const InputConstMatrixMaps& inputs,
+                     OutputMatrixMaps* outputs) final {
+    const int64 rows = inputs[0].rows();
+    if (rows == 0) {
+      // If X is an empty matrix (0 rows, 0 col), X * X' == X.
+      // Therefore, we return X.
+      return;
+    }
+
+    // This algorithm relies on denormals, so switch them back on locally.
+    port::ScopedDontFlushDenormal dont_flush_denormals;
+
+    Eigen::ComplexEigenSolver<OutputMatrix> eig(
+        inputs[0],
+        compute_v_ ? Eigen::ComputeEigenvectors : Eigen::EigenvaluesOnly);
+    // TODO(rmlarsen): Output more detailed error info on failure.
+    OP_REQUIRES(
+        context, eig.info() == Eigen::Success,
+        errors::InvalidArgument("Eigen decomposition was not "
+                                "successful. The input might not be valid."));
+
+    outputs->at(0) = eig.eigenvalues().template cast<OutputScalar>();
+    if (compute_v_) {
+      outputs->at(1) = eig.eigenvectors();
+    }
+  }
+
+ private:
+  bool compute_v_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_CORE_KERNELS_EIG_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc
index b58bcf58348..3836ff796eb 100644
--- a/tensorflow/core/kernels/linalg_ops_common.cc
+++ b/tensorflow/core/kernels/linalg_ops_common.cc
@@ -29,8 +29,8 @@ limitations under the License.
 namespace tensorflow {
 
 // static
-template <typename Scalar>
-void LinearAlgebraOp<Scalar>::ValidateSingleMatrix(
+template <class InputScalar, class OutputScalar>
+void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleMatrix(
     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
   OP_REQUIRES(context, input_matrix_shapes.size() == 1,
               errors::InvalidArgument("Expected a single input matrix, got %d.",
@@ -40,8 +40,8 @@ void LinearAlgebraOp<Scalar>::ValidateSingleMatrix(
 }
 
 // static
-template <typename Scalar>
-void LinearAlgebraOp<Scalar>::ValidateSingleSquareMatrix(
+template <class InputScalar, class OutputScalar>
+void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSingleSquareMatrix(
     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
   OP_REQUIRES(context, input_matrix_shapes.size() == 1,
               errors::InvalidArgument("Expected a single input matrix, got %d.",
@@ -51,8 +51,8 @@ void LinearAlgebraOp<Scalar>::ValidateSingleSquareMatrix(
 }
 
 // static
-template <typename Scalar>
-void LinearAlgebraOp<Scalar>::ValidateSolver(
+template <class InputScalar, class OutputScalar>
+void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSolver(
     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
   OP_REQUIRES(context, input_matrix_shapes.size() == 2,
               errors::InvalidArgument("Expected two input matrices, got %d.",
@@ -68,8 +68,8 @@ void LinearAlgebraOp<Scalar>::ValidateSolver(
 }
 
 // static
-template <typename Scalar>
-void LinearAlgebraOp<Scalar>::ValidateSquareSolver(
+template <class InputScalar, class OutputScalar>
+void LinearAlgebraOp<InputScalar, OutputScalar>::ValidateSquareSolver(
     OpKernelContext* context, const TensorShapes& input_matrix_shapes) {
   OP_REQUIRES(context, input_matrix_shapes.size() == 2,
               errors::InvalidArgument("Expected two input matrices, got %d.",
@@ -85,8 +85,9 @@ void LinearAlgebraOp<Scalar>::ValidateSquareSolver(
       errors::InvalidArgument("Input matrix and rhs are incompatible."));
 }
 
-template <typename Scalar>
-void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) {
+template <class InputScalar, class OutputScalar>
+void LinearAlgebraOp<InputScalar, OutputScalar>::Compute(
+    OpKernelContext* context) {
   TensorInputs inputs;
   TensorShapes input_matrix_shapes;
   TensorShape batch_shape;
@@ -110,11 +111,10 @@ void LinearAlgebraOp<Scalar>::Compute(OpKernelContext* context) {
         batch_shape.num_elements(), GetCostPerUnit(input_matrix_shapes), shard);
 }
 
-template <typename Scalar>
-void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
-                                            TensorInputs* inputs,
-                                            TensorShapes* input_matrix_shapes,
-                                            TensorShape* batch_shape) {
+template <class InputScalar, class OutputScalar>
+void LinearAlgebraOp<InputScalar, OutputScalar>::AnalyzeInputs(
+    OpKernelContext* context, TensorInputs* inputs,
+    TensorShapes* input_matrix_shapes, TensorShape* batch_shape) {
   int input_rank = -1;
   for (int i = 0; i < NumMatrixInputs(context); ++i) {
     const Tensor& in = context->input(i);
@@ -155,8 +155,8 @@ void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
   ValidateInputMatrixShapes(context, *input_matrix_shapes);
 }
 
-template <typename Scalar>
-void LinearAlgebraOp<Scalar>::PrepareOutputs(
+template <class InputScalar, class OutputScalar>
+void LinearAlgebraOp<InputScalar, OutputScalar>::PrepareOutputs(
     OpKernelContext* context, const TensorShapes& input_matrix_shapes,
     const TensorShape& batch_shape, TensorOutputs* outputs,
     TensorShapes* output_matrix_shapes) {
@@ -214,22 +214,22 @@ void LinearAlgebraOp<Scalar>::PrepareOutputs(
   }
 }
 
-template <typename Scalar>
-void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
+template <class InputScalar, class OutputScalar>
+void LinearAlgebraOp<InputScalar, OutputScalar>::ComputeTensorSlice(
     OpKernelContext* context, int64 matrix_index, const TensorInputs& inputs,
     const TensorShapes& input_matrix_shapes, const TensorOutputs& outputs,
     const TensorShapes& output_matrix_shapes) {
-  ConstMatrixMaps matrix_inputs;
+  InputConstMatrixMaps matrix_inputs;
   for (size_t i = 0; i < inputs.size(); ++i) {
     // TODO(kalakris): Handle alignment if possible. Eigen::Map is
     // unaligned by default.
     matrix_inputs.emplace_back(
-        inputs[i]->flat<Scalar>().data() +
+        inputs[i]->flat<InputScalar>().data() +
             matrix_index * input_matrix_shapes[i].num_elements(),
         input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
   }
 
-  MatrixMaps matrix_outputs;
+  OutputMatrixMaps matrix_outputs;
   for (size_t i = 0; i < output_matrix_shapes.size(); ++i) {
     // The output matrix shape may not be a matrix.
     int num_output_rows = output_matrix_shapes[i].dims() >= 1
@@ -239,7 +239,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
                               ? output_matrix_shapes[i].dim_size(1)
                               : 1;
     matrix_outputs.emplace_back(
-        outputs[i]->flat<Scalar>().data() +
+        outputs[i]->flat<OutputScalar>().data() +
             matrix_index * output_matrix_shapes[i].num_elements(),
         num_output_rows, num_output_cols);
   }
@@ -251,5 +251,7 @@ template class LinearAlgebraOp<float>;
 template class LinearAlgebraOp<double>;
 template class LinearAlgebraOp<complex64>;
 template class LinearAlgebraOp<complex128>;
+template class LinearAlgebraOp<float, complex64>;
+template class LinearAlgebraOp<double, complex128>;
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h
index 11ecf7d676e..65c2fb90f0e 100644
--- a/tensorflow/core/kernels/linalg_ops_common.h
+++ b/tensorflow/core/kernels/linalg_ops_common.h
@@ -36,7 +36,7 @@ limitations under the License.
 namespace tensorflow {
 
 // Base class for linear algebra operators.
-template <typename Scalar>
+template <class InputScalar, class OutputScalar = InputScalar>
 class LinearAlgebraOp : public OpKernel {
  public:
   explicit LinearAlgebraOp(OpKernelConstruction* context) : OpKernel(context) {}
@@ -109,6 +109,28 @@ class LinearAlgebraOp : public OpKernel {
   // and expect the kernel to perform the computation inplace.
   virtual bool EnableInputForwarding() const { return true; }
 
+  using InputMatrix = Eigen::Matrix<InputScalar, Eigen::Dynamic, Eigen::Dynamic,
+                                    Eigen::RowMajor>;
+  using InputConstMatrixMap = Eigen::Map<const InputMatrix>;
+  using InputMatrixMap = Eigen::Map<InputMatrix>;
+  using InputConstVectorMap =
+      Eigen::Map<const Eigen::Matrix<InputScalar, 1, Eigen::Dynamic>>;
+  using InputConstMatrixMaps = gtl::InlinedVector<InputConstMatrixMap, 4>;
+  using InputMatrixMaps = gtl::InlinedVector<InputMatrixMap, 4>;
+  using InputRealScalar = typename Eigen::NumTraits<InputScalar>::Real;
+
+  using OutputMatrix = Eigen::Matrix<OutputScalar, Eigen::Dynamic,
+                                     Eigen::Dynamic, Eigen::RowMajor>;
+  using OutputConstMatrixMap = Eigen::Map<const OutputMatrix>;
+  using OutputMatrixMap = Eigen::Map<OutputMatrix>;
+  using OutputConstVectorMap =
+      Eigen::Map<const Eigen::Matrix<OutputScalar, 1, Eigen::Dynamic>>;
+  using OutputConstMatrixMaps = gtl::InlinedVector<OutputConstMatrixMap, 4>;
+  using OutputMatrixMaps = gtl::InlinedVector<OutputMatrixMap, 4>;
+  using OutputRealScalar = typename Eigen::NumTraits<OutputScalar>::Real;
+
+  // backward compatibility
+  using Scalar = OutputScalar;
   using Matrix =
       Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
   using ConstMatrixMap = Eigen::Map<const Matrix>;
@@ -126,8 +148,8 @@ class LinearAlgebraOp : public OpKernel {
   // parallelized. The number of threads used is determined by a cost model from
   // the value returned by GetCostPerUnit().
   virtual void ComputeMatrix(OpKernelContext* context,
-                             const ConstMatrixMaps& inputs,
-                             MatrixMaps* outputs) = 0;
+                             const InputConstMatrixMaps& inputs,
+                             OutputMatrixMaps* outputs) = 0;
 
  private:
   using TensorInputs = gtl::InlinedVector<const Tensor*, 4>;
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index f037d38ef81..4572df279b7 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -383,6 +383,15 @@ REGISTER_OP("SelfAdjointEig")
       return Status::OK();
     });
 
+REGISTER_OP("Eig")
+    .Input("input: T")
+    .Output("e: Tout")
+    .Output("v: Tout")
+    .Attr("compute_v: bool = True")
+    .Attr("T: {float, double, complex64, complex128}")
+    .Attr("Tout: {complex64, complex128}")
+    .SetShapeFn(SelfAdjointEigV2ShapeFn);
+
 REGISTER_OP("SelfAdjointEigV2")
     .Input("input: T")
     .Output("e: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 617634980dd..1da4cef1557 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -12206,6 +12206,50 @@ op {
     type: "type"
   }
 }
+op {
+  name: "Eig"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "e"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "v"
+    type_attr: "T"
+  }
+  attr {
+    name: "compute_v"
+    type: "bool"
+    default_value {
+      b: true
+    }
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+  attr {
+    name: "Tout"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_COMPLEX64
+        type: DT_COMPLEX128
+      }
+    }
+  }
+}
 op {
   name: "Elu"
   input_arg {
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index ca50ed1d566..7176b894246 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -3355,6 +3355,27 @@ cuda_py_test(
     shard_count = 20,
 )
 
+tf_py_test(
+    name = "eig_op_test",
+    size = "medium",
+    srcs = ["eig_op_test.py"],
+    additional_deps = [
+        "//third_party/py/numpy",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:framework_for_generated_wrappers",
+        "//tensorflow/python:linalg_ops",
+        "//tensorflow/python:math_ops",
+    ],
+    data = ["//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files"],
+    shard_count = 20,
+    tags = [
+        "no_rocm",  # flaky test
+        "no_windows",
+    ],
+    # b/127344411: xla_enable_strict_auto_jit = True,
+)
+
 cuda_py_test(
     name = "self_adjoint_eig_op_test",
     size = "medium",
diff --git a/tensorflow/python/kernel_tests/eig_op_test.py b/tensorflow/python/kernel_tests/eig_op_test.py
new file mode 100644
index 00000000000..ffc61b7bcfe
--- /dev/null
+++ b/tensorflow/python/kernel_tests/eig_op_test.py
@@ -0,0 +1,198 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.linalg_ops.eig."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+def _AddTest(test_class, op_name, testcase_name, fn):
+  test_name = "_".join(["test", op_name, testcase_name])
+  if hasattr(test_class, test_name):
+    raise RuntimeError("Test %s defined more than once" % test_name)
+  setattr(test_class, test_name, fn)
+
+
+class EigTest(test.TestCase):
+
+  @test_util.run_deprecated_v1
+  def testWrongDimensions(self):
+    # The input to self_adjoint_eig should be a tensor of
+    # at least rank 2.
+    scalar = constant_op.constant(1.)
+    with self.assertRaises(ValueError):
+      linalg_ops.eig(scalar)
+    vector = constant_op.constant([1., 2.])
+    with self.assertRaises(ValueError):
+      linalg_ops.eig(vector)
+
+  @test_util.run_deprecated_v1
+  def testConcurrentExecutesWithoutError(self):
+    all_ops = []
+    with self.session(use_gpu=True) as sess:
+      for compute_v_ in True, False:
+        matrix1 = random_ops.random_normal([5, 5], seed=42)
+        matrix2 = random_ops.random_normal([5, 5], seed=42)
+        if compute_v_:
+          e1, v1 = linalg_ops.eig(matrix1)
+          e2, v2 = linalg_ops.eig(matrix2)
+          all_ops += [e1, v1, e2, v2]
+        else:
+          e1 = linalg_ops.eigvals(matrix1)
+          e2 = linalg_ops.eigvals(matrix2)
+          all_ops += [e1, e2]
+      val = self.evaluate(all_ops)
+      self.assertAllEqual(val[0], val[2])
+      # The algorithm is slightly different for compute_v being True and False,
+      # so require approximate equality only here.
+      self.assertAllClose(val[2], val[4])
+      self.assertAllEqual(val[4], val[5])
+      self.assertAllEqual(val[1], val[3])
+
+  def testMatrixThatFailsWhenFlushingDenormsToZero(self):
+    # Test a 32x32 matrix which is known to fail if denorm floats are flushed to
+    # zero.
+    matrix = np.genfromtxt(
+        test.test_src_dir_path(
+            "python/kernel_tests/testdata/"
+            "self_adjoint_eig_fail_if_denorms_flushed.txt")).astype(np.float32)
+    self.assertEqual(matrix.shape, (32, 32))
+    matrix_tensor = constant_op.constant(matrix)
+    with self.session(use_gpu=True) as sess:
+      (e, v) = self.evaluate(linalg_ops.self_adjoint_eig(matrix_tensor))
+      self.assertEqual(e.size, 32)
+      self.assertAllClose(
+          np.matmul(v, v.transpose()), np.eye(32, dtype=np.float32), atol=2e-3)
+      self.assertAllClose(matrix,
+                          np.matmul(np.matmul(v, np.diag(e)), v.transpose()))
+
+
+def SortEigenValues(e):
+  perm = np.argsort(e.real + e.imag, -1)
+  return np.take(e, perm, -1)
+
+
+def SortEigenDecomposition(e, v):
+  if v.ndim < 2:
+    return e, v
+  else:
+    perm = np.argsort(e.real + e.imag, -1)
+    return np.take(e, perm, -1), np.take(v, perm, -1)
+
+
+def EquilibrateEigenVectorPhases(x, y):
+  """Equilibrate the phase of the Eigenvectors in the columns of `x` and `y`.
+
+  Eigenvectors are only unique up to an arbitrary phase. This function rotates x
+  such that it matches y. Precondition: The coluns of x and y differ by a
+  multiplicative complex phase factor only.
+
+  Args:
+    x: `np.ndarray` with Eigenvectors
+    y: `np.ndarray` with Eigenvectors
+
+  Returns:
+    `np.ndarray` containing an equilibrated version of x.
+  """
+  phases = np.sum(np.conj(x) * y, -2, keepdims=True)
+  phases /= np.abs(phases)
+  return phases * x
+
+
+def _GetEigTest(dtype_, shape_, compute_v_):
+
+  def CompareEigenVectors(self, x, y, tol):
+    x = EquilibrateEigenVectorPhases(x, y)
+    self.assertAllClose(x, y, atol=tol)
+
+  def CompareEigenDecompositions(self, x_e, x_v, y_e, y_v, tol):
+    num_batches = int(np.prod(x_e.shape[:-1]))
+    n = x_e.shape[-1]
+    x_e = np.reshape(x_e, [num_batches] + [n])
+    x_v = np.reshape(x_v, [num_batches] + [n, n])
+    y_e = np.reshape(y_e, [num_batches] + [n])
+    y_v = np.reshape(y_v, [num_batches] + [n, n])
+    for i in range(num_batches):
+      x_ei, x_vi = SortEigenDecomposition(x_e[i, :], x_v[i, :, :])
+      y_ei, y_vi = SortEigenDecomposition(y_e[i, :], y_v[i, :, :])
+      self.assertAllClose(x_ei, y_ei, atol=tol, rtol=tol)
+      CompareEigenVectors(self, x_vi, y_vi, tol)
+
+  def Test(self):
+    np.random.seed(1)
+    n = shape_[-1]
+    batch_shape = shape_[:-2]
+    np_dtype = dtype_.as_numpy_dtype
+    # most of matrices are diagonalizable # TODO
+    a = np.random.uniform(
+        low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
+    if dtype_.is_complex:
+      a += 1j * np.random.uniform(
+          low=-1.0, high=1.0, size=n * n).reshape([n, n]).astype(np_dtype)
+    a = np.tile(a, batch_shape + (1, 1))
+    if dtype_ in (dtypes_lib.float32, dtypes_lib.complex64):
+      atol = 1e-4
+    else:
+      atol = 1e-12
+    np_e, np_v = np.linalg.eig(a)
+    with self.session(use_gpu=True):
+      if compute_v_:
+        tf_e, tf_v = linalg_ops.eig(constant_op.constant(a))
+
+        # Check that V*diag(E)*V^(-1) is close to A.
+        a_ev = math_ops.matmul(
+            math_ops.matmul(tf_v, array_ops.matrix_diag(tf_e)),
+            linalg_ops.matrix_inverse(tf_v))
+        self.assertAllClose(self.evaluate(a_ev), a, atol=atol)
+
+        # Compare to numpy.linalg.eig.
+        CompareEigenDecompositions(self, np_e, np_v, self.evaluate(tf_e),
+                                   self.evaluate(tf_v), atol)
+      else:
+        tf_e = linalg_ops.eigvals(constant_op.constant(a))
+        self.assertAllClose(
+            SortEigenValues(np_e),
+            SortEigenValues(self.evaluate(tf_e)),
+            atol=atol)
+
+  return Test
+
+
+if __name__ == "__main__":
+  dtypes_to_test = [dtypes_lib.float32, dtypes_lib.float64]
+  if not test.is_built_with_rocm():
+    # ROCm does not support BLAS operations for complex types
+    dtypes_to_test += [dtypes_lib.complex64, dtypes_lib.complex128]
+  for compute_v in True, False:
+    for dtype in dtypes_to_test:
+      for size in 1, 2, 5, 10:
+        for batch_dims in [(), (3,)] + [(3, 2)] * (max(size, size) < 10):
+          shape = batch_dims + (size, size)
+          name = "%s_%s_%s" % (dtype.name, "_".join(map(str, shape)), compute_v)
+          _AddTest(EigTest, "Eig", name, _GetEigTest(dtype, shape, compute_v))
+          # No gradient yet
+  test.main()
diff --git a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
index a42d7922bfb..0ada446e84b 100644
--- a/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
+++ b/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
+"""Tests for tensorflow.ops.linalg_ops.self_adjoint_eig."""
 
 from __future__ import absolute_import
 from __future__ import division
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 914e5748534..e49434ffd4e 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -306,6 +306,62 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
         matrix, rhs, l2_regularizer, fast=fast, name=name)
 
 
+@tf_export('eig', 'linalg.eig', v1=[])
+def eig(tensor, name=None):
+  """Computes the eigen decomposition of a batch of matrices.
+
+  The eigenvalues
+  and eigenvectors for a non-Hermitian matrix in general are complex. The
+  eigenvectors are not guaranteed to be linearly independent.
+
+  Computes the eigenvalues and right eigenvectors of the innermost
+  N-by-N matrices in `tensor` such that
+  `tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i]`, for i=0...N-1.
+
+  Args:
+    tensor: `Tensor` of shape `[..., N, N]`. Only the lower triangular part of
+      each inner inner matrix is referenced.
+    name: string, optional name of the operation.
+
+  Returns:
+    e: Eigenvalues. Shape is `[..., N]`. Sorted in non-decreasing order.
+    v: Eigenvectors. Shape is `[..., N, N]`. The columns of the inner most
+      matrices contain eigenvectors of the corresponding matrices in `tensor`
+  """
+  if tensor.dtype == dtypes.float32 or tensor.dtype == dtypes.complex64:
+    out_dtype = dtypes.complex64
+  elif tensor.dtype == dtypes.float64 or tensor.dtype == dtypes.complex128:
+    out_dtype = dtypes.complex128
+  e, v = gen_linalg_ops.eig(tensor, Tout=out_dtype, compute_v=True, name=name)
+  return e, v
+
+
+@tf_export('eigvals', 'linalg.eigvals', v1=[])
+def eigvals(tensor, name=None):
+  """Computes the eigenvalues of one or more matrices.
+
+  Note: If your program backpropagates through this function, you should replace
+  it with a call to tf.linalg.eig (possibly ignoring the second output) to
+  avoid computing the eigen decomposition twice. This is because the
+  eigenvectors are used to compute the gradient w.r.t. the eigenvalues. See
+  _SelfAdjointEigV2Grad in linalg_grad.py.
+
+  Args:
+    tensor: `Tensor` of shape `[..., N, N]`.
+    name: string, optional name of the operation.
+
+  Returns:
+    e: Eigenvalues. Shape is `[..., N]`. The vector `e[..., :]` contains the `N`
+      eigenvalues of `tensor[..., :, :]`.
+  """
+  if tensor.dtype == dtypes.float32 or tensor.dtype == dtypes.complex64:
+    out_dtype = dtypes.complex64
+  elif tensor.dtype == dtypes.float64 or tensor.dtype == dtypes.complex128:
+    out_dtype = dtypes.complex128
+  e, _ = gen_linalg_ops.eig(tensor, Tout=out_dtype, compute_v=False, name=name)
+  return e
+
+
 @tf_export('linalg.eigh', v1=['linalg.eigh', 'self_adjoint_eig'])
 @deprecation.deprecated_endpoints('self_adjoint_eig')
 def self_adjoint_eig(tensor, name=None):
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index e7d5f1aec78..8ae11431a08 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1132,6 +1132,10 @@ tf_module {
     name: "EditDistance"
     argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
+  member_method {
+    name: "Eig"
+    argspec: "args=[\'input\', \'Tout\', \'compute_v\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+  }
   member_method {
     name: "Einsum"
     argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
index 3150ea14464..a25583d7fdd 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
@@ -108,10 +108,18 @@ tf_module {
     name: "diag_part"
     argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], "
   }
+  member_method {
+    name: "eig"
+    argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "eigh"
     argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "eigvals"
+    argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "eigvalsh"
     argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index f3d5aec9215..d67870a92b8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -604,6 +604,14 @@ tf_module {
     name: "edit_distance"
     argspec: "args=[\'hypothesis\', \'truth\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'edit_distance\'], "
   }
+  member_method {
+    name: "eig"
+    argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "eigvals"
+    argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
   member_method {
     name: "einsum"
     argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index e7d5f1aec78..8ae11431a08 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1132,6 +1132,10 @@ tf_module {
     name: "EditDistance"
     argspec: "args=[\'hypothesis_indices\', \'hypothesis_values\', \'hypothesis_shape\', \'truth_indices\', \'truth_values\', \'truth_shape\', \'normalize\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
   }
+  member_method {
+    name: "Eig"
+    argspec: "args=[\'input\', \'Tout\', \'compute_v\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
+  }
   member_method {
     name: "Einsum"
     argspec: "args=[\'inputs\', \'equation\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "