diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 4e8639dfc88..cc6fbf26c2f 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1785,6 +1785,7 @@ cuda_py_test(
     size = "medium",
     srcs = ["linalg_ops_test.py"],
     additional_deps = [
+        "@absl_py//absl/testing:parameterized",
         "//third_party/py/numpy",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/kernel_tests/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg_ops_test.py
index aa17f727d09..ccb3feeaf6e 100644
--- a/tensorflow/python/kernel_tests/linalg_ops_test.py
+++ b/tensorflow/python/kernel_tests/linalg_ops_test.py
@@ -18,6 +18,9 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import itertools
+
+from absl.testing import parameterized
 import numpy as np
 
 from tensorflow.python.framework import dtypes
@@ -52,7 +55,7 @@ class CholeskySolveTest(test.TestCase):
   def test_works_with_five_different_random_pos_def_matrices(self):
     for n in range(1, 6):
       for np_type, atol in [(np.float32, 0.05), (np.float64, 1e-5)]:
-        with self.test_session(use_gpu=True):
+        with self.session(use_gpu=True):
           # Create 2 x n x n matrix
           array = np.array(
               [_RandomPDMatrix(n, self.rng),
@@ -76,7 +79,7 @@ class LogdetTest(test.TestCase):
                              (np.complex64, 0.05), (np.complex128, 1e-5)]:
         matrix = _RandomPDMatrix(n, self.rng, np_dtype)
         _, logdet_np = np.linalg.slogdet(matrix)
-        with self.test_session(use_gpu=True):
+        with self.session(use_gpu=True):
           # Create 2 x n x n matrix
           # matrix = np.array(
           #     [_RandomPDMatrix(n, self.rng, np_dtype),
@@ -89,7 +92,7 @@ class LogdetTest(test.TestCase):
                            (np.complex64, 0.05), (np.complex128, 1e-5)]:
       matrix = (np.eye(20) * 1e-6).astype(np_dtype)
       _, logdet_np = np.linalg.slogdet(matrix)
-      with self.test_session(use_gpu=True):
+      with self.session(use_gpu=True):
         logdet_tf = linalg.logdet(matrix)
         self.assertAllClose(logdet_np, logdet_tf.eval(), atol=atol)
 
@@ -105,7 +108,7 @@ class SlogdetTest(test.TestCase):
                              (np.complex64, 0.05), (np.complex128, 1e-5)]:
         matrix = _RandomPDMatrix(n, self.rng, np_dtype)
         sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
-        with self.test_session(use_gpu=True):
+        with self.session(use_gpu=True):
           sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
           self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol)
           self.assertAllClose(sign_np, sign_tf.eval(), atol=atol)
@@ -115,7 +118,7 @@ class SlogdetTest(test.TestCase):
                            (np.complex64, 0.05), (np.complex128, 1e-5)]:
       matrix = (np.eye(20) * 1e-6).astype(np_dtype)
       sign_np, log_abs_det_np = np.linalg.slogdet(matrix)
-      with self.test_session(use_gpu=True):
+      with self.session(use_gpu=True):
         sign_tf, log_abs_det_tf = linalg.slogdet(matrix)
         self.assertAllClose(log_abs_det_np, log_abs_det_tf.eval(), atol=atol)
         self.assertAllClose(sign_np, sign_tf.eval(), atol=atol)
@@ -128,66 +131,126 @@ class AdjointTest(test.TestCase):
       matrix_np = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [4 + 4j, 5 + 5j,
                                                        6 + 6j]]).astype(dtype)
       expected_transposed = np.conj(matrix_np.T)
-      with self.cached_session():
+      with self.session():
         matrix = ops.convert_to_tensor(matrix_np)
         transposed = linalg.adjoint(matrix)
         self.assertEqual((3, 2), transposed.get_shape())
         self.assertAllEqual(expected_transposed, transposed.eval())
 
 
-class EyeTest(test.TestCase):
-  pass  # Will be filled in below
+class EyeTest(parameterized.TestCase, test.TestCase):
 
+  def testShapeInferenceNoBatch(self):
+    self.assertEqual((2, 2), linalg_ops.eye(num_rows=2).shape)
+    self.assertEqual((2, 3), linalg_ops.eye(num_rows=2, num_columns=3).shape)
 
-def _GetEyeTest(num_rows, num_columns, batch_shape, dtype):
+  def testShapeInferenceStaticBatch(self):
+    batch_shape = (2, 3)
+    self.assertEqual(
+        (2, 3, 2, 2),
+        linalg_ops.eye(num_rows=2, batch_shape=batch_shape).shape)
+    self.assertEqual(
+        (2, 3, 2, 3),
+        linalg_ops.eye(
+            num_rows=2, num_columns=3, batch_shape=batch_shape).shape)
 
-  def Test(self):
+  @parameterized.named_parameters(
+      ("DynamicRow", array_ops.placeholder_with_default(2, shape=None), None),
+      ("DynamicRowStaticColumn",
+       array_ops.placeholder_with_default(2, shape=None),
+       3),
+      ("StaticRowDynamicColumn",
+       2,
+       array_ops.placeholder_with_default(3, shape=None)),
+      ("DynamicRowDynamicColumn",
+       array_ops.placeholder_with_default(2, shape=None),
+       array_ops.placeholder_with_default(3, shape=None)))
+  def testShapeInferenceStaticBatchWith(self, num_rows, num_columns):
+    batch_shape = (2, 3)
+    identity_matrix = linalg_ops.eye(
+        num_rows=num_rows,
+        num_columns=num_columns,
+        batch_shape=batch_shape)
+    self.assertEqual(4, identity_matrix.shape.ndims)
+    self.assertEqual((2, 3), identity_matrix.shape[:2])
+    if num_rows is not None and not isinstance(num_rows, ops.Tensor):
+      self.assertEqual(2, identity_matrix.shape[-2])
+
+    if num_columns is not None and not isinstance(num_columns, ops.Tensor):
+      self.assertEqual(3, identity_matrix.shape[-1])
+
+  @parameterized.parameters(
+      itertools.product(
+          # num_rows
+          [0, 1, 2, 5],
+          # num_columns
+          [None, 0, 1, 2, 5],
+          # batch_shape
+          [None, [], [2], [2, 3]],
+          # dtype
+          [
+              dtypes.int32,
+              dtypes.int64,
+              dtypes.float32,
+              dtypes.float64,
+              dtypes.complex64,
+              dtypes.complex128
+          ])
+      )
+  def test_eye_no_placeholder(self, num_rows, num_columns, batch_shape, dtype):
     eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
     if batch_shape is not None:
       eye_np = np.tile(eye_np, batch_shape + [1, 1])
-    for use_placeholder in False, True:
-      if use_placeholder and (num_columns is None or batch_shape is None):
-        return
-      with self.test_session(use_gpu=True) as sess:
-        if use_placeholder:
-          num_rows_placeholder = array_ops.placeholder(
-              dtypes.int32, name="num_rows")
-          num_columns_placeholder = array_ops.placeholder(
-              dtypes.int32, name="num_columns")
-          batch_shape_placeholder = array_ops.placeholder(
-              dtypes.int32, name="batch_shape")
-          eye = linalg_ops.eye(
-              num_rows_placeholder,
-              num_columns=num_columns_placeholder,
-              batch_shape=batch_shape_placeholder,
-              dtype=dtype)
-          eye_tf = sess.run(
-              eye,
-              feed_dict={
-                  num_rows_placeholder: num_rows,
-                  num_columns_placeholder: num_columns,
-                  batch_shape_placeholder: batch_shape
-              })
-        else:
-          eye_tf = linalg_ops.eye(
-              num_rows,
-              num_columns=num_columns,
-              batch_shape=batch_shape,
-              dtype=dtype).eval()
-        self.assertAllEqual(eye_np, eye_tf)
+    eye_tf = self.evaluate(linalg_ops.eye(
+        num_rows,
+        num_columns=num_columns,
+        batch_shape=batch_shape,
+        dtype=dtype))
+    self.assertAllEqual(eye_np, eye_tf)
 
-  return Test
+  @parameterized.parameters(
+      itertools.product(
+          # num_rows
+          [0, 1, 2, 5],
+          # num_columns
+          [0, 1, 2, 5],
+          # batch_shape
+          [[], [2], [2, 3]],
+          # dtype
+          [
+              dtypes.int32,
+              dtypes.int64,
+              dtypes.float32,
+              dtypes.float64,
+              dtypes.complex64,
+              dtypes.complex128
+          ])
+      )
+  def test_eye_with_placeholder(
+      self, num_rows, num_columns, batch_shape, dtype):
+    eye_np = np.eye(num_rows, M=num_columns, dtype=dtype.as_numpy_dtype)
+    eye_np = np.tile(eye_np, batch_shape + [1, 1])
+    num_rows_placeholder = array_ops.placeholder(
+        dtypes.int32, name="num_rows")
+    num_columns_placeholder = array_ops.placeholder(
+        dtypes.int32, name="num_columns")
+    batch_shape_placeholder = array_ops.placeholder(
+        dtypes.int32, name="batch_shape")
+    eye = linalg_ops.eye(
+        num_rows_placeholder,
+        num_columns=num_columns_placeholder,
+        batch_shape=batch_shape_placeholder,
+        dtype=dtype)
+    with self.session(use_gpu=True) as sess:
+      eye_tf = sess.run(
+          eye,
+          feed_dict={
+              num_rows_placeholder: num_rows,
+              num_columns_placeholder: num_columns,
+              batch_shape_placeholder: batch_shape
+          })
+    self.assertAllEqual(eye_np, eye_tf)
 
 
 if __name__ == "__main__":
-  for _num_rows in 0, 1, 2, 5:
-    for _num_columns in None, 0, 1, 2, 5:
-      for _batch_shape in None, [], [2], [2, 3]:
-        for _dtype in (dtypes.int32, dtypes.int64, dtypes.float32,
-                       dtypes.float64, dtypes.complex64, dtypes.complex128):
-          name = "dtype_%s_num_rows_%s_num_column_%s_batch_shape_%s_" % (
-              _dtype.name, _num_rows, _num_columns, _batch_shape)
-          _AddTest(EyeTest, "EyeTest", name,
-                   _GetEyeTest(_num_rows, _num_columns, _batch_shape, _dtype))
-
   test.main()
diff --git a/tensorflow/python/ops/linalg_ops_impl.py b/tensorflow/python/ops/linalg_ops_impl.py
index e7c89f6ae3e..37c724e0325 100644
--- a/tensorflow/python/ops/linalg_ops_impl.py
+++ b/tensorflow/python/ops/linalg_ops_impl.py
@@ -44,22 +44,31 @@ def eye(num_rows,
     is_square = num_columns is None
     batch_shape = [] if batch_shape is None else batch_shape
     num_columns = num_rows if num_columns is None else num_columns
-    if isinstance(num_rows, ops.Tensor) or isinstance(
-        num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
-      batch_shape = ops.convert_to_tensor(
-          batch_shape, name='shape', dtype=dtypes.int32)
+
+    # We cannot statically infer what the diagonal size should be:
+    if (isinstance(num_rows, ops.Tensor) or
+        isinstance(num_columns, ops.Tensor)):
       diag_size = math_ops.minimum(num_rows, num_columns)
-      diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
-      if not is_square:
-        shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
     else:
+      # We can statically infer the diagonal size, and whether it is square.
       if not isinstance(num_rows, compat.integral_types) or not isinstance(
           num_columns, compat.integral_types):
         raise TypeError(
             'num_rows and num_columns must be positive integer values.')
-      batch_shape = [dim for dim in batch_shape]
       is_square = num_rows == num_columns
-      diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
+      diag_size = np.minimum(num_rows, num_columns)
+
+    # We can not statically infer the shape of the tensor.
+    if isinstance(batch_shape, ops.Tensor) or isinstance(diag_size, ops.Tensor):
+      batch_shape = ops.convert_to_tensor(
+          batch_shape, name='shape', dtype=dtypes.int32)
+      diag_shape = array_ops.concat((batch_shape, [diag_size]), axis=0)
+      if not is_square:
+        shape = array_ops.concat((batch_shape, [num_rows, num_columns]), axis=0)
+    # We can statically infer everything.
+    else:
+      batch_shape = list(batch_shape)
+      diag_shape = batch_shape + [diag_size]
       if not is_square:
         shape = batch_shape + [num_rows, num_columns]