diff --git a/tensorflow/compiler/tests/matrix_diag_ops_test.py b/tensorflow/compiler/tests/matrix_diag_ops_test.py
index 343f367c1d0..69ae03a06cf 100644
--- a/tensorflow/compiler/tests/matrix_diag_ops_test.py
+++ b/tensorflow/compiler/tests/matrix_diag_ops_test.py
@@ -26,9 +26,82 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import googletest
 
 
+# LINT.IfChange
+matrix_diag_v3_forward_compat_date = (2019, 12, 6)
+# LINT.ThenChange(
+#   //tensorflow/python/kernel_tests/diag_op_test.py,
+#   //tensorflow/python/ops/array_ops.py,
+#   //tensorflow/python/ops/parallel_for/array_test.py
+# )
+
+default_v2_alignment = "LEFT_LEFT"
+alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT"]
+
+
+def zip_to_first_list_length(a, b):
+  if len(b) > len(a):
+    return zip(a, b[:len(a)])
+  return zip(a, b + [None] * (len(a) - len(b)))
+
+
+# Routines to convert test cases to have diagonals in a specified alignment.
+# Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py
+def repack_diagonals(packed_diagonals,
+                     diag_index,
+                     num_rows,
+                     num_cols,
+                     align=None):
+  # The original test cases are LEFT_LEFT aligned.
+  if align == default_v2_alignment or align is None:
+    return packed_diagonals
+
+  align = align.split("_")
+  d_lower, d_upper = diag_index
+  batch_dims = packed_diagonals.ndim - (2 if d_lower < d_upper else 1)
+  max_diag_len = packed_diagonals.shape[-1]
+  index = (slice(None),) * batch_dims
+  repacked_diagonals = np.zeros_like(packed_diagonals)
+
+  # Aligns each diagonal row-by-row.
+  for diag_index in range(d_lower, d_upper + 1):
+    diag_len = min(num_rows + min(0, diag_index), num_cols - max(0, diag_index))
+    row_index = d_upper - diag_index
+    padding_len = max_diag_len - diag_len
+    left_align = (diag_index >= 0 and
+                  align[0] == "LEFT") or (diag_index <= 0 and
+                                          align[1] == "LEFT")
+    # Prepares index tuples.
+    extra_dim = tuple() if d_lower == d_upper else (row_index,)
+    packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),)
+    repacked_last_dim = (slice(None),) if left_align else (slice(
+        padding_len, max_diag_len, 1),)
+    packed_index = index + extra_dim + packed_last_dim
+    repacked_index = index + extra_dim + repacked_last_dim
+
+    # Repacks the diagonal.
+    repacked_diagonals[repacked_index] = packed_diagonals[packed_index]
+  return repacked_diagonals
+
+
+def repack_diagonals_in_tests(tests, align=None):
+  # The original test cases are LEFT_LEFT aligned.
+  if align == default_v2_alignment or align is None:
+    return tests
+
+  new_tests = dict()
+  # Loops through each case.
+  for diag_index, (packed_diagonals, padded_diagonals) in tests.items():
+    num_rows, num_cols = padded_diagonals.shape[-2:]
+    repacked_diagonals = repack_diagonals(
+        packed_diagonals, diag_index, num_rows, num_cols, align=align)
+    new_tests[diag_index] = (repacked_diagonals, padded_diagonals)
+
+  return new_tests
+
+
 # Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2.
 # Copied from //third_party/tensorflow/python/kernel_tests/diag_op_test.py
-def square_cases():
+def square_cases(align=None):
   # pyformat: disable
   mat = np.array([[[1, 2, 3, 4, 5],
                    [6, 7, 8, 9, 1],
@@ -103,10 +176,10 @@ def square_cases():
                             [0, 0, 0, 0, 0],
                             [0, 0, 0, 0, 0]]]))
   # pyformat: enable
-  return (mat, tests)
+  return (mat, repack_diagonals_in_tests(tests, align))
 
 
-def tall_cases():
+def tall_cases(align=None):
   # pyformat: disable
   mat = np.array([[[1, 2, 3],
                    [4, 5, 6],
@@ -191,10 +264,10 @@ def tall_cases():
                             [0, 0, 0],
                             [0, 0, 0]]]))
   # pyformat: enable
-  return (mat, tests)
+  return (mat, repack_diagonals_in_tests(tests, align))
 
 
-def fat_cases():
+def fat_cases(align=None):
   # pyformat: disable
   mat = np.array([[[1, 2, 3, 4],
                    [5, 6, 7, 8],
@@ -259,7 +332,11 @@ def fat_cases():
                             [0, 9, 1, 2],
                             [0, 0, 5, 6]]]))
   # pyformat: enable
-  return (mat, tests)
+  return (mat, repack_diagonals_in_tests(tests, align))
+
+
+def all_tests(align=None):
+  return [square_cases(align), tall_cases(align), fat_cases(align)]
 
 
 class MatrixDiagTest(xla_test.XLATestCase):
@@ -327,39 +404,31 @@ class MatrixDiagTest(xla_test.XLATestCase):
 
   # From here onwards are v2-only tests.
   def testSquare(self):
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-      for _, tests in [square_cases()]:
-        for diag_index, (vecs, solution) in tests.items():
-          self._assertOpOutputMatchesExpected(
-              {
-                  "diagonal": vecs[0],
-                  "k": diag_index
-              }, solution[0])
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+      for align in alignment_list:
+        for _, tests in [square_cases(align)]:
+          for diag_index, (vecs, solution) in tests.items():
+            params = {"diagonal": vecs[0], "k": diag_index, "align": align}
+            self._assertOpOutputMatchesExpected(params, solution[0])
 
   def testSquareBatch(self):
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-      for _, tests in [square_cases()]:
-        for diag_index, (vecs, solution) in tests.items():
-          self._assertOpOutputMatchesExpected(
-              {
-                  "diagonal": vecs,
-                  "k": diag_index
-              }, solution)
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+      for align in alignment_list:
+        for _, tests in [square_cases(align)]:
+          for diag_index, (vecs, solution) in tests.items():
+            params = {"diagonal": vecs, "k": diag_index, "align": align}
+            self._assertOpOutputMatchesExpected(params, solution)
 
   def testRectangularBatch(self):
-    # LINT.IfChange
-    if not compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
+    if not compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
       return
 
     # Stores expected num_rows and num_cols (when the other is given).
     # expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
     test_list = list()
 
+    # Do not align the test cases here. Re-alignment needs to happen after the
+    # solution shape is updated.
     # Square cases:
     expected = {
         (-1, -1): (5, 4),
@@ -389,43 +458,65 @@ class MatrixDiagTest(xla_test.XLATestCase):
     test_list.append((expected, fat_cases()))
 
     # Giving both num_rows and num_cols
-    for _, tests in [tall_cases(), fat_cases()]:
+    align = alignment_list[0]
+    for _, tests in [tall_cases(align), fat_cases(align)]:
       for diag_index, (vecs, solution) in tests.items():
         self._assertOpOutputMatchesExpected(
             {
                 "diagonal": vecs,
                 "k": diag_index,
                 "num_rows": solution.shape[-2],
-                "num_cols": solution.shape[-1]
+                "num_cols": solution.shape[-1],
+                "align": align
             }, solution)
 
+    # We go through each alignment in a round-robin manner.
+    align_index = 0
+
     # Giving just num_rows or num_cols.
     for expected, (_, tests) in test_list:
       for diag_index, (new_num_rows, new_num_cols) in expected.items():
+        align = alignment_list[align_index]
+        align_index = (align_index + 1) % len(alignment_list)
         vecs, solution = tests[diag_index]
         solution_given_num_rows = solution.take(
             indices=range(new_num_cols), axis=-1)
+        # Repacks the diagonal input according to the new solution shape.
+        vecs_given_num_rows = repack_diagonals(
+            vecs,
+            diag_index,
+            solution_given_num_rows.shape[-2],
+            new_num_cols,
+            align=align)
         self._assertOpOutputMatchesExpected(
             {
-                "diagonal": vecs,
+                "diagonal": vecs_given_num_rows,
                 "k": diag_index,
-                "num_rows": solution_given_num_rows.shape[-2]
+                "num_rows": solution_given_num_rows.shape[-2],
+                "align": align
             }, solution_given_num_rows)
         solution_given_num_cols = solution.take(
             indices=range(new_num_rows), axis=-2)
+        # Repacks the diagonal input according to the new solution shape.
+        vecs_given_num_cols = repack_diagonals(
+            vecs,
+            diag_index,
+            new_num_rows,
+            solution_given_num_cols.shape[-1],
+            align=align)
         self._assertOpOutputMatchesExpected(
             {
-                "diagonal": vecs,
+                "diagonal": vecs_given_num_cols,
                 "k": diag_index,
-                "num_cols": solution_given_num_cols.shape[-1]
+                "num_cols": solution_given_num_cols.shape[-1],
+                "align": align
             }, solution_given_num_cols)
 
   def testPadding(self):
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-      for padding_value in [555, -11]:
-        for _, tests in [square_cases(), tall_cases(), fat_cases()]:
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+      for padding_value, align in zip_to_first_list_length([555, -11],
+                                                           alignment_list):
+        for _, tests in all_tests(align):
           for diag_index, (vecs, solution) in tests.items():
             mask = (solution == 0)
             solution = solution + (mask * padding_value)
@@ -435,7 +526,8 @@ class MatrixDiagTest(xla_test.XLATestCase):
                     "k": diag_index,
                     "num_rows": solution.shape[-2],
                     "num_cols": solution.shape[-1],
-                    "padding_value": padding_value
+                    "padding_value": padding_value,
+                    "align": align
                 }, solution)
 
 
@@ -542,36 +634,36 @@ class MatrixSetDiagTest(xla_test.XLATestCase):
 
   # From here onwards are v2-only tests.
   def testSingleMatrix(self):
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-      for _, tests in [square_cases(), tall_cases(), fat_cases()]:
-        for diag_index, (vecs, banded_mat) in tests.items():
-          mask = (banded_mat[0] == 0)
-          input_mat = np.random.randint(10, size=mask.shape)
-          solution = input_mat * mask + banded_mat[0]
-          self._assertOpOutputMatchesExpected(
-              {
-                  "input": input_mat,
-                  "diagonal": vecs[0],
-                  "k": diag_index
-              }, solution)
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+      for align in alignment_list:
+        for _, tests in all_tests(align):
+          for diag_index, (vecs, banded_mat) in tests.items():
+            mask = (banded_mat[0] == 0)
+            input_mat = np.random.randint(10, size=mask.shape)
+            solution = input_mat * mask + banded_mat[0]
+            self._assertOpOutputMatchesExpected(
+                {
+                    "input": input_mat,
+                    "diagonal": vecs[0],
+                    "k": diag_index,
+                    "align": align
+                }, solution)
 
   def testBatch(self):
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-      for _, tests in [square_cases(), tall_cases(), fat_cases()]:
-        for diag_index, (vecs, banded_mat) in tests.items():
-          mask = (banded_mat == 0)
-          input_mat = np.random.randint(10, size=mask.shape)
-          solution = input_mat * mask + banded_mat
-          self._assertOpOutputMatchesExpected(
-              {
-                  "input": input_mat,
-                  "diagonal": vecs,
-                  "k": diag_index
-              }, solution)
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+      for align in alignment_list:
+        for _, tests in all_tests(align):
+          for diag_index, (vecs, banded_mat) in tests.items():
+            mask = (banded_mat == 0)
+            input_mat = np.random.randint(10, size=mask.shape)
+            solution = input_mat * mask + banded_mat
+            self._assertOpOutputMatchesExpected(
+                {
+                    "input": input_mat,
+                    "diagonal": vecs,
+                    "k": diag_index,
+                    "align": align
+                }, solution)
 
 
 class MatrixDiagPartTest(xla_test.XLATestCase):
@@ -613,33 +705,35 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
 
   # From here onwards are v2-only tests.
   def testSingleMatrix(self):
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-      for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
-        for diag_index, (solution, _) in tests.items():
-          self._assertOpOutputMatchesExpected({
-              "input": mat[0],
-              "k": diag_index
-          }, solution[0])
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+      for align in alignment_list:
+        test_list = [square_cases(align), tall_cases(align), fat_cases(align)]
+        for mat, tests in test_list:
+          for diag_index, (solution, _) in tests.items():
+            self._assertOpOutputMatchesExpected(
+                {
+                    "input": mat[0],
+                    "k": diag_index,
+                    "align": align
+                }, solution[0])
 
   def testBatch(self):
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-      for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
-        for diag_index, (solution, _) in tests.items():
-          self._assertOpOutputMatchesExpected({
-              "input": mat,
-              "k": diag_index
-          }, solution)
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+      for align in alignment_list:
+        for mat, tests in all_tests(align):
+          for diag_index, (solution, _) in tests.items():
+            self._assertOpOutputMatchesExpected(
+                {
+                    "input": mat,
+                    "k": diag_index,
+                    "align": align
+                }, solution)
 
   def testPadding(self):
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-      for padding_value in [555, -11]:
-        for mat, tests in [square_cases(), tall_cases(), fat_cases()]:
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+      for padding_value, align in zip_to_first_list_length([555, -11],
+                                                           alignment_list):
+        for mat, tests in all_tests(align):
           for diag_index, (solution, _) in tests.items():
             mask = (solution == 0)
             solution = solution + (mask * padding_value)
@@ -647,7 +741,8 @@ class MatrixDiagPartTest(xla_test.XLATestCase):
                 {
                     "input": mat,
                     "k": diag_index,
-                    "padding_value": padding_value
+                    "padding_value": padding_value,
+                    "align": align
                 }, solution)
 
 
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc
index c6babf0d5f7..7cf9da0c057 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc
@@ -26,6 +26,30 @@ limitations under the License.
 namespace tensorflow {
 namespace {
 
+// Calculates the diagonal length of a diagonal.
+static inline int ComputeDiagLen(int diag_index, int num_rows, int num_cols) {
+  return std::min(num_rows + std::min(0, diag_index),
+                  num_cols - std::max(0, diag_index));
+}
+
+// Checks if a diagonal is to be aligned left or right.
+static inline bool IsLeftAligned(int diag_index, bool left_align_superdiagonal,
+                                 bool left_align_subdiagonal) {
+  return (diag_index >= 0 && left_align_superdiagonal) ||
+         (diag_index <= 0 && left_align_subdiagonal);
+}
+
+// Reads the diagonal packing alignment.
+void ReadAlignment(OpKernelConstruction* context,
+                   bool* left_align_superdiagonal,
+                   bool* left_align_subdiagonal) {
+  string align;
+  OP_REQUIRES_OK(context, context->GetAttr("align", &align));
+
+  *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
+  *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
+}
+
 // Reads or infers lower_diag_index and upper_diag_index from kernel's input
 // parameter "k". Also validates their values.
 std::pair<int64, int64> ProcessDiagIndex(XlaOpKernelContext* context) {
@@ -94,7 +118,9 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
                          const TensorShape& input_shape, const int64 diag_rank,
                          const int64 num_diags, const int64 lower_diag_index,
                          const int64 upper_diag_index, const int64 max_diag_len,
-                         const int64 num_rows, const int64 num_cols) {
+                         const int64 num_rows, const int64 num_cols,
+                         const bool left_align_superdiagonal,
+                         const bool left_align_subdiagonal) {
   // Creates a padding config.
   const int input_rank = input_shape.dims();
   xla::PaddingConfig padding_config;
@@ -108,29 +134,26 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
   // XLA can fuse multiple Broadcasts and Selects so this shouldn't be slow.
   //
   // For example,
-  //   diag = [[2, 3, 0], k = (-1, 1), and num_rows = 4.
+  //   diag = [[0, 2, 3], k = (-1, 1), num_cols = 4, and align="RIGHT_LEFT".
   //           [4, 5, 6],
   //           [7, 8, 9]]
-  // The expected output is [[4, 2, 0],
-  //                         [7, 5, 4],
-  //                         [0, 8, 6],
-  //                         [0, 0, 9]]
+  // The expected output is [[7, 4, 2, 0],
+  //                         [0, 8, 5, 3],
+  //                         [0, 0, 9, 6]].
   // The 1st diagonal is created by:
-  // 1) Extracting diag_slice = [1, 2, 0].
-  // 2) Padding the vector to be as long as num_rows,
-  //      diag_slice = [1, 2, 0, 0],
-  //    then broadcasting diag_slice row-wise to a full matrix,
-  //      diag_broadcast = [[1, 1, 1],
-  //                        [2, 2, 2],
-  //                        [0, 0, 0],
-  //                        [0, 0, 0]]
+  // 1) Extracting diag_slice = [0, 2, 3] which is right-aligned.
+  // 2) Padding the vector (in the same direction) to be as long as num_cols,
+  //      diag_slice = [0, 0, 2, 3],
+  //    then broadcasting diag_slice column-wise to a full matrix,
+  //      diag_broadcast = [[0, 0, 2, 3],
+  //                        [0, 0, 2, 3],
+  //                        [0, 0, 2, 3]].
   //    The padding value can be anything because it will not appear in the
   //    results after masking. Here, we use zero.
   // 3) Masking diag_broadcast with a mask of the shape of the 1st diagonal.
-  //      mask = [[0, 1, 0],  -->  output = [[x, 2, x],
-  //              [0, 0, 1],                 [x, x, 3],
-  //              [0, 0, 0],                 [x, x, x],
-  //              [0, 0, 0]]                 [x, x, x]],
+  //      mask = [[0, 0, 1, 0],  -->  output = [[x, x, 2, x],
+  //              [0, 0, 0, 1],                 [x, x, x, 3],
+  //              [0, 0, 0, 0]]                 [x, x, x, x]],
   //    where x denotes the existing input contents.
   std::vector<int64> broadcast_dimensions(input_rank - 1);
   absl::c_iota(broadcast_dimensions, 0);
@@ -140,6 +163,8 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
     // Extracts a single diagonal.
     auto diag_slice = diag;
     if (num_diags > 1) {
+      // The result of SliceInDim has dims: [<batch_dim>, 1, max_diag_len].
+      // We call Collapse to make the dims: [<batch_dim>, max_diag_len].
       const int64 mapped_diag_index = upper_diag_index - diag_index;
       diag_slice = xla::Collapse(
           xla::SliceInDim(diag, mapped_diag_index, mapped_diag_index + 1, 1,
@@ -147,20 +172,51 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
           {diag_rank - 2, diag_rank - 1});
     }
 
-    // Pads if necessary. Always pad at the end because shorter diagonals in
-    // the input come padded at the end.
-    const int64 padding_length =
-        ((diag_index <= 0) ? num_cols : num_rows) - max_diag_len;
-    const xla::XlaOp zero = xla::ScalarLike(input, 0);
-    if (padding_length > 0) {
+    // Pad if necessary.
+    // - If the diagonal has the longest length, i.e., min(num_rows, num_cols),
+    //   no padding is necessary. It is broadcast column-wise if it is a sub-
+    //   diagonal, row-wise if superdiagonal.
+    // - Otherwise, pad and keep the old alignment (shorter diagonals in the
+    //   input come pre-padded). max_len in the table refers to max_diag_len.
+    //   -------------------------------------------------------------------
+    //   | Diag  | Align | Broadcast   |   padding_low   |   padding_high  |
+    //   -------------------------------------------------------------------
+    //   | Super | Left  | Row-wise    |        0        | #rows - max_len |
+    //   |       | Right | Column-wise | #cols - max_len |        0        |
+    //   -------------------------------------------------------------------
+    //   | Sub   | Left  | Column-wise |        0        | #cols - max_len |
+    //   |       | Right | Row-wise    | #rows - max_len |        0        |
+    //   -------------------------------------------------------------------
+    if (num_cols - num_rows <= diag_index && diag_index <= 0) {
+      broadcast_dimensions.back() = input_rank - 1;  // Column-wise.
+    } else if (0 <= diag_index && diag_index <= num_cols - num_rows) {
+      broadcast_dimensions.back() = input_rank - 2;  // Row-wise.
+    } else {
+      int length_to_pad_to;
+      if ((diag_index > 0 && left_align_superdiagonal) ||
+          (diag_index < 0 && !left_align_subdiagonal)) {
+        length_to_pad_to = num_rows;
+        broadcast_dimensions.back() = input_rank - 2;  // Row-wise.
+      } else {
+        length_to_pad_to = num_cols;
+        broadcast_dimensions.back() = input_rank - 1;  // Column-wise.
+      }
+      int padding_low = length_to_pad_to - max_diag_len;
+      int padding_high = 0;
+      if (IsLeftAligned(diag_index, left_align_superdiagonal,
+                        left_align_subdiagonal)) {
+        std::swap(padding_low, padding_high);
+      }
       padding_config.mutable_dimensions(input_rank - 2)
-          ->set_edge_padding_high(padding_length);
+          ->set_edge_padding_low(padding_low);
+      padding_config.mutable_dimensions(input_rank - 2)
+          ->set_edge_padding_high(padding_high);
+
+      const xla::XlaOp zero = xla::ScalarLike(input, 0);
       diag_slice = xla::Pad(diag_slice, zero, padding_config);
     }
 
-    // Broadcasts column-wise for subdiagonals; row-wise for superdiagonals.
-    broadcast_dimensions.back() =
-        (diag_index <= 0) ? input_rank - 1 : input_rank - 2;
+    // Broadcast and mask.
     xla::XlaOp diag_broadcast = xla::BroadcastInDim(
         diag_slice, input_shape.dim_sizes(), broadcast_dimensions);
     const auto mask = xla::GetDiagonalMask(output, diag_index);
@@ -173,11 +229,17 @@ xla::XlaOp SetMatrixDiag(const xla::XlaOp input, const xla::XlaOp diag,
 
 class MatrixDiagOp : public XlaOpKernel {
  public:
-  explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+  explicit MatrixDiagOp(OpKernelConstruction* context) : XlaOpKernel(context) {
+    // MatrixDiagV3-specific.
+    if (context->HasAttr("align")) {
+      ReadAlignment(context, &left_align_superdiagonal_,
+                    &left_align_subdiagonal_);
+    }
+  }
 
   void Compile(XlaOpKernelContext* context) override {
     OP_REQUIRES(
-        context, context->num_inputs() >= 1,
+        context, context->num_inputs() >= kNumV1Inputs,
         errors::InvalidArgument("MatrixDiag op must have at least one input"));
     const TensorShape diag_shape = context->InputShape(0);
     OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(diag_shape),
@@ -199,7 +261,7 @@ class MatrixDiagOp : public XlaOpKernel {
     // MatrixDiag and MatrixDiagV2 both use this OpKernel. MatrixDiag only has
     // one input, so we have to check the number of inputs before reading
     // additional parameters for MatrixDiagV2.
-    if (context->num_inputs() > 1) {
+    if (context->num_inputs() > kNumV1Inputs) {
       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &num_rows));
       OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(3, &num_cols));
@@ -252,8 +314,14 @@ class MatrixDiagOp : public XlaOpKernel {
     context->SetOutput(
         0, SetMatrixDiag(output, diag, output_shape, diag_rank, num_diags,
                          lower_diag_index, upper_diag_index, max_diag_len,
-                         num_rows, num_cols));
+                         num_rows, num_cols, left_align_superdiagonal_,
+                         left_align_subdiagonal_));
   }
+
+ private:
+  bool left_align_superdiagonal_ = true;
+  bool left_align_subdiagonal_ = true;
+  static constexpr int kNumV1Inputs = 1;
 };
 
 REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp);
@@ -263,12 +331,24 @@ REGISTER_XLA_OP(Name("MatrixDiagV2")
                     .CompileTimeConstantInput("num_cols")
                     .CompileTimeConstantInput("padding_value"),
                 MatrixDiagOp);
+REGISTER_XLA_OP(Name("MatrixDiagV3")
+                    .CompileTimeConstantInput("k")
+                    .CompileTimeConstantInput("num_rows")
+                    .CompileTimeConstantInput("num_cols")
+                    .CompileTimeConstantInput("padding_value"),
+                MatrixDiagOp);
 
 class MatrixDiagPartOp : public XlaOpKernel {
  public:
   explicit MatrixDiagPartOp(OpKernelConstruction* context)
       : XlaOpKernel(context),
-        is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {}
+        is_gpu_(context->device_type().type_string() == DEVICE_GPU_XLA_JIT) {
+    // MatrixDiagPartV3-specific.
+    if (context->HasAttr("align")) {
+      ReadAlignment(context, &left_align_superdiagonal_,
+                    &left_align_subdiagonal_);
+    }
+  }
 
   void Compile(XlaOpKernelContext* context) override {
     const TensorShape input_shape = context->InputShape(0);
@@ -290,7 +370,7 @@ class MatrixDiagPartOp : public XlaOpKernel {
     // MatrixDiagPart and MatrixDiagPartV2 both use this OpKernel.
     // MatrixDiagPart only has one input, so we have to check the number of
     // inputs before reading additional parameters in MatrixDiagV2.
-    if (context->num_inputs() > 1) {
+    if (context->num_inputs() > kNumV1Inputs) {
       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
       padding_value = context->Input(2);
     }
@@ -314,25 +394,34 @@ class MatrixDiagPartOp : public XlaOpKernel {
     // Computes output.
     xla::XlaOp input = context->Input(0);
     std::vector<xla::XlaOp> diag_list;
-    xla::PaddingConfig padding_config;
+    xla::PaddingConfig padding_config =
+        xla::MakeNoPaddingConfig(input_rank - 1);
     if (num_diags == 1) {
       context->SetOutput(
           0, is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, upper_diag_index)
                      : xla::GetMatrixDiagonal(input, upper_diag_index));
       return;
     }
-    padding_config = xla::MakeNoPaddingConfig(input_rank - 1);
     for (int diag_index = upper_diag_index; diag_index >= lower_diag_index;
          --diag_index) {
       xla::XlaOp single_diag =
           is_gpu_ ? xla::GetMatrixDiagonalViaGather(input, diag_index)
                   : xla::GetMatrixDiagonal(input, diag_index);
-      const int64 diag_length =
-          (diag_index >= 0) ? (num_cols - diag_index) : (num_rows + diag_index);
-      const int64 padding_length = max_diag_len - diag_length;
-      if (padding_length > 0) {
-        padding_config.mutable_dimensions(input_rank - 2)
-            ->set_edge_padding_high(padding_length);
+      const int64 diag_len = ComputeDiagLen(diag_index, num_rows, num_cols);
+      const int64 padding_len = max_diag_len - diag_len;
+      if (padding_len > 0) {
+        if (IsLeftAligned(diag_index, left_align_superdiagonal_,
+                          left_align_subdiagonal_)) {
+          padding_config.mutable_dimensions(input_rank - 2)
+              ->set_edge_padding_low(0);
+          padding_config.mutable_dimensions(input_rank - 2)
+              ->set_edge_padding_high(padding_len);
+        } else {
+          padding_config.mutable_dimensions(input_rank - 2)
+              ->set_edge_padding_low(padding_len);
+          padding_config.mutable_dimensions(input_rank - 2)
+              ->set_edge_padding_high(0);
+        }
         single_diag = xla::Pad(single_diag, padding_value, padding_config);
       }
       diag_list.emplace_back(single_diag);
@@ -344,6 +433,9 @@ class MatrixDiagPartOp : public XlaOpKernel {
 
  private:
   const bool is_gpu_;
+  bool left_align_superdiagonal_ = true;
+  bool left_align_subdiagonal_ = true;
+  static constexpr int kNumV1Inputs = 1;
 };
 
 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp);
@@ -351,11 +443,21 @@ REGISTER_XLA_OP(Name("MatrixDiagPartV2")
                     .CompileTimeConstantInput("k")
                     .CompileTimeConstantInput("padding_value"),
                 MatrixDiagPartOp);
+REGISTER_XLA_OP(Name("MatrixDiagPartV3")
+                    .CompileTimeConstantInput("k")
+                    .CompileTimeConstantInput("padding_value"),
+                MatrixDiagPartOp);
 
 class MatrixSetDiagOp : public XlaOpKernel {
  public:
   explicit MatrixSetDiagOp(OpKernelConstruction* context)
-      : XlaOpKernel(context) {}
+      : XlaOpKernel(context) {
+    // MatrixSetDiagV3-specific.
+    if (context->HasAttr("align")) {
+      ReadAlignment(context, &left_align_superdiagonal_,
+                    &left_align_subdiagonal_);
+    }
+  }
 
   void Compile(XlaOpKernelContext* context) override {
     const TensorShape input_shape = context->InputShape(0);
@@ -378,7 +480,7 @@ class MatrixSetDiagOp : public XlaOpKernel {
     // reading additional parameters in MatrixSetDiagV2.
     int64 lower_diag_index = 0;
     int64 upper_diag_index = 0;
-    if (context->num_inputs() > 2) {
+    if (context->num_inputs() > kNumV1Inputs) {
       std::tie(lower_diag_index, upper_diag_index) = ProcessDiagIndex(context);
     }
 
@@ -419,15 +521,21 @@ class MatrixSetDiagOp : public XlaOpKernel {
     context->SetOutput(
         0, SetMatrixDiag(input, diag, input_shape, diag_rank, num_diags,
                          lower_diag_index, upper_diag_index, max_diag_len,
-                         num_rows, num_cols));
+                         num_rows, num_cols, left_align_superdiagonal_,
+                         left_align_subdiagonal_));
   }
 
  private:
+  bool left_align_superdiagonal_ = true;
+  bool left_align_subdiagonal_ = true;
+  static constexpr int kNumV1Inputs = 2;
   TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
 };
 
 REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp);
 REGISTER_XLA_OP(Name("MatrixSetDiagV2").CompileTimeConstantInput("k"),
                 MatrixSetDiagOp);
+REGISTER_XLA_OP(Name("MatrixSetDiagV3").CompileTimeConstantInput("k"),
+                MatrixSetDiagOp);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixDiagPartV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixDiagPartV3.pbtxt
new file mode 100644
index 00000000000..a9fe8802e66
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MatrixDiagPartV3.pbtxt
@@ -0,0 +1,141 @@
+op {
+  graph_op_name: "MatrixDiagPartV3"
+  in_arg {
+    name: "input"
+    description: "Rank `r` tensor where `r >= 2`."
+  }
+  in_arg {
+    name: "k"
+    description: <<END
+Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main
+diagonal, and negative value means subdiagonals. `k` can be a single integer
+(for a single diagonal) or a pair of integers specifying the low and high ends
+of a matrix band. `k[0]` must not be larger than `k[1]`.
+END
+  }
+  in_arg {
+    name: "padding_value"
+    description: <<END
+The value to fill the area outside the specified diagonal band with.
+Default is 0.
+END
+  }
+  out_arg {
+    name: "diagonal"
+    description: "The extracted diagonal(s)."
+  }
+  attr {
+    name: "align"
+    description: <<END
+Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is
+a string specifying how superdiagonals and subdiagonals should be aligned,
+respectively. There are four possible alignments: "RIGHT_LEFT" (default),
+"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals
+to the right (left-pads the row) and subdiagonals to the left (right-pads the
+row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is
+the opposite alignment.
+END
+  }
+  summary: "Returns the batched diagonal part of a batched tensor."
+  description: <<END
+Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched
+`input`.
+
+Assume `input` has `r` dimensions `[I, J, ..., L, M, N]`.
+Let `max_diag_len` be the maximum length among all diagonals to be extracted,
+`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
+Let `num_diags` be the number of diagonals to extract,
+`num_diags = k[1] - k[0] + 1`.
+
+If `num_diags == 1`, the output tensor is of rank `r - 1` with shape
+`[I, J, ..., L, max_diag_len]` and values:
+
+```
+diagonal[i, j, ..., l, n]
+  = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
+    padding_value                 ; otherwise.
+```
+where `y = max(-k[1], 0)`, `x = max(k[1], 0)`.
+
+Otherwise, the output tensor has rank `r` with dimensions
+`[I, J, ..., L, num_diags, max_diag_len]` with values:
+
+```
+diagonal[i, j, ..., l, m, n]
+  = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
+    padding_value                 ; otherwise.
+```
+where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`.
+
+`offset` is zero except when the alignment of the diagonal is to the right.
+```
+offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
+                                           and `d >= 0`) or
+                                         (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
+                                           and `d <= 0`)
+         0                          ; otherwise
+```
+where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
+
+The input must be at least a matrix.
+
+For example:
+
+```
+input = np.array([[[1, 2, 3, 4],  # Input shape: (2, 3, 4)
+                   [5, 6, 7, 8],
+                   [9, 8, 7, 6]],
+                  [[5, 4, 3, 2],
+                   [1, 2, 3, 4],
+                   [5, 6, 7, 8]]])
+
+# A main diagonal from each batch.
+tf.matrix_diag_part(input) ==> [[1, 6, 7],  # Output shape: (2, 3)
+                                [5, 2, 7]]
+
+# A superdiagonal from each batch.
+tf.matrix_diag_part(input, k = 1)
+  ==> [[2, 7, 6],  # Output shape: (2, 3)
+       [4, 3, 8]]
+
+# A band from each batch.
+tf.matrix_diag_part(input, k = (-1, 2))
+  ==> [[[0, 3, 8],  # Output shape: (2, 4, 3)
+        [2, 7, 6],
+        [1, 6, 7],
+        [5, 8, 0]],
+       [[0, 3, 4],
+        [4, 3, 8],
+        [5, 2, 7],
+        [1, 6, 0]]]
+
+# LEFT_RIGHT alignment.
+tf.matrix_diag_part(input, k = (-1, 2), align="LEFT_RIGHT")
+  ==> [[[3, 8, 0],  # Output shape: (2, 4, 3)
+        [2, 7, 6],
+        [1, 6, 7],
+        [0, 5, 8]],
+       [[3, 4, 0],
+        [4, 3, 8],
+        [5, 2, 7],
+        [0, 1, 6]]]
+
+# max_diag_len can be shorter than the main diagonal.
+tf.matrix_diag_part(input, k = (-2, -1))
+  ==> [[[5, 8],
+        [9, 0]],
+       [[1, 6],
+        [5, 0]]]
+
+# padding_value = 9
+tf.matrix_diag_part(input, k = (1, 3), padding_value = 9)
+  ==> [[[9, 9, 4],  # Output shape: (2, 3, 3)
+        [9, 3, 8],
+        [2, 7, 6]],
+       [[9, 9, 2],
+        [9, 3, 4],
+        [4, 3, 8]]]
+
+```
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixDiagV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixDiagV3.pbtxt
new file mode 100644
index 00000000000..d96e48fdbea
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MatrixDiagV3.pbtxt
@@ -0,0 +1,177 @@
+op {
+  graph_op_name: "MatrixDiagV3"
+  in_arg {
+    name: "diagonal"
+    description: "Rank `r`, where `r >= 1`"
+  }
+  in_arg {
+    name: "k"
+    description: <<END
+Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main
+diagonal, and negative value means subdiagonals. `k` can be a single integer
+(for a single diagonal) or a pair of integers specifying the low and high ends
+of a matrix band. `k[0]` must not be larger than `k[1]`.
+END
+  }
+  in_arg {
+    name: "num_rows"
+    description: <<END
+The number of rows of the output matrix. If it is not provided, the op assumes
+the output matrix is a square matrix and infers the matrix size from k and the
+innermost dimension of `diagonal`.
+END
+  }
+  in_arg {
+    name: "num_cols"
+    description: <<END
+The number of columns of the output matrix. If it is not provided, the op
+assumes the output matrix is a square matrix and infers the matrix size from
+k and the innermost dimension of `diagonal`.
+END
+  }
+  in_arg {
+    name: "padding_value"
+    description: <<END
+The number to fill the area outside the specified diagonal band with.
+Default is 0.
+END
+  }
+  out_arg {
+    name: "output"
+    description: <<END
+Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise.
+END
+  }
+  attr {
+    name: "align"
+    description: <<END
+Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is
+a string specifying how superdiagonals and subdiagonals should be aligned,
+respectively. There are four possible alignments: "RIGHT_LEFT" (default),
+"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals
+to the right (left-pads the row) and subdiagonals to the left (right-pads the
+row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is
+the opposite alignment.
+END
+  }
+  summary:
+    "Returns a batched diagonal tensor with given batched diagonal values."
+  description: <<END
+Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th
+diagonals of a matrix, with everything else padded with `padding`. `num_rows`
+and `num_cols` specify the dimension of the innermost matrix of the output. If
+both are not specified, the op assumes the innermost matrix is square and infers
+its size from `k` and the innermost dimension of `diagonal`. If only one of them
+is specified, the op assumes the unspecified value is the smallest possible
+based on other criteria.
+
+Let `diagonal` have `r` dimensions `[I, J, ..., L, M, N]`. The output tensor has
+rank `r+1` with shape `[I, J, ..., L, M, num_rows, num_cols]` when only one
+diagonal is given (`k` is an integer or `k[0] == k[1]`). Otherwise, it has rank
+`r` with shape `[I, J, ..., L, num_rows, num_cols]`.
+
+The second innermost dimension of `diagonal` has double meaning.
+When `k` is scalar or `k[0] == k[1]`, `M` is part of the batch size
+[I, J, ..., M], and the output tensor is:
+
+```
+output[i, j, ..., l, m, n]
+  = diagonal[i, j, ..., l, n-max(d_upper, 0)] ; if n - m == d_upper
+    padding_value                             ; otherwise
+```
+
+Otherwise, `M` is treated as the number of diagonals for the matrix in the
+same batch (`M = k[1]-k[0]+1`), and the output tensor is:
+
+```
+output[i, j, ..., l, m, n]
+  = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
+    padding_value                                     ; otherwise
+```
+where `d = n - m`, `diag_index = [k] - d`, and
+`index_in_diag = n - max(d, 0) + offset`.
+
+`offset` is zero except when the alignment of the diagonal is to the right.
+```
+offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
+                                           and `d >= 0`) or
+                                         (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
+                                           and `d <= 0`)
+         0                          ; otherwise
+```
+where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
+
+For example:
+
+```
+# The main diagonal.
+diagonal = np.array([[1, 2, 3, 4],            # Input shape: (2, 4)
+                     [5, 6, 7, 8]])
+tf.matrix_diag(diagonal) ==> [[[1, 0, 0, 0],  # Output shape: (2, 4, 4)
+                               [0, 2, 0, 0],
+                               [0, 0, 3, 0],
+                               [0, 0, 0, 4]],
+                              [[5, 0, 0, 0],
+                               [0, 6, 0, 0],
+                               [0, 0, 7, 0],
+                               [0, 0, 0, 8]]]
+
+# A superdiagonal (per batch).
+diagonal = np.array([[1, 2, 3],  # Input shape: (2, 3)
+                     [4, 5, 6]])
+tf.matrix_diag(diagonal, k = 1)
+  ==> [[[0, 1, 0, 0],  # Output shape: (2, 4, 4)
+        [0, 0, 2, 0],
+        [0, 0, 0, 3],
+        [0, 0, 0, 0]],
+       [[0, 4, 0, 0],
+        [0, 0, 5, 0],
+        [0, 0, 0, 6],
+        [0, 0, 0, 0]]]
+
+# A tridiagonal band (per batch).
+diagonals = np.array([[[0, 8, 9],  # Input shape: (2, 2, 3)
+                       [1, 2, 3],
+                       [4, 5, 0]],
+                      [[0, 2, 3],
+                       [6, 7, 9],
+                       [9, 1, 0]]])
+tf.matrix_diag(diagonals, k = (-1, 1))
+  ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
+        [4, 2, 9],
+        [0, 5, 3]],
+       [[6, 2, 0],
+        [9, 7, 3],
+        [0, 1, 9]]]
+
+# LEFT_RIGHT alignment.
+diagonals = np.array([[[8, 9, 0],  # Input shape: (2, 2, 3)
+                       [1, 2, 3],
+                       [0, 4, 5]],
+                      [[2, 3, 0],
+                       [6, 7, 9],
+                       [0, 9, 1]]])
+tf.matrix_diag(diagonals, k = (-1, 1), align="LEFT_RIGHT")
+  ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
+        [4, 2, 9],
+        [0, 5, 3]],
+       [[6, 2, 0],
+        [9, 7, 3],
+        [0, 1, 9]]]
+
+# Rectangular matrix.
+diagonal = np.array([1, 2])  # Input shape: (2)
+tf.matrix_diag(diagonal, k = -1, num_rows = 3, num_cols = 4)
+  ==> [[0, 0, 0, 0],  # Output shape: (3, 4)
+       [1, 0, 0, 0],
+       [0, 2, 0, 0]]
+
+# Rectangular matrix with inferred num_cols and padding_value = 9.
+tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9)
+  ==> [[9, 9],  # Output shape: (3, 2)
+       [1, 9],
+       [9, 2]]
+
+```
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixSetDiagV3.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixSetDiagV3.pbtxt
new file mode 100644
index 00000000000..b3b5980ad44
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MatrixSetDiagV3.pbtxt
@@ -0,0 +1,148 @@
+op {
+  graph_op_name: "MatrixSetDiagV3"
+  in_arg {
+    name: "input"
+    description: "Rank `r+1`, where `r >= 1`."
+  }
+  in_arg {
+    name: "diagonal"
+    description: <<END
+Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`.
+`k >= 1`.
+END
+  }
+  in_arg {
+    name: "k"
+    description: <<END
+Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main
+diagonal, and negative value means subdiagonals. `k` can be a single integer
+(for a single diagonal) or a pair of integers specifying the low and high ends
+of a matrix band. `k[0]` must not be larger than `k[1]`.
+END
+  }
+  out_arg {
+    name: "output"
+    description: <<END
+Rank `r+1`, with `output.shape = input.shape`.
+END
+  }
+  attr {
+    name: "align"
+    description: <<END
+Some diagonals are shorter than `max_diag_len` and need to be padded. `align` is
+a string specifying how superdiagonals and subdiagonals should be aligned,
+respectively. There are four possible alignments: "RIGHT_LEFT" (default),
+"LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT" aligns superdiagonals
+to the right (left-pads the row) and subdiagonals to the left (right-pads the
+row). It is the packing format LAPACK uses. cuSPARSE uses "LEFT_RIGHT", which is
+the opposite alignment.
+END
+  }
+  summary: "Returns a batched matrix tensor with new batched diagonal values."
+  description: <<END
+Given `input` and `diagonal`, this operation returns a tensor with the
+same shape and values as `input`, except for the specified diagonals of the
+innermost matrices. These will be overwritten by the values in `diagonal`.
+
+`input` has `r+1` dimensions `[I, J, ..., L, M, N]`. When `k` is scalar or
+`k[0] == k[1]`, `diagonal` has `r` dimensions `[I, J, ..., L, max_diag_len]`.
+Otherwise, it has `r+1` dimensions `[I, J, ..., L, num_diags, max_diag_len]`.
+`num_diags` is the number of diagonals, `num_diags = k[1] - k[0] + 1`.
+`max_diag_len` is the longest diagonal in the range `[k[0], k[1]]`,
+`max_diag_len = min(M + min(k[1], 0), N + min(-k[0], 0))`
+
+The output is a tensor of rank `k+1` with dimensions `[I, J, ..., L, M, N]`.
+If `k` is scalar or `k[0] == k[1]`:
+
+```
+output[i, j, ..., l, m, n]
+  = diagonal[i, j, ..., l, n-max(k[1], 0)] ; if n - m == k[1]
+    input[i, j, ..., l, m, n]              ; otherwise
+```
+
+Otherwise,
+
+```
+output[i, j, ..., l, m, n]
+  = diagonal[i, j, ..., l, diag_index, index_in_diag] ; if k[0] <= d <= k[1]
+    input[i, j, ..., l, m, n]                         ; otherwise
+```
+where `d = n - m`, `diag_index = k[1] - d`, and
+`index_in_diag = n - max(d, 0) + offset`.
+
+`offset` is zero except when the alignment of the diagonal is to the right.
+```
+offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
+                                           and `d >= 0`) or
+                                         (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
+                                           and `d <= 0`)
+         0                          ; otherwise
+```
+where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
+
+For example:
+
+```
+# The main diagonal.
+input = np.array([[[7, 7, 7, 7],              # Input shape: (2, 3, 4)
+                   [7, 7, 7, 7],
+                   [7, 7, 7, 7]],
+                  [[7, 7, 7, 7],
+                   [7, 7, 7, 7],
+                   [7, 7, 7, 7]]])
+diagonal = np.array([[1, 2, 3],               # Diagonal shape: (2, 3)
+                     [4, 5, 6]])
+tf.matrix_set_diag(input, diagonal)
+  ==> [[[1, 7, 7, 7],  # Output shape: (2, 3, 4)
+        [7, 2, 7, 7],
+        [7, 7, 3, 7]],
+       [[4, 7, 7, 7],
+        [7, 5, 7, 7],
+        [7, 7, 6, 7]]]
+
+# A superdiagonal (per batch).
+tf.matrix_set_diag(input, diagonal, k = 1)
+  ==> [[[7, 1, 7, 7],  # Output shape: (2, 3, 4)
+        [7, 7, 2, 7],
+        [7, 7, 7, 3]],
+       [[7, 4, 7, 7],
+        [7, 7, 5, 7],
+        [7, 7, 7, 6]]]
+
+# A band of diagonals.
+diagonals = np.array([[[0, 9, 1],  # Diagonal shape: (2, 4, 3)
+                       [6, 5, 8],
+                       [1, 2, 3],
+                       [4, 5, 0]],
+                      [[0, 1, 2],
+                       [5, 6, 4],
+                       [6, 1, 2],
+                       [3, 4, 0]]])
+tf.matrix_set_diag(input, diagonals, k = (-1, 2))
+  ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
+        [4, 2, 5, 1],
+        [7, 5, 3, 8]],
+       [[6, 5, 1, 7],
+        [3, 1, 6, 2],
+        [7, 4, 2, 4]]]
+
+# LEFT_RIGHT alignment.
+diagonals = np.array([[[9, 1, 0],  # Diagonal shape: (2, 4, 3)
+                       [6, 5, 8],
+                       [1, 2, 3],
+                       [0, 4, 5]],
+                      [[1, 2, 0],
+                       [5, 6, 4],
+                       [6, 1, 2],
+                       [0, 3, 4]]])
+tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT")
+  ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
+        [4, 2, 5, 1],
+        [7, 5, 3, 8]],
+       [[6, 5, 1, 7],
+        [3, 1, 6, 2],
+        [7, 4, 2, 4]]]
+
+```
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV2.pbtxt
index 753cfe5e235..180fa21ba6d 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV2.pbtxt
@@ -1,11 +1,4 @@
 op {
   graph_op_name: "MatrixDiagPartV2"
-  endpoint {
-    name: "linalg.diag_part"
-  }
-  endpoint {
-    name: "matrix_diag_part"
-    deprecation_version: 2
-  }
   visibility: HIDDEN
 }
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV3.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV3.pbtxt
new file mode 100644
index 00000000000..386502ceea2
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagPartV3.pbtxt
@@ -0,0 +1,11 @@
+op {
+  graph_op_name: "MatrixDiagPartV3"
+  endpoint {
+    name: "linalg.diag_part"
+  }
+  endpoint {
+    name: "matrix_diag_part"
+    deprecation_version: 2
+  }
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagV2.pbtxt
index 7f76ddba0c4..b6bbdf54a81 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixDiagV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagV2.pbtxt
@@ -1,11 +1,4 @@
 op {
   graph_op_name: "MatrixDiagV2"
-  endpoint {
-    name: "linalg.diag"
-  }
-  endpoint {
-    name: "matrix_diag"
-    deprecation_version: 2
-  }
   visibility: HIDDEN
 }
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixDiagV3.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixDiagV3.pbtxt
new file mode 100644
index 00000000000..e1ed2061f06
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixDiagV3.pbtxt
@@ -0,0 +1,11 @@
+op {
+  graph_op_name: "MatrixDiagV3"
+  endpoint {
+    name: "linalg.diag"
+  }
+  endpoint {
+    name: "matrix_diag"
+    deprecation_version: 2
+  }
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV2.pbtxt
index eed43da2bb9..ecf7353c9d8 100644
--- a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV2.pbtxt
@@ -1,11 +1,4 @@
 op {
   graph_op_name: "MatrixSetDiagV2"
-  endpoint {
-    name: "linalg.set_diag"
-  }
-  endpoint {
-    name: "matrix_set_diag"
-    deprecation_version: 2
-  }
   visibility: HIDDEN
 }
diff --git a/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV3.pbtxt b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV3.pbtxt
new file mode 100644
index 00000000000..a92ff977c58
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_MatrixSetDiagV3.pbtxt
@@ -0,0 +1,11 @@
+op {
+  graph_op_name: "MatrixSetDiagV3"
+  endpoint {
+    name: "linalg.set_diag"
+  }
+  endpoint {
+    name: "matrix_set_diag"
+    deprecation_version: 2
+  }
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index ea8678de5cf..e78f3970b89 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1189,6 +1189,254 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) {
   return Status::OK();
 }
 
+Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
+                     int32* lower_diag_index, int32* upper_diag_index) {
+  // This function assumes that the shape of diag_index_tensor is fully defined.
+  if (diag_index_tensor->dims() == 0) {
+    *lower_diag_index = diag_index_tensor->scalar<int32>()();
+    *upper_diag_index = *lower_diag_index;
+  } else {
+    int32 num_elements = diag_index_tensor->dim_size(0);
+    if (num_elements == 1) {
+      *lower_diag_index = diag_index_tensor->vec<int32>()(0);
+      *upper_diag_index = *lower_diag_index;
+    } else if (num_elements == 2) {
+      *lower_diag_index = diag_index_tensor->vec<int32>()(0);
+      *upper_diag_index = diag_index_tensor->vec<int32>()(1);
+    } else {
+      return errors::InvalidArgument(
+          "diag_index must be a vector with one or two elements. It has ",
+          num_elements, " elements.");
+    }
+  }
+  return Status::OK();
+}
+
+Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) {
+  ShapeHandle input_shape, diag_index_shape, unused_shape;
+  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
+  TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
+
+  const Tensor* diag_index_tensor = c->input_tensor(1);
+  if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
+      diag_index_tensor == nullptr) {
+    c->set_output(0, c->UnknownShape());
+    return Status::OK();
+  }
+  int32 lower_diag_index = 0;
+  int32 upper_diag_index = 0;
+  TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
+                                   &upper_diag_index));
+  if (lower_diag_index > upper_diag_index) {
+    return errors::InvalidArgument(
+        "lower_diag_index is greater than upper_diag_index");
+  }
+
+  // Validates lower_diag_index and upper_diag_index.
+  const int32 input_rank = c->Rank(input_shape);
+  const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
+  const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
+  if (num_rows != InferenceContext::kUnknownDim &&
+      num_cols != InferenceContext::kUnknownDim) {
+    if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
+        (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
+      return errors::InvalidArgument("lower_diag_index is out of bound.");
+    }
+    if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
+        (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
+      return errors::InvalidArgument("upper_diag_index is out of bound.");
+    }
+  }
+
+  const int32 max_diag_len = std::min(num_rows + std::min(upper_diag_index, 0),
+                                      num_cols - std::max(lower_diag_index, 0));
+  std::vector<DimensionHandle> dims;
+  dims.reserve(input_rank - 2);
+  for (int i = 0; i < input_rank - 2; ++i) {
+    dims.push_back(c->Dim(input_shape, i));
+  }
+  if (lower_diag_index < upper_diag_index) {
+    dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
+  }
+  dims.push_back(c->MakeDim(max_diag_len));
+  c->set_output(0, c->MakeShape(dims));
+  return Status::OK();
+}
+
+Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) {
+  // Checks input ranks.
+  ShapeHandle input_shape, diag_index_shape, unused_shape;
+  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
+  TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
+
+  // Reads the diagonal indices.
+  const Tensor* diag_index_tensor = c->input_tensor(1);
+  if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
+      diag_index_tensor == nullptr) {
+    c->set_output(0, c->UnknownShape());
+    return Status::OK();
+  }
+  int32 lower_diag_index = 0;
+  int32 upper_diag_index = 0;
+  TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
+                                   &upper_diag_index));
+  if (lower_diag_index > upper_diag_index) {
+    return errors::InvalidArgument(
+        "lower_diag_index is greater than upper_diag_index");
+  }
+
+  // Checks if the number of diagonals provided matches what we imply from
+  // lower_diag_index and upper_diag_index.
+  const int32 input_rank = c->Rank(input_shape);
+  if (lower_diag_index < upper_diag_index) {
+    const int32 num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
+    const int32 other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
+
+    if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
+      return errors::InvalidArgument(
+          "The number of rows of `diagonal` doesn't match the number of "
+          "diagonals implied from `d_lower` and `d_upper`.\n",
+          "num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
+          ", d_upper = ", upper_diag_index, " ", input_rank, " ", other_dim);
+    }
+  }
+
+  // Reads num_rows and num_cols.
+  const Tensor* num_rows_tensor = c->input_tensor(2);
+  const Tensor* num_cols_tensor = c->input_tensor(3);
+  int64 num_rows = -1;
+  int64 num_cols = -1;
+  if (num_rows_tensor != nullptr) {
+    TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
+  }
+  if (num_cols_tensor != nullptr) {
+    TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
+  }
+
+  // Infers the missing num_rows or num_cols: If both are missing, assume
+  // output is square. Otherwise, use the smallest possible value. Also
+  // validates the provided values.
+  const int32 max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
+  const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
+  const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
+  if (num_rows == -1 && num_cols == -1) {  // Special case.
+    num_rows = std::max(min_num_rows, min_num_cols);
+    num_cols = num_rows;
+  }
+  if (num_rows == -1) {
+    num_rows = min_num_rows;
+  } else if (num_rows < min_num_rows) {
+    return errors::InvalidArgument("num_rows is too small");
+  }
+  if (num_cols == -1) {
+    num_cols = min_num_cols;
+  } else if (num_cols < min_num_cols) {
+    return errors::InvalidArgument("num_cols is too small.");
+  }
+  // At least one of them must match the minimum length.
+  if (num_rows != min_num_rows && num_cols != min_num_cols) {
+    return errors::InvalidArgument(
+        "num_rows and num_cols are not consistent with lower_diag_index, "
+        "upper_diag_index, and the length of the given diagonals.\n",
+        "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
+        ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
+  }
+
+  // Sets output shape.
+  ShapeHandle output_shape;
+  const DimensionHandle output_row_dim = c->MakeDim(num_rows);
+  const DimensionHandle output_col_dim = c->MakeDim(num_cols);
+  if (lower_diag_index == upper_diag_index) {
+    TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
+                                     output_row_dim, &output_shape));
+    TF_RETURN_IF_ERROR(
+        c->Concatenate(output_shape, c->Vector(output_col_dim), &output_shape));
+  } else {
+    TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
+                                     output_row_dim, &output_shape));
+    TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
+                                     output_col_dim, &output_shape));
+  }
+  c->set_output(0, output_shape);
+  return Status::OK();
+}
+
+Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) {
+  ShapeHandle input_shape, diag_shape, diag_index_shape;
+  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
+  TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
+  TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
+
+  int32 lower_diag_index = 0;
+  int32 upper_diag_index = 0;
+  bool diag_index_known = false;
+  const Tensor* diag_index_tensor = c->input_tensor(2);
+  if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
+    diag_index_known = true;
+    TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
+                                     &upper_diag_index));
+    if (lower_diag_index > upper_diag_index) {
+      return errors::InvalidArgument(
+          "lower_diag_index is greater than upper_diag_index");
+    }
+  }
+
+  // Do more checks when input rank is known.
+  if (c->RankKnown(input_shape)) {
+    int32 input_rank = c->Rank(input_shape);
+
+    // If diag_index is set, we know the exact rank of diagonal.
+    if (diag_index_known) {
+      TF_RETURN_IF_ERROR(c->WithRank(
+          c->input(1),
+          (lower_diag_index == upper_diag_index) ? input_rank - 1 : input_rank,
+          &diag_shape));
+    } else {
+      TF_RETURN_IF_ERROR(
+          c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
+      TF_RETURN_IF_ERROR(
+          c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
+    }
+
+    // Validates lower_diag_index and upper_diag_index.
+    const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
+    const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
+    if (num_rows != InferenceContext::kUnknownDim &&
+        num_cols != InferenceContext::kUnknownDim) {
+      if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
+          (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
+        return errors::InvalidArgument("lower_diag_index is out of bound.");
+      }
+      if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
+          (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
+        return errors::InvalidArgument("upper_diag_index is out of bound.");
+      }
+    }
+  }
+
+  ShapeHandle output_shape = input_shape;
+  if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
+    // Try to infer parts of shape from diag.
+    ShapeHandle diag_prefix;
+    TF_RETURN_IF_ERROR(c->Subshape(
+        diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
+        &diag_prefix));
+
+    // The inner matrices can be rectangular, so we can't pinpoint their
+    // exact height and width by just lower_diag_index, upper_diag_index,
+    // and the longest length of given diagonals.
+    TF_RETURN_IF_ERROR(
+        c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
+    TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
+  }
+  c->set_output(0, output_shape);
+  return Status::OK();
+}
+
 Status MaxPoolShape(shape_inference::InferenceContext* c) {
   string data_format_str;
   TensorFormat data_format;
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 590aa98b60b..434948bafa2 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -270,6 +270,15 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c);
 // Shape function for FusedBatchNormGrad and FusedBatchNormGradV2 operations.
 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c);
 
+// Shape function for MatrixDiagPartV2 and MatrixDiagPartV3 operations.
+Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c);
+
+// Shape function for MatrixDiagV2 and MatrixDiagV3 operations.
+Status MatrixDiagV2Shape(shape_inference::InferenceContext* c);
+
+// Shape function for MatrixSetDiagV2 and MatrixSetDiagV3 operations.
+Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c);
+
 // Shape function for MaxPool-like operations.
 Status MaxPoolShape(shape_inference::InferenceContext* c);
 
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ac9c6299833..3945d5d7f55 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1158,7 +1158,7 @@ tf_kernel_library(
 tf_kernel_library(
     name = "matrix_set_diag_op",
     prefix = "matrix_set_diag_op",
-    deps = ARRAY_DEPS,
+    deps = ARRAY_DEPS + [":matrix_diag_op"],
 )
 
 tf_kernel_library(
diff --git a/tensorflow/core/kernels/matrix_diag_op.cc b/tensorflow/core/kernels/matrix_diag_op.cc
index ae69e7752f1..9796fd25e39 100644
--- a/tensorflow/core/kernels/matrix_diag_op.cc
+++ b/tensorflow/core/kernels/matrix_diag_op.cc
@@ -46,8 +46,13 @@ typedef Eigen::GpuDevice GPUDevice;
 template <typename Device, typename T>
 class MatrixDiagPartOp : public OpKernel {
  public:
-  explicit MatrixDiagPartOp(OpKernelConstruction* context)
-      : OpKernel(context) {}
+  explicit MatrixDiagPartOp(OpKernelConstruction* context) : OpKernel(context) {
+    // MatrixDiagPartV3-specific.
+    if (context->HasAttr("align")) {
+      functor::ReadAlignment(context, &left_align_superdiagonal_,
+                             &left_align_subdiagonal_);
+    }
+  }
 
   void Compute(OpKernelContext* context) override {
     const Tensor& input = context->input(0);
@@ -60,7 +65,7 @@ class MatrixDiagPartOp : public OpKernel {
     T padding_value(0);
 
     // MatrixDiagPartV2-specific.
-    if (context->num_inputs() > 1) {
+    if (context->num_inputs() > kNumV1Inputs) {
       auto& diag_index = context->input(1);
       OP_REQUIRES(context,
                   TensorShapeUtils::IsScalar(diag_index.shape()) ||
@@ -132,17 +137,26 @@ class MatrixDiagPartOp : public OpKernel {
     functor::MatrixDiagPart<Device, T>::Compute(
         context, context->eigen_device<Device>(), input_reshaped,
         output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
-        padding_value);
+        padding_value, left_align_superdiagonal_, left_align_subdiagonal_);
   }
 
  private:
+  bool left_align_superdiagonal_ = true;
+  bool left_align_subdiagonal_ = true;
+  static constexpr int kNumV1Inputs = 1;
   TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagPartOp);
 };
 
 template <typename Device, typename T>
 class MatrixDiagOp : public OpKernel {
  public:
-  explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) {}
+  explicit MatrixDiagOp(OpKernelConstruction* context) : OpKernel(context) {
+    // MatrixDiagV3-specific.
+    if (context->HasAttr("align")) {
+      functor::ReadAlignment(context, &left_align_superdiagonal_,
+                             &left_align_subdiagonal_);
+    }
+  }
 
   void Compute(OpKernelContext* context) override {
     const Tensor& diagonal = context->input(0);
@@ -157,7 +171,7 @@ class MatrixDiagOp : public OpKernel {
     T padding_value(0);
 
     // MatrixDiagOpV2-specific.
-    if (context->num_inputs() > 1) {
+    if (context->num_inputs() > kNumV1Inputs) {
       auto& diag_index = context->input(1);
       OP_REQUIRES(context,
                   TensorShapeUtils::IsScalar(diag_index.shape()) ||
@@ -242,10 +256,13 @@ class MatrixDiagOp : public OpKernel {
     functor::MatrixDiag<Device, T>::Compute(
         context, context->eigen_device<Device>(), diag_reshaped,
         output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
-        padding_value);
+        padding_value, left_align_superdiagonal_, left_align_subdiagonal_);
   }
 
  private:
+  bool left_align_superdiagonal_ = true;
+  bool left_align_subdiagonal_ = true;
+  static constexpr int kNumV1Inputs = 1;
   TF_DISALLOW_COPY_AND_ASSIGN(MatrixDiagOp);
 };
 
@@ -256,12 +273,19 @@ class MatrixDiagOp : public OpKernel {
   REGISTER_KERNEL_BUILDER(                                                   \
       Name("MatrixDiagV2").Device(DEVICE_CPU).TypeConstraint<type>("T"),     \
       MatrixDiagOp<CPUDevice, type>);                                        \
+  REGISTER_KERNEL_BUILDER(                                                   \
+      Name("MatrixDiagV3").Device(DEVICE_CPU).TypeConstraint<type>("T"),     \
+      MatrixDiagOp<CPUDevice, type>);                                        \
   REGISTER_KERNEL_BUILDER(                                                   \
       Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"),   \
       MatrixDiagPartOp<CPUDevice, type>);                                    \
   REGISTER_KERNEL_BUILDER(                                                   \
       Name("MatrixDiagPartV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+      MatrixDiagPartOp<CPUDevice, type>);                                    \
+  REGISTER_KERNEL_BUILDER(                                                   \
+      Name("MatrixDiagPartV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
       MatrixDiagPartOp<CPUDevice, type>);
+
 TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG);
 #undef REGISTER_MATRIX_DIAG
 
@@ -280,6 +304,28 @@ TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG);
 
 // Implementation of the functor specialization for CPU.
 namespace functor {
+
+void ReadAlignment(OpKernelConstruction* context,
+                   bool* left_align_superdiagonal,
+                   bool* left_align_subdiagonal) {
+  string align;
+  OP_REQUIRES_OK(context, context->GetAttr("align", &align));
+
+  *left_align_superdiagonal = align == "LEFT_LEFT" || align == "LEFT_RIGHT";
+  *left_align_subdiagonal = align == "LEFT_LEFT" || align == "RIGHT_LEFT";
+}
+
+std::pair<int, int> ComputeDiagLenAndContentOffset(
+    int diag_index, int max_diag_len, int num_rows, int num_cols,
+    bool left_align_superdiagonal, bool left_align_subdiagonal) {
+  const bool left_align = (diag_index >= 0 && left_align_superdiagonal) ||
+                          (diag_index <= 0 && left_align_subdiagonal);
+  const int diag_len = std::min(num_rows + std::min(0, diag_index),
+                                num_cols - std::max(0, diag_index));
+  const int content_offset = (left_align) ? 0 : (max_diag_len - diag_len);
+  return {diag_len, content_offset};
+}
+
 template <typename T>
 struct MatrixDiag<CPUDevice, T> {
   static void Compute(OpKernelContext* context, const CPUDevice& device,
@@ -287,16 +333,21 @@ struct MatrixDiag<CPUDevice, T> {
                       typename TTypes<T, 3>::Tensor& output,
                       const Eigen::Index lower_diag_index,
                       const Eigen::Index upper_diag_index,
-                      const Eigen::Index max_diag_len, const T padding_value) {
+                      const Eigen::Index max_diag_len, const T padding_value,
+                      const bool left_align_superdiagonal,
+                      const bool left_align_subdiagonal) {
     // 10 in cost_per_batch is from existing heuristic.
     // TODO(penporn): Tune for the best constant in cost_per_batch.
     const Eigen::Index num_batches = output.dimension(0);
-    const Eigen::Index cost_per_batch =
-        10 * output.dimension(1) * output.dimension(2);
+    const Eigen::Index num_rows = output.dimension(1);
+    const Eigen::Index num_cols = output.dimension(2);
+    const Eigen::Index cost_per_batch = 10 * num_rows * num_cols;
 
-    auto compute_shard = [&output, &diag, &lower_diag_index, &upper_diag_index,
-                          &max_diag_len, &padding_value](Eigen::Index begin,
-                                                         Eigen::Index end) {
+    auto compute_shard = [&output, &num_rows, &num_cols, &diag,
+                          &lower_diag_index, &upper_diag_index, &max_diag_len,
+                          &padding_value, &left_align_superdiagonal,
+                          &left_align_subdiagonal](Eigen::Index begin,
+                                                   Eigen::Index end) {
       const int num_diags = upper_diag_index - lower_diag_index + 1;
       const int diag_elements_in_batch = num_diags * max_diag_len;
       Eigen::Index diag_batch_base_index = begin * diag_elements_in_batch;
@@ -305,8 +356,12 @@ struct MatrixDiag<CPUDevice, T> {
           for (Eigen::Index j = 0; j < output.dimension(2); ++j) {
             const int diag_index = j - i;
             const int diag_index_in_input = upper_diag_index - diag_index;
+            int diag_len, content_offset;
+            std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
+                diag_index, max_diag_len, num_rows, num_cols,
+                left_align_superdiagonal, left_align_subdiagonal);
             const int index_in_the_diagonal =
-                j - std::max<Eigen::Index>(diag_index, 0);
+                j - std::max<Eigen::Index>(diag_index, 0) + content_offset;
             if (lower_diag_index <= diag_index &&
                 diag_index <= upper_diag_index) {
               output(batch, i, j) = diag(diag_batch_base_index +
@@ -334,7 +389,9 @@ struct MatrixDiagPart<CPUDevice, T> {
                       typename TTypes<T>::Tensor& output,
                       const Eigen::Index lower_diag_index,
                       const Eigen::Index upper_diag_index,
-                      const Eigen::Index max_diag_len, const T padding_value) {
+                      const Eigen::Index max_diag_len, const T padding_value,
+                      const bool left_align_superdiagonal,
+                      const bool left_align_subdiagonal) {
     // 10 in cost_per_batch is from existing heuristic.
     // TODO(penporn): Tune for the best constant in cost_per_batch.
     const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
@@ -346,24 +403,32 @@ struct MatrixDiagPart<CPUDevice, T> {
 
     auto compute_shard = [&output, &input, &num_rows, &num_cols,
                           &upper_diag_index, &max_diag_len, &num_diags,
-                          &output_elements_in_batch, &padding_value](
+                          &output_elements_in_batch, &padding_value,
+                          &left_align_superdiagonal, &left_align_subdiagonal](
                              Eigen::Index begin, Eigen::Index end) {
       Eigen::Index output_base_index = begin * output_elements_in_batch;
       for (Eigen::Index batch = begin; batch < end; ++batch) {
         for (Eigen::Index m = 0; m < num_diags; ++m) {
-          const Eigen::Index d = upper_diag_index - m;
-          Eigen::Index n = 0;
-          // Make two separate cases to save some index calculations.
-          if (d >= 0) {
-            for (; n < std::min(num_rows, num_cols - d); ++n) {
-              output(output_base_index + n) = input(batch, n, n + d);
-            }
-          } else {
-            for (; n < std::min(num_rows + d, num_cols); ++n) {
-              output(output_base_index + n) = input(batch, n - d, n);
-            }
+          const Eigen::Index diag_index = upper_diag_index - m;
+          Eigen::Index y_offset = std::max<Eigen::Index>(0, -diag_index);
+          Eigen::Index x_offset = std::max<Eigen::Index>(0, diag_index);
+          int diag_len, content_offset;
+          std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
+              diag_index, max_diag_len, num_rows, num_cols,
+              left_align_superdiagonal, left_align_subdiagonal);
+
+          // Fills the diagonal.
+          for (Eigen::Index n = 0; n < diag_len; ++n) {
+            output(output_base_index + content_offset + n) =
+                input(batch, n + y_offset, n + x_offset);
           }
-          for (; n < max_diag_len; ++n) {  // Padding.
+
+          // Padding.
+          const bool left_align = (content_offset == 0);
+          const Eigen::Index padding_start = (left_align) ? diag_len : 0;
+          const Eigen::Index padding_end =
+              (left_align) ? max_diag_len : content_offset;
+          for (Eigen::Index n = padding_start; n < padding_end; ++n) {
             output(output_base_index + n) = padding_value;
           }
           output_base_index += max_diag_len;
@@ -391,7 +456,8 @@ namespace functor {
       typename TTypes<T, 3>::Tensor& output,                                   \
       const Eigen::Index lower_diag_index,                                     \
       const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len,    \
-      const T padding_value);                                                  \
+      const T padding_value, const bool left_align_superdiagonal,              \
+      const bool left_align_subdiagonal);                                      \
   extern template struct MatrixDiag<GPUDevice, T>;                             \
   template <>                                                                  \
   void MatrixDiagPart<GPUDevice, T>::Compute(                                  \
@@ -399,7 +465,8 @@ namespace functor {
       typename TTypes<T, 3>::ConstTensor& input,                               \
       typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index, \
       const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len,    \
-      const T padding_value);                                                  \
+      const T padding_value, const bool left_align_superdiagonal,              \
+      const bool left_align_subdiagonal);                                      \
   extern template struct MatrixDiagPart<GPUDevice, T>;
 
 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
@@ -422,6 +489,14 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
                               .HostMemory("num_cols")                      \
                               .HostMemory("padding_value"),                \
                           MatrixDiagOp<GPUDevice, type>);                  \
+  REGISTER_KERNEL_BUILDER(Name("MatrixDiagV3")                             \
+                              .Device(DEVICE_GPU)                          \
+                              .TypeConstraint<type>("T")                   \
+                              .HostMemory("k")                             \
+                              .HostMemory("num_rows")                      \
+                              .HostMemory("num_cols")                      \
+                              .HostMemory("padding_value"),                \
+                          MatrixDiagOp<GPUDevice, type>);                  \
   REGISTER_KERNEL_BUILDER(                                                 \
       Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
       MatrixDiagPartOp<GPUDevice, type>);                                  \
@@ -430,6 +505,12 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
                               .TypeConstraint<type>("T")                   \
                               .HostMemory("k")                             \
                               .HostMemory("padding_value"),                \
+                          MatrixDiagPartOp<GPUDevice, type>);              \
+  REGISTER_KERNEL_BUILDER(Name("MatrixDiagPartV3")                         \
+                              .Device(DEVICE_GPU)                          \
+                              .TypeConstraint<type>("T")                   \
+                              .HostMemory("k")                             \
+                              .HostMemory("padding_value"),                \
                           MatrixDiagPartOp<GPUDevice, type>);
 
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);
diff --git a/tensorflow/core/kernels/matrix_diag_op.h b/tensorflow/core/kernels/matrix_diag_op.h
index 619fb5855eb..707fd9b6c14 100644
--- a/tensorflow/core/kernels/matrix_diag_op.h
+++ b/tensorflow/core/kernels/matrix_diag_op.h
@@ -19,6 +19,7 @@ limitations under the License.
 // Generator definition for MatrixDiagOp, must be compilable by nvcc.
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_types.h"
 #include "tensorflow/core/platform/types.h"
@@ -26,26 +27,43 @@ limitations under the License.
 namespace tensorflow {
 namespace functor {
 
+// Reads the diagonal packing alignment.
+void ReadAlignment(OpKernelConstruction* context,
+                   bool* left_align_superdiagonal,
+                   bool* left_align_subdiagonal);
+
+// Calculates diagonal length and content offset (from aligning) of a diagonal.
+// Returns a pair of integers {diag_len, content_offset}:
+//   - diag_len: The length of the diag_index-th diagonal.
+//   - content_offset: Each diagonal is stored as a row in the compact format.
+//     If the diagonal is shorter than max_diag_len, its content is aligned
+//     either to the left or right. content_offset is the index in the row
+//     where the first element of the diag-index-th diagonal is stored. It is
+//     always zero when the diagonal is left-aligned.
+std::pair<int, int> ComputeDiagLenAndContentOffset(
+    int diag_index, int max_diag_len, int num_rows, int num_cols,
+    bool left_align_superdiagonal, bool left_align_subdiagonal);
+
 template <typename Device, typename T>
 struct MatrixDiagPart {
   EIGEN_ALWAYS_INLINE static void Compute(
       OpKernelContext* context, const Device& device,
       typename TTypes<T, 3>::ConstTensor& input,
-      typename TTypes<T>::Tensor& output_original, const Eigen::Index d_lower,
-      const Eigen::Index d_upper, const Eigen::Index max_diag_len,
-      const T padding);
+      typename TTypes<T>::Tensor& output, const Eigen::Index lower_diag_index,
+      const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len,
+      const T padding_value, const bool left_align_superdiagonal,
+      const bool left_align_subdiagonal);
 };
 
 template <typename Device, typename T>
 struct MatrixDiag {
-  EIGEN_ALWAYS_INLINE static void Compute(OpKernelContext* context,
-                                          const Device& device,
-                                          typename TTypes<T>::ConstTensor& diag,
-                                          typename TTypes<T, 3>::Tensor& output,
-                                          const Eigen::Index d_lower,
-                                          const Eigen::Index d_upper,
-                                          const Eigen::Index max_diag_len,
-                                          const T padding);
+  EIGEN_ALWAYS_INLINE static void Compute(
+      OpKernelContext* context, const Device& device,
+      typename TTypes<T>::ConstTensor& diag,
+      typename TTypes<T, 3>::Tensor& output,
+      const Eigen::Index lower_diag_index, const Eigen::Index upper_diag_index,
+      const Eigen::Index max_diag_len, const T padding_value,
+      const bool left_align_superdiagonal, const bool left_align_subdiagonal);
 };
 
 }  // namespace functor
diff --git a/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc
index 9f6d4a0ea87..53cd2d2dc46 100644
--- a/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc
@@ -26,13 +26,28 @@ namespace functor {
 
 typedef Eigen::GpuDevice GPUDevice;
 
+__device__ inline int ComputeContentOffset(const int diag_index,
+                                           const int max_diag_len,
+                                           const int num_rows,
+                                           const int num_cols,
+                                           const bool left_align_superdiagonal,
+                                           const bool left_align_subdiagonal) {
+  const bool left_align = (diag_index >= 0 && left_align_superdiagonal) ||
+                          (diag_index <= 0 && left_align_subdiagonal);
+  if (left_align) return 0;
+  const int y_offset = min(0, diag_index);
+  const int x_offset = max(0, diag_index);
+  const int diag_len = min(num_rows + y_offset, num_cols - x_offset);
+  return max_diag_len - diag_len;
+}
+
 template <typename T>
-__global__ void MatrixDiagKernel(const int num_threads, const int num_rows,
-                                 const int num_cols, const int num_diags,
-                                 const int max_diag_len,
-                                 const int lower_diag_index,
-                                 const int upper_diag_index, const T padding,
-                                 const T* diag_ptr, T* output_ptr) {
+__global__ void MatrixDiagKernel(
+    const int num_threads, const int num_rows, const int num_cols,
+    const int num_diags, const int max_diag_len, const int lower_diag_index,
+    const int upper_diag_index, const T padding_value,
+    const bool left_align_superdiagonal, const bool left_align_subdiagonal,
+    const T* diag_ptr, T* output_ptr) {
   GPU_1D_KERNEL_LOOP(index, num_threads) {
     const int batch_and_row_index = index / num_cols;
     const int col = index - batch_and_row_index * num_cols;
@@ -40,13 +55,16 @@ __global__ void MatrixDiagKernel(const int num_threads, const int num_rows,
     const int row = batch_and_row_index - batch * num_rows;
     const int diag_index = col - row;
     const int diag_index_in_input = upper_diag_index - diag_index;
-    const int index_in_the_diagonal = col - max(diag_index, 0);
+    const int content_offset =
+        ComputeContentOffset(diag_index, max_diag_len, num_rows, num_cols,
+                             left_align_superdiagonal, left_align_subdiagonal);
+    const int index_in_the_diagonal = col - max(diag_index, 0) + content_offset;
     if (lower_diag_index <= diag_index && diag_index <= upper_diag_index) {
       output_ptr[index] =
           diag_ptr[batch * num_diags * max_diag_len +
                    diag_index_in_input * max_diag_len + index_in_the_diagonal];
     } else {
-      output_ptr[index] = padding;
+      output_ptr[index] = padding_value;
     }
   }
 }
@@ -58,7 +76,9 @@ struct MatrixDiag<GPUDevice, T> {
                       typename TTypes<T, 3>::Tensor& output,
                       const Eigen::Index lower_diag_index,
                       const Eigen::Index upper_diag_index,
-                      const Eigen::Index max_diag_len, const T padding) {
+                      const Eigen::Index max_diag_len, const T padding_value,
+                      const bool left_align_superdiagonal,
+                      const bool left_align_subdiagonal) {
     const int batch_size = output.dimension(0);
     const int num_rows = output.dimension(1);
     const int num_cols = output.dimension(2);
@@ -72,19 +92,19 @@ struct MatrixDiag<GPUDevice, T> {
     TF_CHECK_OK(GpuLaunchKernel(
         MatrixDiagKernel<T>, config.block_count, config.thread_per_block, 0,
         device.stream(), config.virtual_thread_count, num_rows, num_cols,
-        num_diags, max_diag_len, lower_diag_index, upper_diag_index, padding,
+        num_diags, max_diag_len, lower_diag_index, upper_diag_index,
+        padding_value, left_align_superdiagonal, left_align_subdiagonal,
         diag.data(), output.data()));
   }
 };
 
 template <typename T>
-__global__ void MatrixDiagPartKernel(const int num_threads, const int num_rows,
-                                     const int num_cols, const int num_diags,
-                                     const int max_diag_len,
-                                     const int lower_diag_index,
-                                     const int upper_diag_index,
-                                     const T padding, const T* input_ptr,
-                                     T* output_ptr) {
+__global__ void MatrixDiagPartKernel(
+    const int num_threads, const int num_rows, const int num_cols,
+    const int num_diags, const int max_diag_len, const int lower_diag_index,
+    const int upper_diag_index, const T padding_value,
+    const bool left_align_superdiagonal, const bool left_align_subdiagonal,
+    const T* input_ptr, T* output_ptr) {
   GPU_1D_KERNEL_LOOP(index, num_threads) {
     const int batch_and_mapped_diag_index = index / max_diag_len;
     const int index_in_the_diagonal =
@@ -93,13 +113,19 @@ __global__ void MatrixDiagPartKernel(const int num_threads, const int num_rows,
     const int mapped_diag_index =
         batch_and_mapped_diag_index - batch * num_diags;
     const int diag_index = upper_diag_index - mapped_diag_index;
-    const int y_index = index_in_the_diagonal + max(0, -diag_index);
-    const int x_index = index_in_the_diagonal + max(0, diag_index);
-    if (y_index < num_rows && x_index < num_cols) {
+    const int content_offset =
+        ComputeContentOffset(diag_index, max_diag_len, num_rows, num_cols,
+                             left_align_superdiagonal, left_align_subdiagonal);
+    const int y_offset = max(0, -diag_index);
+    const int x_offset = max(0, diag_index);
+    const int y_index = index_in_the_diagonal + y_offset - content_offset;
+    const int x_index = index_in_the_diagonal + x_offset - content_offset;
+    if (0 <= y_index && y_index < num_rows && 0 <= x_index &&
+        x_index < num_cols) {
       output_ptr[index] =
           input_ptr[batch * num_rows * num_cols + y_index * num_cols + x_index];
     } else {
-      output_ptr[index] = padding;
+      output_ptr[index] = padding_value;
     }
   }
 }
@@ -111,7 +137,9 @@ struct MatrixDiagPart<GPUDevice, T> {
                       typename TTypes<T>::Tensor& output,
                       const Eigen::Index lower_diag_index,
                       const Eigen::Index upper_diag_index,
-                      const Eigen::Index max_diag_len, const T padding) {
+                      const Eigen::Index max_diag_len, const T padding_value,
+                      const bool left_align_superdiagonal,
+                      const bool left_align_subdiagonal) {
     const int batch_size = input.dimension(0);
     const int num_rows = input.dimension(1);
     const int num_cols = input.dimension(2);
@@ -125,7 +153,8 @@ struct MatrixDiagPart<GPUDevice, T> {
     TF_CHECK_OK(GpuLaunchKernel(
         MatrixDiagPartKernel<T>, config.block_count, config.thread_per_block, 0,
         device.stream(), config.virtual_thread_count, num_rows, num_cols,
-        num_diags, max_diag_len, lower_diag_index, upper_diag_index, padding,
+        num_diags, max_diag_len, lower_diag_index, upper_diag_index,
+        padding_value, left_align_superdiagonal, left_align_subdiagonal,
         input.data(), output.data()));
   }
 };
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc
index 6507fca3403..2701ff788f7 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op.cc
@@ -30,6 +30,7 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_types.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/matrix_diag_op.h"
 #include "tensorflow/core/lib/core/threadpool.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
@@ -42,7 +43,13 @@ typedef Eigen::GpuDevice GPUDevice;
 template <typename Device, typename T>
 class MatrixSetDiagOp : public OpKernel {
  public:
-  explicit MatrixSetDiagOp(OpKernelConstruction* context) : OpKernel(context) {}
+  explicit MatrixSetDiagOp(OpKernelConstruction* context) : OpKernel(context) {
+    // MatrixSetDiagV3-specific.
+    if (context->HasAttr("align")) {
+      functor::ReadAlignment(context, &left_align_superdiagonal_,
+                             &left_align_subdiagonal_);
+    }
+  }
 
   void Compute(OpKernelContext* context) override {
     const Tensor& input = context->input(0);
@@ -55,7 +62,7 @@ class MatrixSetDiagOp : public OpKernel {
     int32 upper_diag_index = 0;
 
     // MatrixSetDiagV2-specific.
-    if (context->num_inputs() > 2) {
+    if (context->num_inputs() > kNumV1Inputs) {
       auto& diag_index = context->input(2);
       OP_REQUIRES(context,
                   TensorShapeUtils::IsScalar(diag_index.shape()) ||
@@ -155,10 +162,14 @@ class MatrixSetDiagOp : public OpKernel {
     auto output_reshaped = output->flat_inner_dims<T, 3>();
     functor::MatrixSetDiag<Device, T>::Compute(
         context, context->eigen_device<Device>(), input_reshaped, diag_reshaped,
-        output_reshaped, lower_diag_index, upper_diag_index, max_diag_len);
+        output_reshaped, lower_diag_index, upper_diag_index, max_diag_len,
+        left_align_superdiagonal_, left_align_subdiagonal_);
   }
 
  private:
+  bool left_align_superdiagonal_ = true;
+  bool left_align_subdiagonal_ = true;
+  static constexpr int kNumV1Inputs = 2;
   TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp);
 };
 
@@ -168,7 +179,11 @@ class MatrixSetDiagOp : public OpKernel {
       MatrixSetDiagOp<CPUDevice, type>);                                    \
   REGISTER_KERNEL_BUILDER(                                                  \
       Name("MatrixSetDiagV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+      MatrixSetDiagOp<CPUDevice, type>);                                    \
+  REGISTER_KERNEL_BUILDER(                                                  \
+      Name("MatrixSetDiagV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
       MatrixSetDiagOp<CPUDevice, type>);
+
 TF_CALL_POD_TYPES(REGISTER_MATRIX_SET_DIAG);
 #undef REGISTER_MATRIX_SET_DIAG
 
@@ -192,29 +207,38 @@ struct MatrixSetDiag<CPUDevice, T> {
                       typename TTypes<T, 3>::Tensor& output,
                       const Eigen::Index lower_diag_index,
                       const Eigen::Index upper_diag_index,
-                      const Eigen::Index max_diag_len) {
+                      const Eigen::Index max_diag_len,
+                      const bool left_align_superdiagonal,
+                      const bool left_align_subdiagonal) {
     if (input.data() != output.data()) {
       output.device(device) = input;
     }
     const Eigen::Index num_diags = upper_diag_index - lower_diag_index + 1;
     auto compute_shard = [&output, &diag, &upper_diag_index, &max_diag_len,
-                          &num_diags](Eigen::Index begin, Eigen::Index end) {
+                          &num_diags, &left_align_superdiagonal,
+                          &left_align_subdiagonal](Eigen::Index begin,
+                                                   Eigen::Index end) {
       const Eigen::Index num_rows = output.dimension(1);
       const Eigen::Index num_cols = output.dimension(2);
       Eigen::Index diag_base_index = begin * num_diags * max_diag_len;
       for (Eigen::Index batch = begin; batch < end; ++batch) {
         for (Eigen::Index m = 0; m < num_diags; ++m) {
-          const Eigen::Index d = upper_diag_index - m;
+          const Eigen::Index diag_index = upper_diag_index - m;
+          int diag_len, content_offset;
+          std::tie(diag_len, content_offset) = ComputeDiagLenAndContentOffset(
+              diag_index, max_diag_len, num_rows, num_cols,
+              left_align_superdiagonal, left_align_subdiagonal);
+
           // Make two separate cases to save some index calculations.
-          if (d >= 0) {
-            for (Eigen::Index n = 0; n < std::min(num_rows, num_cols - d);
-                 ++n) {
-              output(batch, n, n + d) = diag(diag_base_index + n);
+          if (diag_index >= 0) {
+            for (Eigen::Index n = 0; n < diag_len; ++n) {
+              output(batch, n, n + diag_index) =
+                  diag(diag_base_index + n + content_offset);
             }
           } else {
-            for (Eigen::Index n = 0; n < std::min(num_rows + d, num_cols);
-                 ++n) {
-              output(batch, n - d, n) = diag(diag_base_index + n);
+            for (Eigen::Index n = 0; n < diag_len; ++n) {
+              output(batch, n - diag_index, n) =
+                  diag(diag_base_index + n + content_offset);
             }
           }
           diag_base_index += max_diag_len;
@@ -236,15 +260,16 @@ struct MatrixSetDiag<CPUDevice, T> {
 
 // Forward declarations of the functor specializations for GPU.
 namespace functor {
-#define DECLARE_GPU_SPEC(T)                                                  \
-  template <>                                                                \
-  void MatrixSetDiag<GPUDevice, T>::Compute(                                 \
-      OpKernelContext* context, const GPUDevice& device,                     \
-      typename TTypes<T, 3>::ConstTensor& input,                             \
-      typename TTypes<T>::ConstTensor& diag,                                 \
-      typename TTypes<T, 3>::Tensor& output,                                 \
-      const Eigen::Index lower_diag_index,                                   \
-      const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len); \
+#define DECLARE_GPU_SPEC(T)                                                    \
+  template <>                                                                  \
+  void MatrixSetDiag<GPUDevice, T>::Compute(                                   \
+      OpKernelContext* context, const GPUDevice& device,                       \
+      typename TTypes<T, 3>::ConstTensor& input,                               \
+      typename TTypes<T>::ConstTensor& diag,                                   \
+      typename TTypes<T, 3>::Tensor& output,                                   \
+      const Eigen::Index lower_diag_index,                                     \
+      const Eigen::Index upper_diag_index, const Eigen::Index max_diag_len,    \
+      const bool left_align_superdiagonal, const bool left_align_subdiagonal); \
   extern template struct MatrixSetDiag<GPUDevice, T>;
 
 TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
@@ -263,7 +288,13 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
                               .Device(DEVICE_GPU)                         \
                               .TypeConstraint<type>("T")                  \
                               .HostMemory("k"),                           \
+                          MatrixSetDiagOp<GPUDevice, type>);              \
+  REGISTER_KERNEL_BUILDER(Name("MatrixSetDiagV3")                         \
+                              .Device(DEVICE_GPU)                         \
+                              .TypeConstraint<type>("T")                  \
+                              .HostMemory("k"),                           \
                           MatrixSetDiagOp<GPUDevice, type>);
+
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
 TF_CALL_bool(REGISTER_MATRIX_SET_DIAG_GPU);
 TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU);
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.h b/tensorflow/core/kernels/matrix_set_diag_op.h
index db30aaee669..04877cd34ca 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.h
+++ b/tensorflow/core/kernels/matrix_set_diag_op.h
@@ -29,8 +29,11 @@ struct MatrixSetDiag {
                       typename TTypes<T, 3>::ConstTensor& input,
                       typename TTypes<T>::ConstTensor& diag,
                       typename TTypes<T, 3>::Tensor& output,
-                      const Eigen::Index d_lower, const Eigen::Index d_upper,
-                      const Eigen::Index max_diag_len);
+                      const Eigen::Index lower_diag_index,
+                      const Eigen::Index upper_diag_index,
+                      const Eigen::Index max_diag_len,
+                      const bool left_align_superdiagonal,
+                      const bool left_align_subdiagonal);
 };
 
 }  // namespace functor
diff --git a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
index 55489073f93..4f742b90bff 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
@@ -26,26 +26,42 @@ namespace functor {
 
 typedef Eigen::GpuDevice GPUDevice;
 
+// TODO(penporn): Merge this file with matrix_diag_op_gpu.cu.cc.
+__device__ inline int ComputeContentOffset(const int diag_index,
+                                           const int max_diag_len,
+                                           const int num_rows,
+                                           const int num_cols,
+                                           const bool left_align_superdiagonal,
+                                           const bool left_align_subdiagonal) {
+  const bool left_align = (diag_index >= 0 && left_align_superdiagonal) ||
+                          (diag_index <= 0 && left_align_subdiagonal);
+  if (left_align) return 0;
+  const int y_offset = min(0, diag_index);
+  const int x_offset = max(0, diag_index);
+  const int diag_len = min(num_rows + y_offset, num_cols - x_offset);
+  return max_diag_len - diag_len;
+}
+
 template <typename Scalar>
-__global__ void MatrixSetDiagKernel(const int num_threads, const int m,
-                                    const int n, const int num_diags,
-                                    const int max_diag_len,
-                                    const int upper_diag_index,
-                                    const Scalar* __restrict__ diag_ptr,
-                                    Scalar* __restrict__ output_ptr) {
+__global__ void MatrixSetDiagKernel(
+    const int num_threads, const int m, const int n, const int num_diags,
+    const int max_diag_len, const int upper_diag_index,
+    const bool left_align_superdiagonal, const bool left_align_subdiagonal,
+    const Scalar* __restrict__ diag_ptr, Scalar* __restrict__ output_ptr) {
   GPU_1D_KERNEL_LOOP(index, num_threads) {
     const int batch_and_diag_index = index / max_diag_len;
-    const int index_in_the_diagonal =
-        index - batch_and_diag_index * max_diag_len;
+    int index_in_the_diagonal = index - batch_and_diag_index * max_diag_len;
     const int batch = batch_and_diag_index / num_diags;
     const int diag_index_in_input = batch_and_diag_index - batch * num_diags;
     const int diag_index = upper_diag_index - diag_index_in_input;
-    const int y_index = index_in_the_diagonal + max(0, -diag_index);
+    index_in_the_diagonal -=
+        ComputeContentOffset(diag_index, max_diag_len, m, n,
+                             left_align_superdiagonal, left_align_subdiagonal);
+    const int y_index = index_in_the_diagonal - min(0, diag_index);
     const int x_index = index_in_the_diagonal + max(0, diag_index);
 
     // Upper-bound checks for diagonals shorter than max_diag_len.
-    // y_index and x_index are nonnegative by construction.
-    if (y_index < m && x_index < n) {
+    if (index_in_the_diagonal >= 0 && y_index < m && x_index < n) {
       const int out_index = batch * m * n + y_index * n + x_index;
       output_ptr[out_index] = diag_ptr[index];
     }
@@ -56,17 +72,21 @@ template <typename Scalar>
 __global__ void MatrixCopyInputAndSetDiagKernel(
     const int num_threads, const int m, const int n, const int num_diags,
     const int max_diag_len, const int lower_diag_index,
-    const int upper_diag_index, const Scalar* __restrict__ input_ptr,
+    const int upper_diag_index, const bool left_align_superdiagonal,
+    const bool left_align_subdiagonal, const Scalar* __restrict__ input_ptr,
     const Scalar* __restrict__ diag_ptr, Scalar* __restrict__ output_ptr) {
   GPU_1D_KERNEL_LOOP(index, num_threads) {
     const int batch_and_row_index = index / n;
     const int col = index - batch_and_row_index * n;
     const int batch = batch_and_row_index / m;
     const int row = batch_and_row_index - batch * m;
-    const int d = col - row;
-    const int diag_index_in_input = upper_diag_index - d;
-    const int index_in_the_diagonal = col - max(d, 0);
-    if (lower_diag_index <= d && d <= upper_diag_index) {
+    const int diag_index = col - row;
+    const int diag_index_in_input = upper_diag_index - diag_index;
+    const int index_in_the_diagonal =
+        col - max(0, diag_index) +
+        ComputeContentOffset(diag_index, max_diag_len, m, n,
+                             left_align_superdiagonal, left_align_subdiagonal);
+    if (lower_diag_index <= diag_index && diag_index <= upper_diag_index) {
       output_ptr[index] =
           diag_ptr[batch * num_diags * max_diag_len +
                    diag_index_in_input * max_diag_len + index_in_the_diagonal];
@@ -84,7 +104,9 @@ struct MatrixSetDiag<GPUDevice, Scalar> {
                       typename TTypes<Scalar, 3>::Tensor& output,
                       const Eigen::Index lower_diag_index,
                       const Eigen::Index upper_diag_index,
-                      const Eigen::Index max_diag_len) {
+                      const Eigen::Index max_diag_len,
+                      const bool left_align_superdiagonal,
+                      const bool left_align_subdiagonal) {
     const int batch_size = input.dimension(0);
     const int m = input.dimension(1);
     const int n = input.dimension(2);
@@ -98,15 +120,16 @@ struct MatrixSetDiag<GPUDevice, Scalar> {
           MatrixSetDiagKernel<Scalar>, config.block_count,
           config.thread_per_block, 0, device.stream(),
           config.virtual_thread_count, m, n, num_diags, max_diag_len,
-          upper_diag_index, diag.data(), output.data()));
+          upper_diag_index, left_align_superdiagonal, left_align_subdiagonal,
+          diag.data(), output.data()));
     } else {
       GpuLaunchConfig config = GetGpuLaunchConfig(batch_size * m * n, device);
       TF_CHECK_OK(GpuLaunchKernel(
           MatrixCopyInputAndSetDiagKernel<Scalar>, config.block_count,
           config.thread_per_block, 0, device.stream(),
           config.virtual_thread_count, m, n, num_diags, max_diag_len,
-          lower_diag_index, upper_diag_index, input.data(), diag.data(),
-          output.data()));
+          lower_diag_index, upper_diag_index, left_align_superdiagonal,
+          left_align_subdiagonal, input.data(), diag.data(), output.data()));
     }
   }
 };
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index c48fecc6147..dbe357dbfe2 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -279,29 +279,6 @@ Status SetOutputShapeForReshape(InferenceContext* c) {
   return Status::OK();
 }
 
-Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor,
-                     int32* lower_diag_index, int32* upper_diag_index) {
-  // This function assumes that the shape of diag_index_tensor is fully defined.
-  if (diag_index_tensor->dims() == 0) {
-    *lower_diag_index = diag_index_tensor->scalar<int32>()();
-    *upper_diag_index = *lower_diag_index;
-  } else {
-    int32 num_elements = diag_index_tensor->dim_size(0);
-    if (num_elements == 1) {
-      *lower_diag_index = diag_index_tensor->vec<int32>()(0);
-      *upper_diag_index = *lower_diag_index;
-    } else if (num_elements == 2) {
-      *lower_diag_index = diag_index_tensor->vec<int32>()(0);
-      *upper_diag_index = diag_index_tensor->vec<int32>()(1);
-    } else {
-      return errors::InvalidArgument(
-          "diag_index must be a vector with one or two elements. It has ",
-          num_elements, " elements.");
-    }
-  }
-  return Status::OK();
-}
-
 }  // namespace
 
 REGISTER_OP("ParallelConcat")
@@ -861,107 +838,20 @@ REGISTER_OP("MatrixDiagV2")
     .Input("padding_value: T")
     .Output("output: T")
     .Attr("T: type")
-    .SetShapeFn([](InferenceContext* c) {
-      // Checks input ranks.
-      ShapeHandle input_shape, diag_index_shape, unused_shape;
-      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
-      TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
-      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
-      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
-      TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
+    .SetShapeFn(shape_inference::MatrixDiagV2Shape);
 
-      // Reads the diagonal indices.
-      const Tensor* diag_index_tensor = c->input_tensor(1);
-      if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
-          diag_index_tensor == nullptr) {
-        c->set_output(0, c->UnknownShape());
-        return Status::OK();
-      }
-      int32 lower_diag_index = 0;
-      int32 upper_diag_index = 0;
-      TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
-                                       &upper_diag_index));
-      if (lower_diag_index > upper_diag_index) {
-        return errors::InvalidArgument(
-            "lower_diag_index is greater than upper_diag_index");
-      }
-
-      // Checks if the number of diagonals provided matches what we imply from
-      // lower_diag_index and upper_diag_index.
-      const int32 input_rank = c->Rank(input_shape);
-      if (lower_diag_index < upper_diag_index) {
-        const int32 num_diags = c->Value(c->Dim(input_shape, input_rank - 2));
-        const int32 other_dim = c->Value(c->Dim(input_shape, input_rank - 1));
-
-        if (num_diags != (upper_diag_index - lower_diag_index + 1)) {
-          return errors::InvalidArgument(
-              "The number of rows of `diagonal` doesn't match the number of "
-              "diagonals implied from `d_lower` and `d_upper`.\n",
-              "num_diags = ", num_diags, ", d_lower = ", lower_diag_index,
-              ", d_upper = ", upper_diag_index, " ", input_rank, " ",
-              other_dim);
-        }
-      }
-
-      // Reads num_rows and num_cols.
-      const Tensor* num_rows_tensor = c->input_tensor(2);
-      const Tensor* num_cols_tensor = c->input_tensor(3);
-      int64 num_rows = -1;
-      int64 num_cols = -1;
-      if (num_rows_tensor != nullptr) {
-        TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_rows_tensor, &num_rows));
-      }
-      if (num_cols_tensor != nullptr) {
-        TF_RETURN_IF_ERROR(c->GetScalarFromTensor(num_cols_tensor, &num_cols));
-      }
-
-      // Infers the missing num_rows or num_cols: If both are missing, assume
-      // output is square. Otherwise, use the smallest possible value. Also
-      // validates the provided values.
-      const int32 max_diag_len = c->Value(c->Dim(input_shape, input_rank - 1));
-      const int32 min_num_rows = max_diag_len - std::min(upper_diag_index, 0);
-      const int32 min_num_cols = max_diag_len + std::max(lower_diag_index, 0);
-      if (num_rows == -1 && num_cols == -1) {  // Special case.
-        num_rows = std::max(min_num_rows, min_num_cols);
-        num_cols = num_rows;
-      }
-      if (num_rows == -1) {
-        num_rows = min_num_rows;
-      } else if (num_rows < min_num_rows) {
-        return errors::InvalidArgument("num_rows is too small");
-      }
-      if (num_cols == -1) {
-        num_cols = min_num_cols;
-      } else if (num_cols < min_num_cols) {
-        return errors::InvalidArgument("num_cols is too small.");
-      }
-      // At least one of them must match the minimum length.
-      if (num_rows != min_num_rows && num_cols != min_num_cols) {
-        return errors::InvalidArgument(
-            "num_rows and num_cols are not consistent with lower_diag_index, "
-            "upper_diag_index, and the length of the given diagonals.\n",
-            "num_rows = ", num_rows, " != min_num_rows = ", min_num_rows,
-            ", num_cols = ", num_cols, " != min_num_cols = ", min_num_cols);
-      }
-
-      // Sets output shape.
-      ShapeHandle output_shape;
-      const DimensionHandle output_row_dim = c->MakeDim(num_rows);
-      const DimensionHandle output_col_dim = c->MakeDim(num_cols);
-      if (lower_diag_index == upper_diag_index) {
-        TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 1,
-                                         output_row_dim, &output_shape));
-        TF_RETURN_IF_ERROR(c->Concatenate(
-            output_shape, c->Vector(output_col_dim), &output_shape));
-      } else {
-        TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, input_rank - 2,
-                                         output_row_dim, &output_shape));
-        TF_RETURN_IF_ERROR(c->ReplaceDim(output_shape, input_rank - 1,
-                                         output_col_dim, &output_shape));
-      }
-      c->set_output(0, output_shape);
-      return Status::OK();
-    });
+REGISTER_OP("MatrixDiagV3")
+    .Input("diagonal: T")
+    .Input("k: int32")
+    .Input("num_rows: int32")
+    .Input("num_cols: int32")
+    .Input("padding_value: T")
+    .Output("output: T")
+    .Attr("T: type")
+    .Attr(
+        "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
+        "'RIGHT_LEFT'")
+    .SetShapeFn(shape_inference::MatrixDiagV2Shape);
 
 // --------------------------------------------------------------------------
 REGISTER_OP("MatrixSetDiag")
@@ -995,84 +885,25 @@ REGISTER_OP("MatrixSetDiag")
       c->set_output(0, output);
       return Status::OK();
     });
+
 REGISTER_OP("MatrixSetDiagV2")
     .Input("input: T")
     .Input("diagonal: T")
     .Input("k: int32")
     .Output("output: T")
     .Attr("T: type")
-    .SetShapeFn([](InferenceContext* c) {
-      ShapeHandle input_shape, diag_shape, diag_index_shape;
-      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
-      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &diag_shape));
-      TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &diag_index_shape));
+    .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
 
-      int32 lower_diag_index = 0;
-      int32 upper_diag_index = 0;
-      bool diag_index_known = false;
-      const Tensor* diag_index_tensor = c->input_tensor(2);
-      if (diag_index_tensor != nullptr && c->FullyDefined(diag_index_shape)) {
-        diag_index_known = true;
-        TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor,
-                                         &lower_diag_index, &upper_diag_index));
-        if (lower_diag_index > upper_diag_index) {
-          return errors::InvalidArgument(
-              "lower_diag_index is greater than upper_diag_index");
-        }
-      }
-
-      // Do more checks when input rank is known.
-      if (c->RankKnown(input_shape)) {
-        int32 input_rank = c->Rank(input_shape);
-
-        // If diag_index is set, we know the exact rank of diagonal.
-        if (diag_index_known) {
-          TF_RETURN_IF_ERROR(c->WithRank(c->input(1),
-                                         (lower_diag_index == upper_diag_index)
-                                             ? input_rank - 1
-                                             : input_rank,
-                                         &diag_shape));
-        } else {
-          TF_RETURN_IF_ERROR(
-              c->WithRankAtLeast(c->input(1), input_rank - 1, &diag_shape));
-          TF_RETURN_IF_ERROR(
-              c->WithRankAtMost(c->input(1), input_rank, &diag_shape));
-        }
-
-        // Validates lower_diag_index and upper_diag_index.
-        const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
-        const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
-        if (num_rows != InferenceContext::kUnknownDim &&
-            num_cols != InferenceContext::kUnknownDim) {
-          if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
-              (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
-            return errors::InvalidArgument("lower_diag_index is out of bound.");
-          }
-          if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
-              (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
-            return errors::InvalidArgument("upper_diag_index is out of bound.");
-          }
-        }
-      }
-
-      ShapeHandle output_shape = input_shape;
-      if (c->RankKnown(diag_shape) && !c->FullyDefined(input_shape)) {
-        // Try to infer parts of shape from diag.
-        ShapeHandle diag_prefix;
-        TF_RETURN_IF_ERROR(c->Subshape(
-            diag_shape, 0, (lower_diag_index == upper_diag_index) ? -1 : -2,
-            &diag_prefix));
-
-        // The inner matrices can be rectangular, so we can't pinpoint their
-        // exact height and width by just lower_diag_index, upper_diag_index,
-        // and the longest length of given diagonals.
-        TF_RETURN_IF_ERROR(
-            c->Concatenate(diag_prefix, c->UnknownShapeOfRank(2), &diag_shape));
-        TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape));
-      }
-      c->set_output(0, output_shape);
-      return Status::OK();
-    });
+REGISTER_OP("MatrixSetDiagV3")
+    .Input("input: T")
+    .Input("diagonal: T")
+    .Input("k: int32")
+    .Output("output: T")
+    .Attr("T: type")
+    .Attr(
+        "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
+        "'RIGHT_LEFT'")
+    .SetShapeFn(shape_inference::MatrixSetDiagV2Shape);
 
 // --------------------------------------------------------------------------
 REGISTER_OP("MatrixDiagPart")
@@ -1105,58 +936,18 @@ REGISTER_OP("MatrixDiagPartV2")
     .Input("padding_value: T")
     .Output("diagonal: T")
     .Attr("T: type")
-    .SetShapeFn([](InferenceContext* c) {
-      ShapeHandle input_shape, diag_index_shape, unused_shape;
-      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
-      TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &diag_index_shape));
-      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_shape));
+    .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
 
-      const Tensor* diag_index_tensor = c->input_tensor(1);
-      if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) ||
-          diag_index_tensor == nullptr) {
-        c->set_output(0, c->UnknownShape());
-        return Status::OK();
-      }
-      int32 lower_diag_index = 0;
-      int32 upper_diag_index = 0;
-      TF_RETURN_IF_ERROR(ReadDiagIndex(c, diag_index_tensor, &lower_diag_index,
-                                       &upper_diag_index));
-      if (lower_diag_index > upper_diag_index) {
-        return errors::InvalidArgument(
-            "lower_diag_index is greater than upper_diag_index");
-      }
-
-      // Validates lower_diag_index and upper_diag_index.
-      const int32 input_rank = c->Rank(input_shape);
-      const int32 num_rows = c->Value(c->Dim(input_shape, input_rank - 2));
-      const int32 num_cols = c->Value(c->Dim(input_shape, input_rank - 1));
-      if (num_rows != InferenceContext::kUnknownDim &&
-          num_cols != InferenceContext::kUnknownDim) {
-        if (lower_diag_index != 0 &&  // For when num_rows or num_cols == 0.
-            (-num_rows >= lower_diag_index || lower_diag_index >= num_cols)) {
-          return errors::InvalidArgument("lower_diag_index is out of bound.");
-        }
-        if (upper_diag_index != 0 &&  // For when num_rows or num_cols == 0.
-            (-num_rows >= upper_diag_index || upper_diag_index >= num_cols)) {
-          return errors::InvalidArgument("upper_diag_index is out of bound.");
-        }
-      }
-
-      const int32 max_diag_len =
-          std::min(num_rows + std::min(upper_diag_index, 0),
-                   num_cols - std::max(lower_diag_index, 0));
-      std::vector<DimensionHandle> dims;
-      dims.reserve(input_rank - 2);
-      for (int i = 0; i < input_rank - 2; ++i) {
-        dims.push_back(c->Dim(input_shape, i));
-      }
-      if (lower_diag_index < upper_diag_index) {
-        dims.push_back(c->MakeDim(upper_diag_index - lower_diag_index + 1));
-      }
-      dims.push_back(c->MakeDim(max_diag_len));
-      c->set_output(0, c->MakeShape(dims));
-      return Status::OK();
-    });
+REGISTER_OP("MatrixDiagPartV3")
+    .Input("input: T")
+    .Input("k: int32")
+    .Input("padding_value: T")
+    .Output("diagonal: T")
+    .Attr("T: type")
+    .Attr(
+        "align: {'LEFT_RIGHT', 'RIGHT_LEFT', 'LEFT_LEFT', 'RIGHT_RIGHT'} = "
+        "'RIGHT_LEFT'")
+    .SetShapeFn(shape_inference::MatrixDiagPartV2Shape);
 
 // --------------------------------------------------------------------------
 REGISTER_OP("MatrixBandPart")
diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD
index 0a579b48d47..82e362f514e 100644
--- a/tensorflow/lite/toco/BUILD
+++ b/tensorflow/lite/toco/BUILD
@@ -180,8 +180,8 @@ cc_library(
     name = "graph_transformations",
     srcs = [
         "graph_transformations/convert_expanddims_to_reshape.cc",
-        "graph_transformations/convert_matrix_diag_v2_to_v1.cc",
-        "graph_transformations/convert_matrix_set_diag_v2_to_v1.cc",
+        "graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc",
+        "graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc",
         "graph_transformations/convert_pure_conv_to_depthwise.cc",
         "graph_transformations/convert_reorder_axes.cc",
         "graph_transformations/convert_squeeze_to_reshape.cc",
diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc
similarity index 88%
rename from tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_to_v1.cc
rename to tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc
index 5037b05f6fd..c0dbea05a26 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_to_v1.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc
@@ -20,13 +20,16 @@ limitations under the License.
 
 namespace toco {
 
-::tensorflow::Status ConvertMatrixDiagV2ToV1::Run(Model* model,
-                                                  std::size_t op_index,
-                                                  bool* modified) {
+// V3 is only different from V2 because it has an extra attribute (align).
+// This attribute doesn't affect V1 so we don't have to keep track of it here.
+::tensorflow::Status ConvertMatrixDiagV2OrV3ToV1::Run(Model* model,
+                                                      std::size_t op_index,
+                                                      bool* modified) {
   *modified = false;
   auto it = model->operators.begin() + op_index;
   const auto* op = it->get();
-  if (op->type != OperatorType::kMatrixDiagV2) {
+  if (op->type != OperatorType::kMatrixDiagV2 &&
+      op->type != OperatorType::kMatrixDiagV3) {
     return ::tensorflow::Status::OK();
   }
 
diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc
similarity index 83%
rename from tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_to_v1.cc
rename to tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc
index 61288f626b6..a3801398d71 100644
--- a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_to_v1.cc
+++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc
@@ -26,13 +26,16 @@ limitations under the License.
 
 namespace toco {
 
-::tensorflow::Status ConvertMatrixSetDiagV2ToV1::Run(Model* model,
-                                                     std::size_t op_index,
-                                                     bool* modified) {
+// V3 is only different from V2 because it has an extra attribute (align).
+// This attribute doesn't affect V1 so we don't have to keep track of it here.
+::tensorflow::Status ConvertMatrixSetDiagV2OrV3ToV1::Run(Model* model,
+                                                         std::size_t op_index,
+                                                         bool* modified) {
   *modified = false;
   auto it = model->operators.begin() + op_index;
   const auto* op = it->get();
-  if (op->type != OperatorType::kMatrixSetDiagV2) {
+  if (op->type != OperatorType::kMatrixSetDiagV2 &&
+      op->type != OperatorType::kMatrixSetDiagV3) {
     return ::tensorflow::Status::OK();
   }
 
diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
index 4a1334f41f1..0b765b1f507 100644
--- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
@@ -123,8 +123,8 @@ inline void RunGraphTransformations(
 
 // List of all graph transformations
 DECLARE_GRAPH_TRANSFORMATION(ConvertExpandDimsToReshape)
-DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2ToV1)
-DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2ToV1)
+DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixSetDiagV2OrV3ToV1)
+DECLARE_GRAPH_TRANSFORMATION(ConvertMatrixDiagV2OrV3ToV1)
 DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
 DECLARE_GRAPH_TRANSFORMATION(ConvertReorderAxes)
 DECLARE_GRAPH_TRANSFORMATION(ConvertSqueezeToReshape)
diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index d7a56e6d4b2..4b1a6fab607 100644
--- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -2434,6 +2434,14 @@ void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
       // MatrixDiagV2 operators are converted to MatrixDiag, after which their
       // shapes are propagated.
       break;
+    case OperatorType::kMatrixDiagV3:
+      // MatrixDiagV3 operators are converted to MatrixDiag, after which their
+      // shapes are propagated.
+      break;
+    case OperatorType::kMatrixSetDiagV3:
+      // MatrixSetDiagV3 operators are converted to MatrixSetDiag, after which
+      // their shapes are propagated.
+      break;
     default:
       // Unimplemented, another graph transformation should drop it.
       LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index fa8cc9799e1..dd7a9e3d835 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -2560,8 +2560,16 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
       {"MatMul", ConvertMatMulOperator},
       {"MatrixDiag", ConvertSimpleOperator<MatrixDiagOperator, 1, 1>},
       {"MatrixDiagV2", ConvertSimpleOperator<MatrixDiagV2Operator, 5, 1>},
+      // `MatrixDiagV3` has an `align` attribute. However, Toco only converts
+      // `MatrixDiagV3` to `MatrixDiag` with default `k, num_rows, num_cols,
+      // padding_value` inputs. In this case, `align` can be ignored.
+      {"MatrixDiagV3", ConvertSimpleOperator<MatrixDiagV3Operator, 5, 1>},
       {"MatrixSetDiag", ConvertSimpleOperator<MatrixSetDiagOperator, 2, 1>},
       {"MatrixSetDiagV2", ConvertSimpleOperator<MatrixSetDiagV2Operator, 3, 1>},
+      // `MatrixSetDiagV3` has an `align` attribute. However, Toco only converts
+      // `MatrixSetDiagV3` to `MatrixSetDiag` with default `k` inputs. In this
+      // case, `align` can be ignored.
+      {"MatrixSetDiagV3", ConvertSimpleOperator<MatrixSetDiagV3Operator, 3, 1>},
       {"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
       {"MaxPool", ConvertMaxPoolOperator},
       {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2, 1>},
diff --git a/tensorflow/lite/toco/model.h b/tensorflow/lite/toco/model.h
index ef717ee4e18..d8f4b73115c 100644
--- a/tensorflow/lite/toco/model.h
+++ b/tensorflow/lite/toco/model.h
@@ -175,7 +175,9 @@ enum class OperatorType : uint8 {
   kMatrixDiag,
   kMatrixSetDiag,
   kMatrixDiagV2,
-  kMatrixSetDiagV2
+  kMatrixSetDiagV2,
+  kMatrixDiagV3,
+  kMatrixSetDiagV3
 };
 
 // Helper to deal with TensorFlow arrays using a different ordering of
@@ -2122,12 +2124,24 @@ struct MatrixDiagOperator : Operator {
 
 // Matrix Diag Operator V2:
 // Construct a batched diagonal tensor with given batched diagonal values.
-// Not fully supported, constains 4 extra inputs compared to MatrixDiag, support
-// default parameters settings which performs the same as MatrixDiag
+// Not fully supported, contains 4 extra inputs compared to MatrixDiag. Behave
+// like MatrixDiag when default parameters are used.
 struct MatrixDiagV2Operator : Operator {
   MatrixDiagV2Operator() : Operator(OperatorType::kMatrixDiagV2) {}
 };
 
+// Matrix Diag Operator V3:
+// Construct a batched diagonal tensor with given batched diagonal values.
+// Not fully supported, contains 5 extra inputs compared to MatrixDiag. Behave
+// like MatrixDiag when default parameters are used.
+// V3 is only different from V2 because it has an extra attribute (align) which
+// controls the alignment of diagonals in the band matrix (compact) format.
+// The alignment in V2 contradicts with the default alignment in V3 so V2 is
+// skipped. (It has never been, and should never be, exposed in the public API.)
+struct MatrixDiagV3Operator : Operator {
+  MatrixDiagV3Operator() : Operator(OperatorType::kMatrixDiagV3) {}
+};
+
 // Matrix Set Diag Operator:
 // Construct a batched diagonal tensor with given input and diagonal values.
 // Input is a rank (k+1) tensor of values.
@@ -2140,12 +2154,24 @@ struct MatrixSetDiagOperator : Operator {
 
 // Matrix Set Diag Operator V2:
 // Construct a batched diagonal tensor with given input and diagonal values.
-// Not fully supported, constains 1 extra inputs compared to MatrixSetDiag,
-// support default parameters settings which performs the same as MatrixSetDiag
+// Not fully supported, contains 1 extra inputs compared to MatrixSetDiag.
+// Behave like MatrixSetDiag when default parameters are used.
 struct MatrixSetDiagV2Operator : Operator {
   MatrixSetDiagV2Operator() : Operator(OperatorType::kMatrixSetDiagV2) {}
 };
 
+// Matrix Set Diag Operator V3:
+// Construct a batched diagonal tensor with given input and diagonal values.
+// Not fully supported, contains 2 extra inputs compared to MatrixSetDiag.
+// Behave like MatrixSetDiag when default parameters are used.
+// V3 is only different from V2 because it has an extra attribute (align) which
+// controls the alignment of diagonals in the band matrix (compact) format.
+// The alignment in V2 contradicts with the default alignment in V3 so V2 is
+// skipped. (It has never been, and should never be, exposed in the public API.)
+struct MatrixSetDiagV3Operator : Operator {
+  MatrixSetDiagV3Operator() : Operator(OperatorType::kMatrixSetDiagV3) {}
+};
+
 // Alloc's are used for transient arrays only. An Alloc specifies which interval
 // of the "transient_data" workspace buffer passed to inference functions, is to
 // be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc
index 5a448131820..f96c6b83025 100644
--- a/tensorflow/lite/toco/toco_tooling.cc
+++ b/tensorflow/lite/toco/toco_tooling.cc
@@ -54,8 +54,8 @@ void MakeGeneralGraphTransformationsSet(
     GraphTransformationsSet* transformations) {
   CHECK(transformations->empty());
   transformations->Add(new ConvertExpandDimsToReshape);
-  transformations->Add(new ConvertMatrixDiagV2ToV1);
-  transformations->Add(new ConvertMatrixSetDiagV2ToV1);
+  transformations->Add(new ConvertMatrixDiagV2OrV3ToV1);
+  transformations->Add(new ConvertMatrixSetDiagV2OrV3ToV1);
   transformations->Add(new ConvertSqueezeToReshape);
   transformations->Add(new ConvertTrivialAddNToAdd);
   transformations->Add(new ConvertTrivialPackToReshape);
diff --git a/tensorflow/lite/toco/tooling_util.cc b/tensorflow/lite/toco/tooling_util.cc
index 418361cd0e7..ebcb17599b1 100644
--- a/tensorflow/lite/toco/tooling_util.cc
+++ b/tensorflow/lite/toco/tooling_util.cc
@@ -449,6 +449,8 @@ const char* OperatorTypeName(OperatorType type) {
     HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiag)
     HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV2)
     HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV2)
+    HANDLE_OPERATORTYPENAME_CASE(MatrixDiagV3)
+    HANDLE_OPERATORTYPENAME_CASE(MatrixSetDiagV3)
     default:
       LOG(FATAL) << "Unhandled op type";
 #undef HANDLE_OPERATORTYPENAME_CASE
diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py
index afa0850b853..df8409f09e6 100644
--- a/tensorflow/python/kernel_tests/diag_op_test.py
+++ b/tensorflow/python/kernel_tests/diag_op_test.py
@@ -33,8 +33,80 @@ from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging
 
 
+# LINT.IfChange
+matrix_diag_v3_forward_compat_date = (2019, 12, 6)
+# LINT.ThenChange(
+#   //tensorflow/compiler/tests/matrix_diag_ops_test.py,
+#   //tensorflow/python/ops/array_ops.py,
+#   //tensorflow/python/ops/parallel_for/array_test.py
+# )
+
+
+default_v2_alignment = "LEFT_LEFT"
+alignment_list = ["RIGHT_LEFT", "LEFT_RIGHT", "LEFT_LEFT", "RIGHT_RIGHT"]
+
+
+def zip_to_first_list_length(a, b):
+  if len(b) > len(a):
+    return zip(a, b[:len(a)])
+  return zip(a, b + [None] * (len(a) - len(b)))
+
+
+def repack_diagonals(packed_diagonals,
+                     diag_index,
+                     num_rows,
+                     num_cols,
+                     align=None):
+  # The original test cases are LEFT_LEFT aligned.
+  if align == default_v2_alignment or align is None:
+    return packed_diagonals
+
+  align = align.split("_")
+  d_lower, d_upper = diag_index
+  batch_dims = packed_diagonals.ndim - (2 if d_lower < d_upper else 1)
+  max_diag_len = packed_diagonals.shape[-1]
+  index = (slice(None),) * batch_dims
+  repacked_diagonals = np.zeros_like(packed_diagonals)
+
+  # Aligns each diagonal row-by-row.
+  for diag_index in range(d_lower, d_upper + 1):
+    diag_len = min(num_rows + min(0, diag_index), num_cols - max(0, diag_index))
+    row_index = d_upper - diag_index
+    padding_len = max_diag_len - diag_len
+    left_align = (diag_index >= 0 and
+                  align[0] == "LEFT") or (diag_index <= 0 and
+                                          align[1] == "LEFT")
+    # Prepares index tuples.
+    extra_dim = tuple() if d_lower == d_upper else (row_index,)
+    packed_last_dim = (slice(None),) if left_align else (slice(0, diag_len, 1),)
+    repacked_last_dim = (slice(None),) if left_align else (slice(
+        padding_len, max_diag_len, 1),)
+    packed_index = index + extra_dim + packed_last_dim
+    repacked_index = index + extra_dim + repacked_last_dim
+
+    # Repacks the diagonal.
+    repacked_diagonals[repacked_index] = packed_diagonals[packed_index]
+  return repacked_diagonals
+
+
+def repack_diagonals_in_tests(tests, align=None):
+  # The original test cases are LEFT_LEFT aligned.
+  if align == default_v2_alignment or align is None:
+    return tests
+
+  new_tests = dict()
+  # Loops through each case.
+  for diag_index, (packed_diagonals, padded_diagonals) in tests.items():
+    num_rows, num_cols = padded_diagonals.shape[-2:]
+    repacked_diagonals = repack_diagonals(
+        packed_diagonals, diag_index, num_rows, num_cols, align=align)
+    new_tests[diag_index] = (repacked_diagonals, padded_diagonals)
+
+  return new_tests
+
+
 # Test cases shared by MatrixDiagV2, MatrixDiagPartV2, and MatrixSetDiagV2.
-def square_cases():
+def square_cases(align=None):
   # pyformat: disable
   mat = np.array([[[1, 2, 3, 4, 5],
                    [6, 7, 8, 9, 1],
@@ -47,71 +119,72 @@ def square_cases():
                    [6, 7, 8, 9, 1],
                    [2, 3, 4, 5, 6]]])
   tests = dict()
-  tests[(-1, -1)] = (np.array([[6, 4, 1, 7],
-                               [5, 2, 8, 5]]),
-                     np.array([[[0, 0, 0, 0, 0],
-                                [6, 0, 0, 0, 0],
-                                [0, 4, 0, 0, 0],
-                                [0, 0, 1, 0, 0],
-                                [0, 0, 0, 7, 0]],
-                               [[0, 0, 0, 0, 0],
-                                [5, 0, 0, 0, 0],
-                                [0, 2, 0, 0, 0],
-                                [0, 0, 8, 0, 0],
-                                [0, 0, 0, 5, 0]]]))
-  tests[(-4, -3)] = (np.array([[[8, 5],
-                                [4, 0]],
-                               [[6, 3],
-                                [2, 0]]]),
-                     np.array([[[0, 0, 0, 0, 0],
-                                [0, 0, 0, 0, 0],
-                                [0, 0, 0, 0, 0],
-                                [8, 0, 0, 0, 0],
-                                [4, 5, 0, 0, 0]],
-                               [[0, 0, 0, 0, 0],
-                                [0, 0, 0, 0, 0],
-                                [0, 0, 0, 0, 0],
-                                [6, 0, 0, 0, 0],
-                                [2, 3, 0, 0, 0]]]))
-  tests[(-2, 1)] = (np.array([[[2, 8, 6, 3, 0],
-                               [1, 7, 5, 2, 8],
-                               [6, 4, 1, 7, 0],
-                               [3, 9, 6, 0, 0]],
-                              [[1, 7, 4, 1, 0],
-                               [9, 6, 3, 9, 6],
-                               [5, 2, 8, 5, 0],
-                               [1, 7, 4, 0, 0]]]),
-                    np.array([[[1, 2, 0, 0, 0],
-                               [6, 7, 8, 0, 0],
-                               [3, 4, 5, 6, 0],
-                               [0, 9, 1, 2, 3],
-                               [0, 0, 6, 7, 8]],
-                              [[9, 1, 0, 0, 0],
-                               [5, 6, 7, 0, 0],
-                               [1, 2, 3, 4, 0],
-                               [0, 7, 8, 9, 1],
-                               [0, 0, 4, 5, 6]]]))
-  tests[(2, 4)] = (np.array([[[5, 0, 0],
-                              [4, 1, 0],
-                              [3, 9, 7]],
-                             [[4, 0, 0],
-                              [3, 9, 0],
-                              [2, 8, 5]]]),
-                   np.array([[[0, 0, 3, 4, 5],
-                              [0, 0, 0, 9, 1],
-                              [0, 0, 0, 0, 7],
+  # tests[d_lower, d_upper] = (packed_diagonals, padded_diagonals)
+  tests[-1, -1] = (np.array([[6, 4, 1, 7],
+                             [5, 2, 8, 5]]),
+                   np.array([[[0, 0, 0, 0, 0],
+                              [6, 0, 0, 0, 0],
+                              [0, 4, 0, 0, 0],
+                              [0, 0, 1, 0, 0],
+                              [0, 0, 0, 7, 0]],
+                             [[0, 0, 0, 0, 0],
+                              [5, 0, 0, 0, 0],
+                              [0, 2, 0, 0, 0],
+                              [0, 0, 8, 0, 0],
+                              [0, 0, 0, 5, 0]]]))
+  tests[-4, -3] = (np.array([[[8, 5],
+                              [4, 0]],
+                             [[6, 3],
+                              [2, 0]]]),
+                   np.array([[[0, 0, 0, 0, 0],
                               [0, 0, 0, 0, 0],
-                              [0, 0, 0, 0, 0]],
-                             [[0, 0, 2, 3, 4],
-                              [0, 0, 0, 8, 9],
-                              [0, 0, 0, 0, 5],
                               [0, 0, 0, 0, 0],
-                              [0, 0, 0, 0, 0]]]))
+                              [8, 0, 0, 0, 0],
+                              [4, 5, 0, 0, 0]],
+                             [[0, 0, 0, 0, 0],
+                              [0, 0, 0, 0, 0],
+                              [0, 0, 0, 0, 0],
+                              [6, 0, 0, 0, 0],
+                              [2, 3, 0, 0, 0]]]))
+  tests[-2, 1] = (np.array([[[2, 8, 6, 3, 0],
+                             [1, 7, 5, 2, 8],
+                             [6, 4, 1, 7, 0],
+                             [3, 9, 6, 0, 0]],
+                            [[1, 7, 4, 1, 0],
+                             [9, 6, 3, 9, 6],
+                             [5, 2, 8, 5, 0],
+                             [1, 7, 4, 0, 0]]]),
+                  np.array([[[1, 2, 0, 0, 0],
+                             [6, 7, 8, 0, 0],
+                             [3, 4, 5, 6, 0],
+                             [0, 9, 1, 2, 3],
+                             [0, 0, 6, 7, 8]],
+                            [[9, 1, 0, 0, 0],
+                             [5, 6, 7, 0, 0],
+                             [1, 2, 3, 4, 0],
+                             [0, 7, 8, 9, 1],
+                             [0, 0, 4, 5, 6]]]))
+  tests[2, 4] = (np.array([[[5, 0, 0],
+                            [4, 1, 0],
+                            [3, 9, 7]],
+                           [[4, 0, 0],
+                            [3, 9, 0],
+                            [2, 8, 5]]]),
+                 np.array([[[0, 0, 3, 4, 5],
+                            [0, 0, 0, 9, 1],
+                            [0, 0, 0, 0, 7],
+                            [0, 0, 0, 0, 0],
+                            [0, 0, 0, 0, 0]],
+                           [[0, 0, 2, 3, 4],
+                            [0, 0, 0, 8, 9],
+                            [0, 0, 0, 0, 5],
+                            [0, 0, 0, 0, 0],
+                            [0, 0, 0, 0, 0]]]))
   # pyformat: enable
-  return (mat, tests)
+  return (mat, repack_diagonals_in_tests(tests, align))
 
 
-def tall_cases():
+def tall_cases(align=None):
   # pyformat: disable
   mat = np.array([[[1, 2, 3],
                    [4, 5, 6],
@@ -124,81 +197,81 @@ def tall_cases():
                    [7, 8, 9],
                    [9, 8, 7]]])
   tests = dict()
-  tests[(0, 0)] = (np.array([[1, 5, 9],
-                             [3, 2, 6]]),
-                   np.array([[[1, 0, 0],
-                              [0, 5, 0],
-                              [0, 0, 9],
-                              [0, 0, 0]],
-                             [[3, 0, 0],
-                              [0, 2, 0],
-                              [0, 0, 6],
-                              [0, 0, 0]]]))
-  tests[(-4, -3)] = (np.array([[[9, 5],
-                                [6, 0]],
-                               [[7, 8],
-                                [9, 0]]]),
-                     np.array([[[0, 0, 0],
-                                [0, 0, 0],
-                                [0, 0, 0],
-                                [9, 0, 0],
-                                [6, 5, 0]],
-                               [[0, 0, 0],
-                                [0, 0, 0],
-                                [0, 0, 0],
-                                [7, 0, 0],
-                                [9, 8, 0]]]))
-  tests[(-2, -1)] = (np.array([[[4, 8, 7],
-                                [7, 8, 4]],
-                               [[1, 5, 9],
-                                [4, 8, 7]]]),
-                     np.array([[[0, 0, 0],
-                                [4, 0, 0],
-                                [7, 8, 0],
-                                [0, 8, 7],
-                                [0, 0, 4]],
-                               [[0, 0, 0],
-                                [1, 0, 0],
-                                [4, 5, 0],
-                                [0, 8, 9],
-                                [0, 0, 7]]]))
-  tests[(-2, 1)] = (np.array([[[2, 6, 0],
-                               [1, 5, 9],
-                               [4, 8, 7],
-                               [7, 8, 4]],
-                              [[2, 3, 0],
-                               [3, 2, 6],
-                               [1, 5, 9],
-                               [4, 8, 7]]]),
-                    np.array([[[1, 2, 0],
-                               [4, 5, 6],
-                               [7, 8, 9],
-                               [0, 8, 7],
-                               [0, 0, 4]],
-                              [[3, 2, 0],
-                               [1, 2, 3],
-                               [4, 5, 6],
-                               [0, 8, 9],
-                               [0, 0, 7]]]))
-  tests[(1, 2)] = (np.array([[[3, 0],
-                              [2, 6]],
-                             [[1, 0],
-                              [2, 3]]]),
-                   np.array([[[0, 2, 3],
-                              [0, 0, 6],
+  tests[0, 0] = (np.array([[1, 5, 9],
+                           [3, 2, 6]]),
+                 np.array([[[1, 0, 0],
+                            [0, 5, 0],
+                            [0, 0, 9],
+                            [0, 0, 0]],
+                           [[3, 0, 0],
+                            [0, 2, 0],
+                            [0, 0, 6],
+                            [0, 0, 0]]]))
+  tests[-4, -3] = (np.array([[[9, 5],
+                              [6, 0]],
+                             [[7, 8],
+                              [9, 0]]]),
+                   np.array([[[0, 0, 0],
                               [0, 0, 0],
                               [0, 0, 0],
-                              [0, 0, 0]],
-                             [[0, 2, 1],
-                              [0, 0, 3],
+                              [9, 0, 0],
+                              [6, 5, 0]],
+                             [[0, 0, 0],
                               [0, 0, 0],
                               [0, 0, 0],
-                              [0, 0, 0]]]))
+                              [7, 0, 0],
+                              [9, 8, 0]]]))
+  tests[-2, -1] = (np.array([[[4, 8, 7],
+                              [7, 8, 4]],
+                             [[1, 5, 9],
+                              [4, 8, 7]]]),
+                   np.array([[[0, 0, 0],
+                              [4, 0, 0],
+                              [7, 8, 0],
+                              [0, 8, 7],
+                              [0, 0, 4]],
+                             [[0, 0, 0],
+                              [1, 0, 0],
+                              [4, 5, 0],
+                              [0, 8, 9],
+                              [0, 0, 7]]]))
+  tests[-2, 1] = (np.array([[[2, 6, 0],
+                             [1, 5, 9],
+                             [4, 8, 7],
+                             [7, 8, 4]],
+                            [[2, 3, 0],
+                             [3, 2, 6],
+                             [1, 5, 9],
+                             [4, 8, 7]]]),
+                  np.array([[[1, 2, 0],
+                             [4, 5, 6],
+                             [7, 8, 9],
+                             [0, 8, 7],
+                             [0, 0, 4]],
+                            [[3, 2, 0],
+                             [1, 2, 3],
+                             [4, 5, 6],
+                             [0, 8, 9],
+                             [0, 0, 7]]]))
+  tests[1, 2] = (np.array([[[3, 0],
+                            [2, 6]],
+                           [[1, 0],
+                            [2, 3]]]),
+                 np.array([[[0, 2, 3],
+                            [0, 0, 6],
+                            [0, 0, 0],
+                            [0, 0, 0],
+                            [0, 0, 0]],
+                           [[0, 2, 1],
+                            [0, 0, 3],
+                            [0, 0, 0],
+                            [0, 0, 0],
+                            [0, 0, 0]]]))
   # pyformat: enable
-  return (mat, tests)
+  return (mat, repack_diagonals_in_tests(tests, align))
 
 
-def fat_cases():
+def fat_cases(align=None):
   # pyformat: disable
   mat = np.array([[[1, 2, 3, 4],
                    [5, 6, 7, 8],
@@ -207,59 +280,63 @@ def fat_cases():
                    [8, 9, 1, 2],
                    [3, 4, 5, 6]]])
   tests = dict()
-  tests[(2, 2)] = (np.array([[3, 8],
-                             [6, 2]]),
-                   np.array([[[0, 0, 3, 0],
-                              [0, 0, 0, 8],
-                              [0, 0, 0, 0]],
-                             [[0, 0, 6, 0],
-                              [0, 0, 0, 2],
-                              [0, 0, 0, 0]]]))
-  tests[(-2, 0)] = (np.array([[[1, 6, 2],
-                               [5, 1, 0],
-                               [9, 0, 0]],
-                              [[4, 9, 5],
-                               [8, 4, 0],
-                               [3, 0, 0]]]),
-                    np.array([[[1, 0, 0, 0],
-                               [5, 6, 0, 0],
-                               [9, 1, 2, 0]],
-                              [[4, 0, 0, 0],
-                               [8, 9, 0, 0],
-                               [3, 4, 5, 0]]]))
-  tests[(-1, 1)] = (np.array([[[2, 7, 3],
-                               [1, 6, 2],
-                               [5, 1, 0]],
-                              [[5, 1, 6],
-                               [4, 9, 5],
-                               [8, 4, 0]]]),
-                    np.array([[[1, 2, 0, 0],
-                               [5, 6, 7, 0],
-                               [0, 1, 2, 3]],
-                              [[4, 5, 0, 0],
-                               [8, 9, 1, 0],
-                               [0, 4, 5, 6]]]))
-  tests[(0, 3)] = (np.array([[[4, 0, 0],
-                              [3, 8, 0],
-                              [2, 7, 3],
-                              [1, 6, 2]],
-                             [[7, 0, 0],
-                              [6, 2, 0],
-                              [5, 1, 6],
-                              [4, 9, 5]]]),
-                   np.array([[[1, 2, 3, 4],
-                              [0, 6, 7, 8],
-                              [0, 0, 2, 3]],
-                             [[4, 5, 6, 7],
-                              [0, 9, 1, 2],
-                              [0, 0, 5, 6]]]))
+  tests[2, 2] = (np.array([[3, 8],
+                           [6, 2]]),
+                 np.array([[[0, 0, 3, 0],
+                            [0, 0, 0, 8],
+                            [0, 0, 0, 0]],
+                           [[0, 0, 6, 0],
+                            [0, 0, 0, 2],
+                            [0, 0, 0, 0]]]))
+  tests[-2, 0] = (np.array([[[1, 6, 2],
+                             [5, 1, 0],
+                             [9, 0, 0]],
+                            [[4, 9, 5],
+                             [8, 4, 0],
+                             [3, 0, 0]]]),
+                  np.array([[[1, 0, 0, 0],
+                             [5, 6, 0, 0],
+                             [9, 1, 2, 0]],
+                            [[4, 0, 0, 0],
+                             [8, 9, 0, 0],
+                             [3, 4, 5, 0]]]))
+  tests[-1, 1] = (np.array([[[2, 7, 3],
+                             [1, 6, 2],
+                             [5, 1, 0]],
+                            [[5, 1, 6],
+                             [4, 9, 5],
+                             [8, 4, 0]]]),
+                  np.array([[[1, 2, 0, 0],
+                             [5, 6, 7, 0],
+                             [0, 1, 2, 3]],
+                            [[4, 5, 0, 0],
+                             [8, 9, 1, 0],
+                             [0, 4, 5, 6]]]))
+  tests[0, 3] = (np.array([[[4, 0, 0],
+                            [3, 8, 0],
+                            [2, 7, 3],
+                            [1, 6, 2]],
+                           [[7, 0, 0],
+                            [6, 2, 0],
+                            [5, 1, 6],
+                            [4, 9, 5]]]),
+                 np.array([[[1, 2, 3, 4],
+                            [0, 6, 7, 8],
+                            [0, 0, 2, 3]],
+                           [[4, 5, 6, 7],
+                            [0, 9, 1, 2],
+                            [0, 0, 5, 6]]]))
   # pyformat: enable
-  return (mat, tests)
+  return (mat, repack_diagonals_in_tests(tests, align))
+
+
+def all_tests(align=None):
+  return [square_cases(align), tall_cases(align), fat_cases(align)]
 
 
 class MatrixDiagTest(test.TestCase):
 
-  def _moreCases(self):
+  def _moreCases(self, align=None):
     # Diagonal bands.
     # pyformat: disable
     vecs = np.array([[[1, 2, 3, 4],  # Input shape: (2, 3, 4)
@@ -269,41 +346,41 @@ class MatrixDiagTest(test.TestCase):
                       [1, 2, 3, 4],
                       [5, 6, 7, 8]]])
     tests = dict()
-    tests[(-3, -1)] = (vecs,
-                       np.array([[[0, 0, 0, 0, 0],
-                                  [1, 0, 0, 0, 0],
-                                  [5, 2, 0, 0, 0],
-                                  [9, 6, 3, 0, 0],
-                                  [0, 8, 7, 4, 0]],
-                                 [[0, 0, 0, 0, 0],
-                                  [5, 0, 0, 0, 0],
-                                  [1, 4, 0, 0, 0],
-                                  [5, 2, 3, 0, 0],
-                                  [0, 6, 3, 2, 0]]]))
-    tests[(-1, 1)] = (vecs,
-                      np.array([[[5, 1, 0, 0],
-                                 [9, 6, 2, 0],
-                                 [0, 8, 7, 3],
-                                 [0, 0, 7, 8]],
-                                [[1, 5, 0, 0],
-                                 [5, 2, 4, 0],
-                                 [0, 6, 3, 3],
-                                 [0, 0, 7, 4]]]))
-    tests[(2, 4)] = (vecs,
-                     np.array([[[0, 0, 9, 5, 1, 0],
-                                [0, 0, 0, 8, 6, 2],
-                                [0, 0, 0, 0, 7, 7],
-                                [0, 0, 0, 0, 0, 6],
-                                [0, 0, 0, 0, 0, 0],
-                                [0, 0, 0, 0, 0, 0]],
-                               [[0, 0, 5, 1, 5, 0],
-                                [0, 0, 0, 6, 2, 4],
-                                [0, 0, 0, 0, 7, 3],
-                                [0, 0, 0, 0, 0, 8],
-                                [0, 0, 0, 0, 0, 0],
-                                [0, 0, 0, 0, 0, 0]]]))
+    tests[-3, -1] = (vecs,
+                     np.array([[[0, 0, 0, 0, 0],
+                                [1, 0, 0, 0, 0],
+                                [5, 2, 0, 0, 0],
+                                [9, 6, 3, 0, 0],
+                                [0, 8, 7, 4, 0]],
+                               [[0, 0, 0, 0, 0],
+                                [5, 0, 0, 0, 0],
+                                [1, 4, 0, 0, 0],
+                                [5, 2, 3, 0, 0],
+                                [0, 6, 3, 2, 0]]]))
+    tests[-1, 1] = (vecs,
+                    np.array([[[5, 1, 0, 0],
+                               [9, 6, 2, 0],
+                               [0, 8, 7, 3],
+                               [0, 0, 7, 8]],
+                              [[1, 5, 0, 0],
+                               [5, 2, 4, 0],
+                               [0, 6, 3, 3],
+                               [0, 0, 7, 4]]]))
+    tests[2, 4] = (vecs,
+                   np.array([[[0, 0, 9, 5, 1, 0],
+                              [0, 0, 0, 8, 6, 2],
+                              [0, 0, 0, 0, 7, 7],
+                              [0, 0, 0, 0, 0, 6],
+                              [0, 0, 0, 0, 0, 0],
+                              [0, 0, 0, 0, 0, 0]],
+                             [[0, 0, 5, 1, 5, 0],
+                              [0, 0, 0, 6, 2, 4],
+                              [0, 0, 0, 0, 7, 3],
+                              [0, 0, 0, 0, 0, 8],
+                              [0, 0, 0, 0, 0, 0],
+                              [0, 0, 0, 0, 0, 0]]]))
     # pyformat: enable
-    return (None, tests)
+    return (None, repack_diagonals_in_tests(tests, align))
 
   @test_util.run_deprecated_v1
   def testVector(self):
@@ -314,10 +391,7 @@ class MatrixDiagTest(test.TestCase):
       self.assertEqual((3, 3), v_diag.get_shape())
       self.assertAllEqual(v_diag.eval(), mat)
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         # {Sub,Super}diagonals.
         for offset in [1, -2, 5]:
           mat = np.diag(v, offset)
@@ -326,13 +400,14 @@ class MatrixDiagTest(test.TestCase):
           self.assertAllEqual(v_diag.eval(), mat)
 
         # Diagonal bands.
-        for _, tests in [self._moreCases(), square_cases()]:
-          for diags, (vecs, solution) in tests.items():
-            v_diags = array_ops.matrix_diag(vecs[0], k=diags)
-            self.assertEqual(v_diags.get_shape(), solution[0].shape)
-            self.assertAllEqual(v_diags.eval(), solution[0])
+        for align in alignment_list:
+          for _, tests in [self._moreCases(align), square_cases(align)]:
+            for diags, (vecs, solution) in tests.items():
+              v_diags = array_ops.matrix_diag(vecs[0], k=diags, align=align)
+              self.assertEqual(v_diags.get_shape(), solution[0].shape)
+              self.assertAllEqual(v_diags.eval(), solution[0])
 
-  def _testBatchVector(self, dtype):
+  def _testVectorBatch(self, dtype):
     with self.cached_session(use_gpu=True):
       v_batch = np.array([[1.0, 0.0, 3.0], [4.0, 5.0, 6.0]]).astype(dtype)
       mat_batch = np.array([[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 3.0]],
@@ -342,10 +417,7 @@ class MatrixDiagTest(test.TestCase):
       self.assertEqual((2, 3, 3), v_batch_diag.get_shape())
       self.assertAllEqual(v_batch_diag.eval(), mat_batch)
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         # {Sub,Super}diagonals.
         for offset in [1, -2, 5]:
           v_batch_diag = array_ops.matrix_diag(v_batch, k=offset)
@@ -357,33 +429,34 @@ class MatrixDiagTest(test.TestCase):
           self.assertAllEqual(v_batch_diag.eval(), mat_batch)
 
         # Diagonal bands with padding_value.
-        for padding_value in [0, 555, -11]:
-          for _, tests in [self._moreCases(), square_cases()]:
+        for padding_value, align in zip_to_first_list_length([0, 555, -11],
+                                                             alignment_list):
+          for _, tests in [self._moreCases(align), square_cases(align)]:
             for diags, (vecs, solution) in tests.items():
               v_diags = array_ops.matrix_diag(
-                  vecs.astype(dtype), k=diags, padding_value=padding_value)
+                  vecs.astype(dtype),
+                  k=diags,
+                  padding_value=padding_value,
+                  align=align)
               mask = solution == 0
               solution = (solution + padding_value * mask).astype(dtype)
               self.assertEqual(v_diags.get_shape(), solution.shape)
               self.assertAllEqual(v_diags.eval(), solution)
 
   @test_util.run_deprecated_v1
-  def testBatchVector(self):
-    self._testBatchVector(np.float32)
-    self._testBatchVector(np.float64)
-    self._testBatchVector(np.int32)
-    self._testBatchVector(np.int64)
-    self._testBatchVector(np.bool)
+  def testVectorBatch(self):
+    self._testVectorBatch(np.float32)
+    self._testVectorBatch(np.float64)
+    self._testVectorBatch(np.int32)
+    self._testVectorBatch(np.int64)
+    self._testVectorBatch(np.bool)
 
   @test_util.run_deprecated_v1
   def testRectangularBatch(self):
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
       with self.cached_session(use_gpu=True):
         # Stores expected num_rows and num_cols (when the other is given).
-        # expected[(d_lower, d_upper)] = (expected_num_rows, expected_num_cols)
+        # expected[d_lower, d_upper] = (expected_num_rows, expected_num_cols)
         test_list = list()
 
         # Square cases:
@@ -393,6 +466,8 @@ class MatrixDiagTest(test.TestCase):
             (-2, 1): (5, 5),
             (2, 4): (3, 5),
         }
+        # Do not change alignment yet. Re-alignment needs to happen after the
+        # solution shape is updated.
         test_list.append((expected, square_cases()))
 
         # More cases:
@@ -418,16 +493,18 @@ class MatrixDiagTest(test.TestCase):
         }
         test_list.append((expected, fat_cases()))
 
-        for padding_value in [0, 555, -11]:
+        for padding_value, align in zip_to_first_list_length([0, 555, -11],
+                                                             alignment_list):
           # Giving both num_rows and num_cols
-          for _, tests in [tall_cases(), fat_cases()]:
+          for _, tests in [tall_cases(align), fat_cases(align)]:
             for diags, (vecs, solution) in tests.items():
               v_diags = array_ops.matrix_diag(
                   vecs,
                   k=diags,
                   num_rows=solution.shape[-2],
                   num_cols=solution.shape[-1],
-                  padding_value=padding_value)
+                  padding_value=padding_value,
+                  align=align)
               mask = solution == 0
               solution = solution + padding_value * mask
               self.assertEqual(v_diags.get_shape(), solution.shape)
@@ -438,11 +515,15 @@ class MatrixDiagTest(test.TestCase):
             for diags, (_, new_num_cols) in expected.items():
               vecs, solution = tests[diags]
               solution = solution.take(indices=range(new_num_cols), axis=-1)
+              # Repacks the diagonal input according to the new solution shape.
+              vecs = repack_diagonals(
+                  vecs, diags, solution.shape[-2], new_num_cols, align=align)
               v_diags = array_ops.matrix_diag(
                   vecs,
                   k=diags,
                   num_rows=solution.shape[-2],
-                  padding_value=padding_value)
+                  padding_value=padding_value,
+                  align=align)
               mask = solution == 0
               solution = solution + padding_value * mask
               self.assertEqual(v_diags.get_shape(), solution.shape)
@@ -453,11 +534,15 @@ class MatrixDiagTest(test.TestCase):
             for diags, (new_num_rows, _) in expected.items():
               vecs, solution = tests[diags]
               solution = solution.take(indices=range(new_num_rows), axis=-2)
+              # Repacks the diagonal input according to the new solution shape.
+              vecs = repack_diagonals(
+                  vecs, diags, new_num_rows, solution.shape[-1], align=align)
               v_diags = array_ops.matrix_diag(
                   vecs,
                   k=diags,
                   num_cols=solution.shape[-1],
-                  padding_value=padding_value)
+                  padding_value=padding_value,
+                  align=align)
               mask = solution == 0
               solution = solution + padding_value * mask
               self.assertEqual(v_diags.get_shape(), solution.shape)
@@ -489,10 +574,7 @@ class MatrixDiagTest(test.TestCase):
                                                         y.get_shape().as_list())
         self.assertLess(error, 1e-4)
 
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
       # {Sub,super}diagonals/band.
       tests = dict()  # tests[shape] = (d_lower, d_upper)
       tests[(3,)] = (-1, -1)
@@ -500,12 +582,13 @@ class MatrixDiagTest(test.TestCase):
       with self.session(use_gpu=True):
         for shape, diags in tests.items():
           x = constant_op.constant(np.random.rand(*shape), np.float32)
-          y = array_ops.matrix_diag(x, k=diags)
-          error = gradient_checker.compute_gradient_error(
-              x,
-              x.get_shape().as_list(), y,
-              y.get_shape().as_list())
-          self.assertLess(error, 1e-4)
+          for align in alignment_list:
+            y = array_ops.matrix_diag(x, k=diags, align=align)
+            error = gradient_checker.compute_gradient_error(
+                x,
+                x.get_shape().as_list(), y,
+                y.get_shape().as_list())
+            self.assertLess(error, 1e-4)
 
 
 class MatrixSetDiagTest(test.TestCase):
@@ -521,20 +604,18 @@ class MatrixSetDiagTest(test.TestCase):
       self.assertEqual((3, 3), output.get_shape())
       self.assertAllEqual(mat_set_diag, self.evaluate(output))
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         # Diagonal bands.
-        _, tests = square_cases()
-        for diags, pair in tests.items():
-          vecs, banded_mat = pair
-          mask = banded_mat[0] == 0
-          input_mat = np.random.randint(10, size=mask.shape)
-          solution = input_mat * mask + banded_mat[0]
-          output = array_ops.matrix_set_diag(input_mat, vecs[0], k=diags)
-          self.assertEqual(output.get_shape(), solution.shape)
-          self.assertAllEqual(output.eval(), solution)
+        for align in alignment_list:
+          _, tests = square_cases(align)
+          for diags, (vecs, banded_mat) in tests.items():
+            mask = banded_mat[0] == 0
+            input_mat = np.random.randint(10, size=mask.shape)
+            solution = input_mat * mask + banded_mat[0]
+            output = array_ops.matrix_set_diag(
+                input_mat, vecs[0], k=diags, align=align)
+            self.assertEqual(output.get_shape(), solution.shape)
+            self.assertAllEqual(output.eval(), solution)
 
   @test_util.run_deprecated_v1
   def testRectangular(self):
@@ -553,20 +634,18 @@ class MatrixSetDiagTest(test.TestCase):
       self.assertEqual((3, 2), output.get_shape())
       self.assertAllEqual(expected, self.evaluate(output))
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         # Diagonal bands.
-        for _, tests in [tall_cases(), fat_cases()]:
-          for diags, pair in tests.items():
-            vecs, banded_mat = pair
-            mask = banded_mat[0] == 0
-            input_mat = np.random.randint(10, size=mask.shape)
-            solution = input_mat * mask + banded_mat[0]
-            output = array_ops.matrix_set_diag(input_mat, vecs[0], k=diags)
-            self.assertEqual(output.get_shape(), solution.shape)
-            self.assertAllEqual(output.eval(), solution)
+        for align in alignment_list:
+          for _, tests in [tall_cases(align), fat_cases(align)]:
+            for diags, (vecs, banded_mat) in tests.items():
+              mask = banded_mat[0] == 0
+              input_mat = np.random.randint(10, size=mask.shape)
+              solution = input_mat * mask + banded_mat[0]
+              output = array_ops.matrix_set_diag(
+                  input_mat, vecs[0], k=diags, align=align)
+              self.assertEqual(output.get_shape(), solution.shape)
+              self.assertAllEqual(output.eval(), solution)
 
   def _testSquareBatch(self, dtype):
     with self.cached_session(use_gpu=True):
@@ -584,21 +663,18 @@ class MatrixSetDiagTest(test.TestCase):
       self.assertEqual((2, 3, 3), output.get_shape())
       self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         # Diagonal bands.
-        _, tests = square_cases()
-        for diags, pair in tests.items():
-          vecs, banded_mat = pair
-          mask = banded_mat == 0
-          input_mat = np.random.randint(10, size=mask.shape).astype(dtype)
-          solution = (input_mat * mask + banded_mat).astype(dtype)
-          output = array_ops.matrix_set_diag(
-              input_mat, vecs.astype(dtype), k=diags)
-          self.assertEqual(output.get_shape(), solution.shape)
-          self.assertAllEqual(output.eval(), solution)
+        for align in alignment_list:
+          _, tests = square_cases(align)
+          for diags, (vecs, banded_mat) in tests.items():
+            mask = banded_mat == 0
+            input_mat = np.random.randint(10, size=mask.shape).astype(dtype)
+            solution = (input_mat * mask + banded_mat).astype(dtype)
+            output = array_ops.matrix_set_diag(
+                input_mat, vecs.astype(dtype), k=diags, align=align)
+            self.assertEqual(output.get_shape(), solution.shape)
+            self.assertAllEqual(output.eval(), solution)
 
   @test_util.run_deprecated_v1
   def testSquareBatch(self):
@@ -621,20 +697,19 @@ class MatrixSetDiagTest(test.TestCase):
       self.assertEqual((2, 2, 3), output.get_shape())
       self.assertAllEqual(mat_set_diag_batch, self.evaluate(output))
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         # Diagonal bands.
-        for _, tests in [tall_cases(), fat_cases()]:
-          for diags, pair in tests.items():
-            vecs, banded_mat = pair
-            mask = banded_mat == 0
-            input_mat = np.random.randint(10, size=mask.shape)
-            solution = input_mat * mask + banded_mat
-            output = array_ops.matrix_set_diag(input_mat, vecs, k=diags)
-            self.assertEqual(output.get_shape(), solution.shape)
-            self.assertAllEqual(output.eval(), solution)
+        for align in alignment_list:
+          for _, tests in [tall_cases(align), fat_cases(align)]:
+            for diags, pair in tests.items():
+              vecs, banded_mat = pair
+              mask = banded_mat == 0
+              input_mat = np.random.randint(10, size=mask.shape)
+              solution = input_mat * mask + banded_mat
+              output = array_ops.matrix_set_diag(
+                  input_mat, vecs, k=diags, align=align)
+              self.assertEqual(output.get_shape(), solution.shape)
+              self.assertAllEqual(output.eval(), solution)
 
   @test_util.run_deprecated_v1
   def testInvalidShape(self):
@@ -643,10 +718,7 @@ class MatrixSetDiagTest(test.TestCase):
     with self.assertRaisesRegexp(ValueError, "must be at least rank 1"):
       array_ops.matrix_set_diag([[0]], 0)
 
-  # TODO(penporn): Un-skip the XLA test when XLA has MatrixSetDiagV2.
   @test_util.run_deprecated_v1
-  @test_util.disable_xla("XLA op hasn't supported new features in V2, which"
-                         "change the shape requirements.")
   def testInvalidShapeAtEval(self):
     with self.session(use_gpu=True):
       v = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -655,10 +727,7 @@ class MatrixSetDiagTest(test.TestCase):
       with self.assertRaisesOpError("diagonal must be at least 1-dim"):
         array_ops.matrix_set_diag([[v]], v).eval(feed_dict={v: 0.0})
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         d = array_ops.placeholder(dtype=dtypes_lib.float32)
         with self.assertRaisesOpError(
             "first dimensions of diagonal don't match"):
@@ -667,17 +736,15 @@ class MatrixSetDiagTest(test.TestCase):
               d: np.ones((2, 4))
           })
 
-  def _testGrad(self, input_shape, diag_shape, diags):
+  def _testGrad(self, input_shape, diag_shape, diags, align):
     with self.session(use_gpu=True):
       x = constant_op.constant(
           np.random.rand(*input_shape), dtype=dtypes_lib.float32)
       x_diag = constant_op.constant(
           np.random.rand(*diag_shape), dtype=dtypes_lib.float32)
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-        y = array_ops.matrix_set_diag(x, x_diag, k=diags)
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+        y = array_ops.matrix_set_diag(x, x_diag, k=diags, align=align)
       else:
         y = array_ops.matrix_set_diag(x, x_diag)
       error_x = gradient_checker.compute_gradient_error(x,
@@ -696,16 +763,15 @@ class MatrixSetDiagTest(test.TestCase):
     input_shapes = [(3, 4, 4), (3, 3, 4), (3, 4, 3), (7, 4, 8, 8)]
     diag_bands = [(0, 0)]
 
-    # LINT.IfChange
-    if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
       diag_bands.append((-1, 1))
-    for input_shape, diags in itertools.product(input_shapes, diag_bands):
+    for input_shape, diags, align in itertools.product(input_shapes, diag_bands,
+                                                       alignment_list):
       lower_diag_index, upper_diag_index = diags
       num_diags = upper_diag_index - lower_diag_index + 1
       num_diags_dim = () if num_diags == 1 else (num_diags,)
       diag_shape = input_shape[:-2] + num_diags_dim + (min(input_shape[-2:]),)
-      self._testGrad(input_shape, diag_shape, diags)
+      self._testGrad(input_shape, diag_shape, diags, align)
 
   @test_util.run_deprecated_v1
   def testGradWithNoShapeInformation(self):
@@ -739,9 +805,7 @@ class MatrixDiagPartTest(test.TestCase):
       self.assertEqual((3,), mat_diag.get_shape())
       self.assertAllEqual(mat_diag.eval(), v)
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         for offset in [-2, 3]:
           mat = np.diag(v, offset)
           mat_diag = array_ops.matrix_diag_part(mat, k=offset)
@@ -749,12 +813,13 @@ class MatrixDiagPartTest(test.TestCase):
           self.assertAllEqual(mat_diag.eval(), v)
 
         # Diagonal bands.
-        mat, tests = square_cases()
-        for diags, pair in tests.items():
-          solution, _ = pair
-          mat_diag = array_ops.matrix_diag_part(mat[0], k=diags)
-          self.assertEqual(mat_diag.get_shape(), solution[0].shape)
-          self.assertAllEqual(mat_diag.eval(), solution[0])
+        for align in alignment_list:
+          mat, tests = square_cases(align)
+          for diags, pair in tests.items():
+            solution, _ = pair
+            mat_diag = array_ops.matrix_diag_part(mat[0], k=diags, align=align)
+            self.assertEqual(mat_diag.get_shape(), solution[0].shape)
+            self.assertAllEqual(mat_diag.eval(), solution[0])
 
   @test_util.run_deprecated_v1
   def testRectangular(self):
@@ -766,17 +831,16 @@ class MatrixDiagPartTest(test.TestCase):
       mat_diag = array_ops.matrix_diag_part(mat)
       self.assertAllEqual(mat_diag.eval(), np.array([1.0, 4.0]))
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         # Diagonal bands.
-        for mat, tests in [tall_cases(), fat_cases()]:
-          for diags, pair in tests.items():
-            solution, _ = pair
-            mat_diag = array_ops.matrix_diag_part(mat[0], k=diags)
-            self.assertEqual(mat_diag.get_shape(), solution[0].shape)
-            self.assertAllEqual(mat_diag.eval(), solution[0])
+        for align in alignment_list:
+          for mat, tests in [tall_cases(align), fat_cases(align)]:
+            for diags, pair in tests.items():
+              solution, _ = pair
+              mat_diag = array_ops.matrix_diag_part(
+                  mat[0], k=diags, align=align)
+              self.assertEqual(mat_diag.get_shape(), solution[0].shape)
+              self.assertAllEqual(mat_diag.eval(), solution[0])
 
   def _testSquareBatch(self, dtype):
     with self.cached_session(use_gpu=True):
@@ -789,17 +853,18 @@ class MatrixDiagPartTest(test.TestCase):
       self.assertEqual((2, 3), mat_batch_diag.get_shape())
       self.assertAllEqual(mat_batch_diag.eval(), v_batch)
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         # Diagonal bands with padding_value.
-        mat, tests = square_cases()
-        for padding_value in [0, 555, -11]:
+        for padding_value, align in zip_to_first_list_length([0, 555, -11],
+                                                             alignment_list):
+          mat, tests = square_cases(align)
           for diags, pair in tests.items():
             solution, _ = pair
             mat_batch_diag = array_ops.matrix_diag_part(
-                mat.astype(dtype), k=diags, padding_value=padding_value)
+                mat.astype(dtype),
+                k=diags,
+                padding_value=padding_value,
+                align=align)
             mask = solution == 0
             solution = (solution + padding_value * mask).astype(dtype)
             self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
@@ -824,17 +889,15 @@ class MatrixDiagPartTest(test.TestCase):
       self.assertEqual((2, 2), mat_batch_diag.get_shape())
       self.assertAllEqual(mat_batch_diag.eval(), v_batch)
 
-      # LINT.IfChange
-      if compat.forward_compatible(2019, 11, 30):
-      # LINT.ThenChange(//tensorflow/python/ops/array_ops.py)
-
-        # Diagonal bands with padding_value.
-        for padding_value in [0, 555, -11]:
-          for mat, tests in [tall_cases(), fat_cases()]:
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+        # Diagonal bands with padding_value and align.
+        for padding_value, align in zip_to_first_list_length([0, 555, -11],
+                                                             alignment_list):
+          for mat, tests in [tall_cases(align), fat_cases(align)]:
             for diags, pair in tests.items():
               solution, _ = pair
               mat_batch_diag = array_ops.matrix_diag_part(
-                  mat, k=diags, padding_value=padding_value)
+                  mat, k=diags, padding_value=padding_value, align=align)
               mask = solution == 0
               solution = solution + padding_value * mask
               self.assertEqual(mat_batch_diag.get_shape(), solution.shape)
@@ -866,6 +929,22 @@ class MatrixDiagPartTest(test.TestCase):
                                                         y.get_shape().as_list())
         self.assertLess(error, 1e-4)
 
+    if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+      # {Sub,super}diagonals/band.
+      tests = dict()  # tests[shape] = (d_lower, d_upper)
+      tests[(3, 3)] = (-1, -1)
+      tests[(7, 3, 4)] = (-1, 1)
+      with self.session(use_gpu=True):
+        for align in alignment_list:
+          for shape, diags in tests.items():
+            x = constant_op.constant(np.random.rand(*shape), np.float32)
+            y = array_ops.matrix_diag_part(input=x, k=diags, align=align)
+            error = gradient_checker.compute_gradient_error(
+                x,
+                x.get_shape().as_list(), y,
+                y.get_shape().as_list())
+            self.assertLess(error, 1e-4)
+
 
 class DiagTest(test.TestCase):
 
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 2c1bd445e54..6dd853846db 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -341,6 +341,12 @@ def _MatrixDiagV2Grad(op, grad):
       grad, k=op.inputs[1]), None, None, None, None
 
 
+@ops.RegisterGradient("MatrixDiagV3")
+def _MatrixDiagV3Grad(op, grad):
+  return array_ops.matrix_diag_part(
+      grad, k=op.inputs[1], align=op.get_attr("align")), None, None, None, None
+
+
 @ops.RegisterGradient("MatrixDiagPart")
 def _MatrixDiagPartGrad(op, grad):
   matrix_shape = op.inputs[0].get_shape()[-2:]
@@ -362,8 +368,25 @@ def _MatrixDiagPartV2Grad(op, grad):
         num_cols=matrix_shape[1]), None, None
   else:
     return array_ops.matrix_set_diag(
-        array_ops.zeros_like(op.inputs[0]), grad,
-        k=op.inputs[1]), None, None
+        array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1]), None, None
+
+
+@ops.RegisterGradient("MatrixDiagPartV3")
+def _MatrixDiagPartV3Grad(op, grad):
+  """Gradient for MatrixDiagPartV3."""
+  matrix_shape = op.inputs[0].get_shape()[-2:]
+  align = op.get_attr("align")
+  if matrix_shape.is_fully_defined():
+    return array_ops.matrix_diag(
+        grad,
+        k=op.inputs[1],
+        num_rows=matrix_shape[0],
+        num_cols=matrix_shape[1],
+        align=align), None, None
+  else:
+    return array_ops.matrix_set_diag(
+        array_ops.zeros_like(op.inputs[0]), grad, k=op.inputs[1],
+        align=align), None, None
 
 
 @ops.RegisterGradient("MatrixSetDiag")
@@ -392,7 +415,7 @@ def _MatrixSetDiagGrad(op, grad):
 
 @ops.RegisterGradient("MatrixSetDiagV2")
 def _MatrixSetDiagGradV2(op, grad):
-  """Gradient for MatrixSetDiag."""
+  """Gradient for MatrixSetDiagV2."""
   diag_shape = op.inputs[1].get_shape()
   if not diag_shape.is_fully_defined():
     # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
@@ -426,6 +449,46 @@ def _MatrixSetDiagGradV2(op, grad):
   return (grad_input, grad_diag, None)
 
 
+@ops.RegisterGradient("MatrixSetDiagV3")
+def _MatrixSetDiagGradV3(op, grad):
+  """Gradient for MatrixSetDiagV3."""
+  diag_shape = op.inputs[1].get_shape()
+  align = op.get_attr("align")
+  if not diag_shape.is_fully_defined():
+    # Need to know the values of `d_lower` and `d_upper` to infer diag_shape.
+    grad_shape = array_ops.shape(grad)
+    batch_shape = grad_shape[:-2]
+    matrix_shape = grad_shape[-2:]
+    diag_index = array_ops.reshape(op.inputs[2], [-1])  # Converts to vector.
+    d_lower = diag_index[0]
+    d_upper = diag_index[-1]  # Works both when len(diag_index) is 1 and 2.
+    y_offset = control_flow_ops.cond(
+        math_ops.less(d_upper, 0), lambda: d_upper, lambda: 0)
+    x_offset = control_flow_ops.cond(
+        math_ops.greater(d_lower, 0), lambda: -d_lower, lambda: 0)
+
+    max_diag_len = math_ops.minimum(matrix_shape[0] + y_offset,
+                                    matrix_shape[1] + x_offset)
+    # pylint: disable=g-long-lambda
+    # pyformat: disable
+    postfix = control_flow_ops.cond(
+        math_ops.equal(d_lower, d_upper),
+        lambda: ops.convert_to_tensor([max_diag_len]),
+        lambda: ops.convert_to_tensor([d_upper - d_lower + 1,
+                                       max_diag_len]))
+    # pyformat: enable
+    # pylint: enable=g-long-lambda
+    diag_shape = array_ops.concat([batch_shape, postfix], 0)
+
+  grad_input = array_ops.matrix_set_diag(
+      grad,
+      array_ops.zeros(diag_shape, dtype=grad.dtype),
+      k=op.inputs[2],
+      align=align)
+  grad_diag = array_ops.matrix_diag_part(grad, k=op.inputs[2], align=align)
+  return (grad_input, grad_diag, None)
+
+
 @ops.RegisterGradient("MatrixBandPart")
 def _MatrixBandPartGrad(op, grad):
   num_lower = op.inputs[1]
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index c6234754d20..488a9bafca9 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -54,6 +54,13 @@ tf_export("newaxis").export_constant(__name__, "newaxis")
 # existing 'slice' for later use in this module.
 _BaseSlice = slice
 
+# LINT.IfChange
+matrix_diag_v3_forward_compat_date = (2019, 12, 6)
+# LINT.ThenChange(
+#   //tensorflow/compiler/tests/matrix_diag_ops_test.py,
+#   //tensorflow/python/kernel_tests/diag_op_test.py,
+#   //tensorflow/python/ops/parallel_for/array_test.py
+# )
 
 @tf_export("reshape", v1=["reshape", "manip.reshape"])
 def reshape(tensor, shape, name=None):  # pylint: disable=redefined-outer-name
@@ -2043,7 +2050,8 @@ def matrix_diag(diagonal,
                 k=0,
                 num_rows=-1,
                 num_cols=-1,
-                padding_value=0):
+                padding_value=0,
+                align="RIGHT_LEFT"):
   """Returns a batched diagonal tensor with given batched diagonal values.
 
   Returns a tensor with the contents in `diagonal` as `k[0]`-th to `k[1]`-th
@@ -2078,7 +2086,17 @@ def matrix_diag(diagonal,
       padding_value                                     ; otherwise
   ```
   where `d = n - m`, `diag_index = k[1] - d`, and
-  `index_in_diag = n - max(d, 0)`.
+  `index_in_diag = n - max(d, 0) + offset`.
+
+  `offset` is zero except when the alignment of the diagonal is to the right.
+  ```
+  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
+                                             and `d >= 0`) or
+                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
+                                             and `d <= 0`)
+           0                          ; otherwise
+  ```
+  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
 
   For example:
 
@@ -2108,17 +2126,34 @@ def matrix_diag(diagonal,
           [0, 0, 0, 6],
           [0, 0, 0, 0]]]
 
-  # A band of diagonals.
-  diagonals = np.array([[[1, 2, 3],  # Input shape: (2, 2, 3)
-                         [4, 5, 0]],
-                        [[6, 7, 9],
-                         [9, 1, 0]]])
-  tf.matrix_diag(diagonals, k = (-1, 0))
-    ==> [[[1, 0, 0],  # Output shape: (2, 3, 3)
-          [4, 2, 0],
+  # A tridiagonal band (per batch).
+  diagonals = np.array([[[8, 9, 0],  # Input shape: (2, 2, 3)
+                         [1, 2, 3],
+                         [0, 4, 5]],
+                        [[2, 3, 0],
+                         [6, 7, 9],
+                         [0, 9, 1]]])
+  tf.matrix_diag(diagonals, k = (-1, 1))
+    ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
+          [4, 2, 9],
           [0, 5, 3]],
-         [[6, 0, 0],
-          [9, 7, 0],
+         [[6, 2, 0],
+          [9, 7, 3],
+          [0, 1, 9]]]
+
+  # RIGHT_LEFT alignment.
+  diagonals = np.array([[[0, 8, 9],  # Input shape: (2, 2, 3)
+                         [1, 2, 3],
+                         [4, 5, 0]],
+                        [[0, 2, 3],
+                         [6, 7, 9],
+                         [9, 1, 0]]])
+  tf.matrix_diag(diagonals, k = (-1, 1), align="RIGHT_LEFT")
+    ==> [[[1, 8, 0],  # Output shape: (2, 3, 3)
+          [4, 2, 9],
+          [0, 5, 3]],
+         [[6, 2, 0],
+          [9, 7, 3],
           [0, 1, 9]]]
 
   # Rectangular matrix.
@@ -2150,27 +2185,34 @@ def matrix_diag(diagonal,
       size from `d_lower`, `d_upper`, and the innermost dimension of `diagonal`.
     padding_value: The value to fill the area outside the specified diagonal
       band with. Default is 0.
+    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
+      `align` is a string specifying how superdiagonals and subdiagonals should
+      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
+      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
+      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
+      the left (right-pads the row). It is the packing format LAPACK uses.
+      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
 
   Returns:
     A Tensor. Has the same type as `diagonal`.
   """
-  # LINT.IfChange
-  if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py)
-
+  if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
     # Special case to sidestep the tf.constant conversion error:
     # TypeError: Expected bool, got 0 of type 'int' instead.
     if hasattr(diagonal, "dtype") and diagonal.dtype == "bool":
       padding_value = bool(padding_value)
-    return gen_array_ops.matrix_diag_v2(
+
+    return gen_array_ops.matrix_diag_v3(
         diagonal=diagonal,
         k=k,
         num_rows=num_rows,
         num_cols=num_cols,
         padding_value=padding_value,
+        align=align,
         name=name)
 
   # Call v1 to maintain forward compatibility.
+  # (We skip v2 because its alignment conflicts with v3's default alignment.)
   return gen_array_ops.matrix_diag(diagonal=diagonal, name=name)
 
 
@@ -2181,7 +2223,8 @@ def matrix_diag_part(
     input,  # pylint:disable=redefined-builtin
     name="diag_part",
     k=0,
-    padding_value=0):
+    padding_value=0,
+    align="RIGHT_LEFT"):
   """Returns the batched diagonal part of a batched tensor.
 
   Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched
@@ -2211,7 +2254,17 @@ def matrix_diag_part(
     = input[i, j, ..., l, n+y, n+x] ; if 0 <= n+y < M and 0 <= n+x < N,
       padding_value                 ; otherwise.
   ```
-  where `d = k[1] - m`, `y = max(-d, 0)`, and `x = max(d, 0)`.
+  where `d = k[1] - m`, `y = max(-d, 0) - offset`, and `x = max(d, 0) - offset`.
+
+  `offset` is zero except when the alignment of the diagonal is to the right.
+  ```
+  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
+                                             and `d >= 0`) or
+                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
+                                             and `d <= 0`)
+           0                          ; otherwise
+  ```
+  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
 
   The input must be at least a matrix.
 
@@ -2234,16 +2287,36 @@ def matrix_diag_part(
     ==> [[2, 7, 6],  # Output shape: (2, 3)
          [4, 3, 8]]
 
-  # A tridiagonal band from each batch.
-  tf.matrix_diag_part(input, k = (-1, 1))
-    ==> [[[2, 7, 6],  # Output shape: (2, 3, 3)
+  # A band from each batch.
+  tf.matrix_diag_part(input, k = (-1, 2))
+    ==> [[[3, 8, 0],  # Output shape: (2, 4, 3)
+          [2, 7, 6],
+          [1, 6, 7],
+          [0, 5, 8]],
+         [[3, 4, 0],
+          [4, 3, 8],
+          [5, 2, 7],
+          [0, 1, 6]]]
+
+  # RIGHT_LEFT alignment.
+  tf.matrix_diag_part(input, k = (-1, 2), align="RIGHT_LEFT")
+    ==> [[[0, 3, 8],  # Output shape: (2, 4, 3)
+          [2, 7, 6],
           [1, 6, 7],
           [5, 8, 0]],
-         [[4, 3, 8],
+         [[0, 3, 4],
+          [4, 3, 8],
           [5, 2, 7],
           [1, 6, 0]]]
 
-  # Padding value = 9
+  # max_diag_len can be shorter than the main diagonal.
+  tf.matrix_diag_part(input, k = (-2, -1))
+    ==> [[[5, 8],
+          [0, 9]],
+         [[1, 6],
+          [0, 5]]]
+
+  # padding_value = 9
   tf.matrix_diag_part(input, k = (1, 3), padding_value = 9)
     ==> [[[4, 9, 9],  # Output shape: (2, 3, 3)
           [3, 8, 9],
@@ -2251,6 +2324,7 @@ def matrix_diag_part(
          [[2, 9, 9],
           [3, 4, 9],
           [4, 3, 8]]]
+
   ```
 
   Args:
@@ -2262,23 +2336,28 @@ def matrix_diag_part(
       and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
     padding_value: The value to fill the area outside the specified diagonal
       band with. Default is 0.
+    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
+      `align` is a string specifying how superdiagonals and subdiagonals should
+      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
+      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
+      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
+      the left (right-pads the row). It is the packing format LAPACK uses.
+      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
 
   Returns:
     A Tensor containing diagonals of `input`. Has the same type as `input`.
   """
-  # LINT.IfChange
-  if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py)
-
+  if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
     # Special case to sidestep the tf.constant conversion error:
     # TypeError: Expected bool, got 0 of type 'int' instead.
     if hasattr(input, "dtype") and input.dtype == "bool":
       padding_value = bool(padding_value)
 
-    return gen_array_ops.matrix_diag_part_v2(
-        input=input, k=k, padding_value=padding_value, name=name)
+    return gen_array_ops.matrix_diag_part_v3(
+        input=input, k=k, padding_value=padding_value, align=align, name=name)
 
   # Call v1 to maintain forward compatibility.
+  # (We skip v2 because its alignment conflicts with v3's default alignment.)
   return gen_array_ops.matrix_diag_part(input=input, name=name)
 
 
@@ -2288,7 +2367,8 @@ def matrix_set_diag(
     input,  # pylint:disable=redefined-builtin
     diagonal,
     name="set_diag",
-    k=0):
+    k=0,
+    align="RIGHT_LEFT"):
   """Returns a batched matrix tensor with new batched diagonal values.
 
   Given `input` and `diagonal`, this operation returns a tensor with the
@@ -2319,7 +2399,17 @@ def matrix_set_diag(
       input[i, j, ..., l, m, n]                         ; otherwise
   ```
   where `d = n - m`, `diag_index = k[1] - d`, and
-  `index_in_diag = n - max(d, 0)`.
+  `index_in_diag = n - max(d, 0) + offset`.
+
+  `offset` is zero except when the alignment of the diagonal is to the right.
+  ```
+  offset = max_diag_len - diag_len(d) ; if (`align` in {RIGHT_LEFT, RIGHT_RIGHT}
+                                             and `d >= 0`) or
+                                           (`align` in {LEFT_RIGHT, RIGHT_RIGHT}
+                                             and `d <= 0`)
+           0                          ; otherwise
+  ```
+  where `diag_len(d) = min(cols - max(d, 0), rows + min(d, 0))`.
 
   For example:
 
@@ -2333,15 +2423,16 @@ def matrix_set_diag(
                      [7, 7, 7, 7]]])
   diagonal = np.array([[1, 2, 3],               # Diagonal shape: (2, 3)
                        [4, 5, 6]])
-  tf.matrix_set_diag(diagonal) ==> [[[1, 7, 7, 7],  # Output shape: (2, 3, 4)
-                                     [7, 2, 7, 7],
-                                     [7, 7, 3, 7]],
-                                    [[4, 7, 7, 7],
-                                     [7, 5, 7, 7],
-                                     [7, 7, 6, 7]]]
+  tf.matrix_set_diag(input, diagonal)
+    ==> [[[1, 7, 7, 7],  # Output shape: (2, 3, 4)
+          [7, 2, 7, 7],
+          [7, 7, 3, 7]],
+         [[4, 7, 7, 7],
+          [7, 5, 7, 7],
+          [7, 7, 6, 7]]]
 
   # A superdiagonal (per batch).
-  tf.matrix_set_diag(diagonal, k = 1)
+  tf.matrix_set_diag(input, diagonal, k = 1)
     ==> [[[7, 1, 7, 7],  # Output shape: (2, 3, 4)
           [7, 7, 2, 7],
           [7, 7, 7, 3]],
@@ -2350,17 +2441,38 @@ def matrix_set_diag(
           [7, 7, 7, 6]]]
 
   # A band of diagonals.
-  diagonals = np.array([[[1, 2, 3],  # Diagonal shape: (2, 2, 3)
+  diagonals = np.array([[[9, 1, 0],  # Diagonal shape: (2, 4, 3)
+                         [6, 5, 8],
+                         [1, 2, 3],
+                         [0, 4, 5]],
+                        [[1, 2, 0],
+                         [5, 6, 4],
+                         [6, 1, 2],
+                         [0, 3, 4]]])
+  tf.matrix_set_diag(input, diagonals, k = (-1, 2))
+    ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
+          [4, 2, 5, 1],
+          [7, 5, 3, 8]],
+         [[6, 5, 1, 7],
+          [3, 1, 6, 2],
+          [7, 4, 2, 4]]]
+
+  # RIGHT_LEFT alignment.
+  diagonals = np.array([[[0, 9, 1],  # Diagonal shape: (2, 4, 3)
+                         [6, 5, 8],
+                         [1, 2, 3],
                          [4, 5, 0]],
-                        [[6, 1, 2],
+                        [[0, 1, 2],
+                         [5, 6, 4],
+                         [6, 1, 2],
                          [3, 4, 0]]])
-  tf.matrix_set_diag(diagonals, k = (-1, 0))
-    ==> [[[1, 7, 7, 7],  # Output shape: (2, 3, 4)
-          [4, 2, 7, 7],
-          [0, 5, 3, 7]],
-         [[6, 7, 7, 7],
-          [3, 1, 7, 7],
-          [7, 4, 2, 7]]]
+  tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="RIGHT_LEFT")
+    ==> [[[1, 6, 9, 7],  # Output shape: (2, 3, 4)
+          [4, 2, 5, 1],
+          [7, 5, 3, 8]],
+         [[6, 5, 1, 7],
+          [3, 1, 6, 2],
+          [7, 4, 2, 4]]]
 
   ```
 
@@ -2373,14 +2485,20 @@ def matrix_set_diag(
       main diagonal, and negative value means subdiagonals. `k` can be a single
       integer (for a single diagonal) or a pair of integers specifying the low
       and high ends of a matrix band. `k[0]` must not be larger than `k[1]`.
+    align: Some diagonals are shorter than `max_diag_len` and need to be padded.
+      `align` is a string specifying how superdiagonals and subdiagonals should
+      be aligned, respectively. There are four possible alignments: "RIGHT_LEFT"
+      (default), "LEFT_RIGHT", "LEFT_LEFT", and "RIGHT_RIGHT". "RIGHT_LEFT"
+      aligns superdiagonals to the right (left-pads the row) and subdiagonals to
+      the left (right-pads the row). It is the packing format LAPACK uses.
+      cuSPARSE uses "LEFT_RIGHT", which is the opposite alignment.
   """
-  # LINT.IfChange
-  if compat.forward_compatible(2019, 11, 30):
-    # LINT.ThenChange(//tensorflow/python/kernel_tests/diag_op_test.py)
-    return gen_array_ops.matrix_set_diag_v2(
-        input=input, diagonal=diagonal, k=k, name=name)
+  if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+    return gen_array_ops.matrix_set_diag_v3(
+        input=input, diagonal=diagonal, k=k, align=align, name=name)
 
   # Call v1 to maintain forward compatibility.
+  # (We skip v2 because its alignment conflicts with v3's default alignment.)
   return gen_array_ops.matrix_set_diag(
       input=input, diagonal=diagonal, name=name)
 
diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py
index 205c16e5197..2efa16ce62a 100644
--- a/tensorflow/python/ops/linalg/linear_operator.py
+++ b/tensorflow/python/ops/linalg/linear_operator.py
@@ -1111,10 +1111,17 @@ def _cholesky(input, name=None):   # pylint:disable=redefined-builtin
 
 
 # The signature has to match with the one in python/op/array_ops.py,
-# so we have k and padding_value even though we don't use them here.
+# so we have k, padding_value, and align even though we don't use them here.
+# pylint:disable=unused-argument
 @dispatch.dispatch_for_types(linalg.diag_part, LinearOperator)
-def _diag_part(input, name="diag_part", k=0, padding_value=0):  # pylint:disable=redefined-builtin, unused-argument
+def _diag_part(
+    input,  # pylint:disable=redefined-builtin
+    name="diag_part",
+    k=0,
+    padding_value=0,
+    align="RIGHT_LEFT"):
   return input.diag_part(name)
+# pylint:enable=unused-argument
 
 
 @dispatch.dispatch_for_types(linalg.det, LinearOperator)
diff --git a/tensorflow/python/ops/parallel_for/array_test.py b/tensorflow/python/ops/parallel_for/array_test.py
index 874e59926af..7986b74ed3b 100644
--- a/tensorflow/python/ops/parallel_for/array_test.py
+++ b/tensorflow/python/ops/parallel_for/array_test.py
@@ -33,6 +33,15 @@ from tensorflow.python.ops.parallel_for.test_util import PForTestCase
 from tensorflow.python.platform import test
 
 
+# LINT.IfChange
+matrix_diag_v3_forward_compat_date = (2019, 12, 6)
+# LINT.ThenChange(
+#   //tensorflow/compiler/tests/matrix_diag_ops_test.py,
+#   //tensorflow/python/kernel_tests/diag_op_test.py,
+#   //tensorflow/python/ops/array_ops.py
+# )
+
+
 @test_util.run_all_in_graph_and_eager_modes
 class ArrayTest(PForTestCase):
 
@@ -336,8 +345,9 @@ class ArrayTest(PForTestCase):
 
     def loop_fn(i):
       diagonal = array_ops.gather(x, i)
-      if compat.forward_compatible(2019, 11, 30):
-        return array_ops.matrix_diag(diagonal, k=(0, 1), num_rows=4, num_cols=5)
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+        return array_ops.matrix_diag(
+            diagonal, k=(0, 1), num_rows=4, num_cols=5, align="RIGHT_LEFT")
       return array_ops.matrix_diag(diagonal)
 
     self._test_loop_fn(loop_fn, 3)
@@ -347,8 +357,9 @@ class ArrayTest(PForTestCase):
 
     def loop_fn(i):
       input = array_ops.gather(x, i)  # pylint: disable=redefined-builtin
-      if compat.forward_compatible(2019, 11, 30):
-        return array_ops.matrix_diag_part(input, k=(-2, 0), padding_value=3)
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
+        return array_ops.matrix_diag_part(
+            input, k=(-2, 0), padding_value=3, align="RIGHT_LEFT")
       return array_ops.matrix_diag_part(input)
 
     self._test_loop_fn(loop_fn, 3)
@@ -356,8 +367,7 @@ class ArrayTest(PForTestCase):
   def test_matrix_set_diag(self):
     matrices = random_ops.random_uniform([3, 4, 4])
     diags = random_ops.random_uniform([3, 4])
-    if compat.forward_compatible(2019, 11, 30):
-      bands = random_ops.random_uniform([3, 3, 4])
+    bands = random_ops.random_uniform([3, 3, 4])
 
     def loop_fn(i):
       matrix_i = array_ops.gather(matrices, i)
@@ -365,16 +375,20 @@ class ArrayTest(PForTestCase):
       results = [
           array_ops.matrix_set_diag(matrix_i, diag_i),
           array_ops.matrix_set_diag(matrices[0, ...], diag_i),
-          array_ops.matrix_set_diag(matrix_i, diags[0, ...])
+          array_ops.matrix_set_diag(matrix_i, diags[0, ...]),
       ]
-      if compat.forward_compatible(2019, 11, 30):
+
+      if compat.forward_compatible(*matrix_diag_v3_forward_compat_date):
         k = (-1, 1)
         band_i = array_ops.gather(bands, i)
-        results.extend([
-            array_ops.matrix_set_diag(matrix_i, band_i, k=k),
-            array_ops.matrix_set_diag(matrices[0, ...], band_i, k=k),
-            array_ops.matrix_set_diag(matrix_i, bands[0, ...], k=k)
-        ])
+        for align in ["RIGHT_LEFT", "LEFT_RIGHT"]:
+          results.extend([
+              array_ops.matrix_set_diag(matrix_i, band_i, k=k, align=align),
+              array_ops.matrix_set_diag(
+                  matrices[0, ...], band_i, k=k, align=align),
+              array_ops.matrix_set_diag(
+                  matrix_i, bands[0, ...], k=k, align=align)
+          ])
       return results
 
     self._test_loop_fn(loop_fn, 3)
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index a59f6af4d1f..0e861ffe2ab 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -1902,43 +1902,55 @@ def _convert_matrix_set_diag(pfor_input):
   return wrap(array_ops.matrix_set_diag(t, diag), True)
 
 
-# Registrations for MatrixDiagV2, MatrixDiagPartv2, and MatrixSetDiagV2.
+# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3.
 # The input orders defined in the OpKernel and the actual python API are
 # different (for compatibility with V1), so we cannot use _convert_identity.
+# v2 is not compatible with v3 and is never exposed on the public API.
 @RegisterPFor("MatrixDiagV2")
+@RegisterPFor("MatrixDiagV3")
 def _convert_matrix_diag_v2(pfor_input):
-  diagonal = pfor_input.stacked_input(0)
-  k = pfor_input.unstacked_input(1)
-  num_rows = pfor_input.unstacked_input(2)
-  num_cols = pfor_input.unstacked_input(3)
-  padding_value = pfor_input.unstacked_input(4)
-  return wrap(
-      array_ops.matrix_diag(
-          diagonal,
-          k=k,
-          num_rows=num_rows,
-          num_cols=num_cols,
-          padding_value=padding_value), True)
+  params = {
+      "diagonal": pfor_input.stacked_input(0),
+      "k": pfor_input.unstacked_input(1),
+      "num_rows": pfor_input.unstacked_input(2),
+      "num_cols": pfor_input.unstacked_input(3),
+      "padding_value": pfor_input.unstacked_input(4)
+  }
+  if pfor_input.op_type == "MatrixDiagV2":
+    return wrap(array_ops.matrix_diag_v2(**params), True)
+  params["align"] = pfor_input.get_attr("align")
+  return wrap(array_ops.matrix_diag(**params), True)
 
 
 # See notes for MatrixDiagV2
 @RegisterPFor("MatrixDiagPartV2")
+@RegisterPFor("MatrixDiagPartV3")
 def _convert_matrix_diag_part_v2(pfor_input):
-  input = pfor_input.stacked_input(0)  # pylint:disable=redefined-builtin
-  k = pfor_input.unstacked_input(1)
-  padding_value = pfor_input.unstacked_input(2)
-  return wrap(
-      array_ops.matrix_diag_part(input, k=k, padding_value=padding_value), True)
+  params = {
+      "input": pfor_input.stacked_input(0),
+      "k": pfor_input.unstacked_input(1),
+      "padding_value": pfor_input.unstacked_input(2)
+  }
+  if pfor_input.op_type == "MatrixDiagPartV2":
+    return wrap(array_ops.matrix_diag_part_v2(**params), True)
+  params["align"] = pfor_input.get_attr("align")
+  return wrap(array_ops.matrix_diag_part(**params), True)
 
 
 # See notes for MatrixDiagV2
 @RegisterPFor("MatrixSetDiagV2")
+@RegisterPFor("MatrixSetDiagV3")
 def _convert_matrix_set_diag_v2(pfor_input):
   pfor_input.stack_inputs([0, 1])
-  input = pfor_input.stacked_input(0)  # pylint:disable=redefined-builtin
-  diagonal = pfor_input.stacked_input(1)
-  k = pfor_input.unstacked_input(2)
-  return wrap(array_ops.matrix_set_diag(input, diagonal, k=k), True)
+  params = {
+      "input": pfor_input.stacked_input(0),
+      "diagonal": pfor_input.stacked_input(1),
+      "k": pfor_input.unstacked_input(2)
+  }
+  if pfor_input.op_type == "MatrixSetDiagV2":
+    return wrap(array_ops.matrix_set_diag_v2(**params), True)
+  params["align"] = pfor_input.get_attr("align")
+  return wrap(array_ops.matrix_set_diag(**params), True)
 
 
 @RegisterPFor("OneHot")
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
index 283fd9c35d6..f645db2f310 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
@@ -102,11 +102,11 @@ tf_module {
   }
   member_method {
     name: "diag"
-    argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\'], "
+    argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\', \'RIGHT_LEFT\'], "
   }
   member_method {
     name: "diag_part"
-    argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], "
+    argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\', \'RIGHT_LEFT\'], "
   }
   member_method {
     name: "eigh"
@@ -202,7 +202,7 @@ tf_module {
   }
   member_method {
     name: "set_diag"
-    argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\'], "
+    argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\', \'align\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\', \'RIGHT_LEFT\'], "
   }
   member_method {
     name: "slogdet"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 6c75ecb5fbf..2883f6d8166 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1646,11 +1646,11 @@ tf_module {
   }
   member_method {
     name: "matrix_diag"
-    argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\'], "
+    argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\', \'RIGHT_LEFT\'], "
   }
   member_method {
     name: "matrix_diag_part"
-    argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], "
+    argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\', \'RIGHT_LEFT\'], "
   }
   member_method {
     name: "matrix_inverse"
@@ -1658,7 +1658,7 @@ tf_module {
   }
   member_method {
     name: "matrix_set_diag"
-    argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\'], "
+    argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\', \'align\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\', \'RIGHT_LEFT\'], "
   }
   member_method {
     name: "matrix_solve"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 57065a6cfc2..edeb3ba0d56 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -2172,10 +2172,18 @@ tf_module {
     name: "MatrixDiagPartV2"
     argspec: "args=[\'input\', \'k\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "MatrixDiagPartV3"
+    argspec: "args=[\'input\', \'k\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
+  }
   member_method {
     name: "MatrixDiagV2"
     argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "MatrixDiagV3"
+    argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
+  }
   member_method {
     name: "MatrixExponential"
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@@ -2196,6 +2204,10 @@ tf_module {
     name: "MatrixSetDiagV2"
     argspec: "args=[\'input\', \'diagonal\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "MatrixSetDiagV3"
+    argspec: "args=[\'input\', \'diagonal\', \'k\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
+  }
   member_method {
     name: "MatrixSolve"
     argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
index a25583d7fdd..a58c988577a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
@@ -102,11 +102,11 @@ tf_module {
   }
   member_method {
     name: "diag"
-    argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\'], "
+    argspec: "args=[\'diagonal\', \'name\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag\', \'0\', \'-1\', \'-1\', \'0\', \'RIGHT_LEFT\'], "
   }
   member_method {
     name: "diag_part"
-    argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\'], "
+    argspec: "args=[\'input\', \'name\', \'k\', \'padding_value\', \'align\'], varargs=None, keywords=None, defaults=[\'diag_part\', \'0\', \'0\', \'RIGHT_LEFT\'], "
   }
   member_method {
     name: "eig"
@@ -210,7 +210,7 @@ tf_module {
   }
   member_method {
     name: "set_diag"
-    argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\'], "
+    argspec: "args=[\'input\', \'diagonal\', \'name\', \'k\', \'align\'], varargs=None, keywords=None, defaults=[\'set_diag\', \'0\', \'RIGHT_LEFT\'], "
   }
   member_method {
     name: "slogdet"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 57065a6cfc2..edeb3ba0d56 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -2172,10 +2172,18 @@ tf_module {
     name: "MatrixDiagPartV2"
     argspec: "args=[\'input\', \'k\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "MatrixDiagPartV3"
+    argspec: "args=[\'input\', \'k\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
+  }
   member_method {
     name: "MatrixDiagV2"
     argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "MatrixDiagV3"
+    argspec: "args=[\'diagonal\', \'k\', \'num_rows\', \'num_cols\', \'padding_value\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
+  }
   member_method {
     name: "MatrixExponential"
     argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@@ -2196,6 +2204,10 @@ tf_module {
     name: "MatrixSetDiagV2"
     argspec: "args=[\'input\', \'diagonal\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "MatrixSetDiagV3"
+    argspec: "args=[\'input\', \'diagonal\', \'k\', \'align\', \'name\'], varargs=None, keywords=None, defaults=[\'RIGHT_LEFT\', \'None\'], "
+  }
   member_method {
     name: "MatrixSolve"
     argspec: "args=[\'matrix\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "