diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 5c63ee7a97b..7cba7937f1b 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -27,6 +27,7 @@ import numpy as np
 
 from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
 from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.compat import compat
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -71,7 +72,8 @@ def RunLSTM(sess,
             is_training=True,
             dropout=0.,
             num_dirs=True,
-            dtype=dtypes.float32):
+            dtype=dtypes.float32,
+            num_proj=None):
   # TODO(jamesqin): add multi-layer tests.
   # TODO(jamesqin): add multi-dir tests
   assert num_layers == 1
@@ -91,10 +93,12 @@ def RunLSTM(sess,
   inputs_dynamic = array_ops.placeholder(
       dtype, shape=[None, None, None], name="inputs")
   inputs = inputs_dynamic if dynamic_shape_input else inputs_static
+  unified_num_units = num_proj if num_proj else num_units
+  unified_num_proj = num_proj if num_proj else None
   initial_h_op = variable_scope.get_variable(
       "initial_h_op",
-      initializer=np.random.rand(batch_size,
-                                 num_units).astype(dtype.as_numpy_dtype),
+      initializer=np.random.rand(batch_size, unified_num_units).astype(
+          dtype.as_numpy_dtype),
       dtype=dtype)
   initial_c_op = variable_scope.get_variable(
       "initial_c_op",
@@ -115,13 +119,19 @@ def RunLSTM(sess,
   with variable_scope.variable_scope("test", initializer=initializer):
     w = variable_scope.get_variable(
         "rnn/lstm_cell/kernel",
-        shape=[input_size + num_units, num_units * 4],
+        shape=[input_size + unified_num_units, num_units * 4],
         dtype=dtype)
     b = variable_scope.get_variable(
         "rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype)
+    if num_proj:
+      pw = variable_scope.get_variable(
+          "rnn/lstm_cell/projection/kernel",
+          shape=[num_units, num_proj],
+          dtype=dtype)
 
     # canonical lstm. must set forget_bias to 0. to align with cudnn lstm.
-    cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True)
+    cell = rnn_cell_impl.LSTMCell(
+        num_units, forget_bias=0., reuse=True, num_proj=unified_num_proj)
     outputs_op, state_tuple_op = rnn.dynamic_rnn(
         cell,
         inputs_static,
@@ -134,8 +144,13 @@ def RunLSTM(sess,
 
   # Convert to cudnn opaque param.
   format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM(
-      num_layers, num_units, input_size)
-  opaque_params = format_converter.tf_canonical_to_opaque([w, b])
+      num_layers, num_units, input_size, num_proj=unified_num_proj)
+  if num_proj:
+    opaque_params = format_converter.tf_canonical_to_opaque([w, b], [
+        pw,
+    ])
+  else:
+    opaque_params = format_converter.tf_canonical_to_opaque([w, b])
 
   cu_initial_h_op = array_ops.expand_dims(
       initial_h_op, axis=(0 if time_major else 1))
@@ -150,16 +165,22 @@ def RunLSTM(sess,
       time_major=time_major,
       dropout=dropout,
       is_training=is_training,
-      rnn_mode=cudnn_rnn_ops.CUDNN_LSTM)
+      rnn_mode=cudnn_rnn_ops.CUDNN_LSTM,
+      num_proj=unified_num_proj)
   # Remove the trivial 1st dimension.
   cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple(
       c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1),
       h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1))
 
   if is_training:
-    (inp_grad_op, hgrad_op,
-     cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients(
-         outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b])
+    if num_proj:
+      (inp_grad_op, hgrad_op, cgrad_op,
+       wgrad_op, bgrad_op, pwgrad_op) = gradients_impl.gradients(
+           outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b, pw])
+    else:
+      (inp_grad_op, hgrad_op,
+       cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients(
+           outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b])
 
     (cu_inp_grad_op, cu_hgrad_op,
      cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients(
@@ -170,10 +191,16 @@ def RunLSTM(sess,
     # Remove the trivial 1st dimension
     cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1)
 
-    cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
-        opaque_grad_op)
+    if num_proj:
+      cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op = \
+          format_converter.opaque_to_tf_canonical(opaque_grad_op)
+    else:
+      cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
+          opaque_grad_op)
     cu_wgrad_op = cu_wgrad_op[0]
     cu_bgrad_op = cu_bgrad_op[0]
+    if num_proj:
+      cu_pwgrad_op = cu_pwgrad_op[0]
     # cudnn lstm has 2 biases each gate. When converting to tf canonical format,
     # the two biases are summed into one. Thus here bias gradient should be
     # halved when comparing with tf lstm.
@@ -183,17 +210,32 @@ def RunLSTM(sess,
   sess.run(init_op)
 
   if is_training:
-    outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([
-        outputs_op, state_tuple_op, inp_grad_op,
-        (hgrad_op, cgrad_op), wgrad_op, bgrad_op
-    ])
-    (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
-     cu_bgrad) = sess.run(
-         [
-             cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
-             (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
-         ],
-         feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
+    if num_proj:
+      (outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad,
+       pwgrad) = sess.run([
+           outputs_op, state_tuple_op, inp_grad_op, (hgrad_op, cgrad_op),
+           wgrad_op, bgrad_op, pwgrad_op
+       ])
+      (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
+       cu_bgrad, cu_pwgrad) = sess.run(
+           [
+               cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
+               (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op,
+               cu_pwgrad_op
+           ],
+           feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
+    else:
+      outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([
+          outputs_op, state_tuple_op, inp_grad_op, (hgrad_op, cgrad_op),
+          wgrad_op, bgrad_op
+      ])
+      (cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
+       cu_bgrad) = sess.run(
+           [
+               cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
+               (cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
+           ],
+           feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
 
     logging.vlog(1, "outputs: %s" % outputs)
     logging.vlog(1, "cu_outputs: %s" % cu_outputs)
@@ -205,11 +247,20 @@ def RunLSTM(sess,
     logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad))
     logging.vlog(1, "wgrad: %s" % str(wgrad))
     logging.vlog(1, "bgrad: %s" % str(bgrad))
+    if num_proj:
+      logging.vlog(1, "pwgrad: %s" % str(bgrad))
     logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad))
     logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad))
-    return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
-            cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad,
-            cu_bgrad)
+    if num_proj:
+      logging.vlog(1, "cu_pwgrad: %s" % str(cu_bgrad))
+    if num_proj:
+      return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
+              cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad,
+              cu_wgrad, cu_bgrad, cu_pwgrad)
+    else:
+      return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
+              cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad,
+              cu_bgrad)
   else:
     outputs, state_tuple = sess.run([outputs_op, state_tuple_op])
     cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op],
@@ -256,7 +307,6 @@ NAMED_RNN_TESTCASES = ({
     "num_layers": 1,
 })
 
-
 def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs):
   """Expands testcase with new config dimensions.
 
@@ -349,19 +399,35 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
                             time_major,
                             dynamic_shape_input=False,
                             rtol=3e-6,
-                            atol=3e-6):
+                            atol=3e-6,
+                            num_proj=None):
     with self.session(use_gpu=True) as sess:
-      (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad,
-       state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM(
-           sess,
-           num_units,
-           input_size,
-           batch_size,
-           time,
-           num_layers,
-           variable_seq_lengths=variable_seq_lengths,
-           time_major=time_major,
-           dynamic_shape_input=dynamic_shape_input)
+      if num_proj is not None and num_proj != 0:
+        (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
+         cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad, cu_wgrad,
+         cu_bgrad, cu_pwgrad) = RunLSTM(
+             sess,
+             num_units,
+             input_size,
+             batch_size,
+             time,
+             num_layers,
+             variable_seq_lengths=variable_seq_lengths,
+             dynamic_shape_input=dynamic_shape_input,
+             num_proj=num_proj)
+      else:
+        (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
+         cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad,
+         cu_bgrad) = RunLSTM(
+             sess,
+             num_units,
+             input_size,
+             batch_size,
+             time,
+             num_layers,
+             variable_seq_lengths=variable_seq_lengths,
+             dynamic_shape_input=dynamic_shape_input,
+             num_proj=num_proj)
 
       self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
       for s, cu_s in zip(state_tuple, cu_state_tuple):
@@ -371,6 +437,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
       self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol)
       self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol)
       self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol)
+      if num_proj is not None and num_proj != 0:
+        self.assertAllClose(pwgrad, cu_pwgrad, rtol=rtol, atol=atol)
 
   @parameterized.named_parameters(
       ExpandNamedTestCases(
@@ -378,20 +446,27 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
               "variable_seq_lengths": [True, False],
               "time_major": [True, False],
               "dynamic_shape_input": [True, False],
+              "use_proj": [True, False],
           }))
   @test_util.run_gpu_only
   def test_training(self, num_units, input_size, batch_size, time, num_layers,
-                    variable_seq_lengths, time_major, dynamic_shape_input):
-    self._test_training_helper(
-        num_units,
-        input_size,
-        batch_size,
-        time,
-        num_layers,
-        dtypes.float32,
-        variable_seq_lengths=variable_seq_lengths,
-        time_major=time_major,
-        dynamic_shape_input=dynamic_shape_input)
+                    variable_seq_lengths, time_major, dynamic_shape_input,
+                    use_proj):
+    with compat.forward_compatibility_horizon(2019, 6, 27):
+      num_proj = num_units // 2
+      if use_proj and num_proj == 0:
+        self.skipTest("num_proj cannot be 0")
+      self._test_training_helper(
+          num_units,
+          input_size,
+          batch_size,
+          time,
+          num_layers,
+          dtypes.float32,
+          variable_seq_lengths=variable_seq_lengths,
+          time_major=time_major,
+          dynamic_shape_input=dynamic_shape_input,
+          num_proj=num_proj if use_proj else None)
 
   @parameterized.named_parameters(
       ExpandNamedTestCases(
@@ -399,52 +474,29 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
               "variable_seq_lengths": [True, False],
               "time_major": [True, False],
               "dynamic_shape_input": [True, False],
+              "use_proj": [True, False],
           }))
   @test_util.run_gpu_only
   def test_training_fp16(self, num_units, input_size, batch_size, time,
                          num_layers, variable_seq_lengths, time_major,
-                         dynamic_shape_input):
-    self._test_training_helper(
-        num_units,
-        input_size,
-        batch_size,
-        time,
-        num_layers,
-        dtypes.float16,
-        rtol=5e-3,
-        atol=5e-4,
-        variable_seq_lengths=variable_seq_lengths,
-        time_major=time_major,
-        dynamic_shape_input=dynamic_shape_input)
-
-  @parameterized.named_parameters(
-      ExpandNamedTestCases(
-          NAMED_RNN_TESTCASES, **{
-              "variable_seq_lengths": [True, False],
-              "time_major": [True, False],
-              "dynamic_shape_input": [True, False],
-          }))
-  @test_util.run_gpu_only
-  def test_inference(self, num_units, input_size, batch_size, time, num_layers,
-                     variable_seq_lengths, time_major, dynamic_shape_input):
-    with self.session(use_gpu=True) as sess:
-      (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
-          sess,
+                         dynamic_shape_input, use_proj):
+    with compat.forward_compatibility_horizon(2019, 6, 27):
+      num_proj = num_units // 2
+      if use_proj and num_proj == 0:
+        self.skipTest("num_proj cannot be 0")
+      self._test_training_helper(
           num_units,
           input_size,
           batch_size,
           time,
           num_layers,
-          is_training=False,
+          dtypes.float16,
+          rtol=5e-3,
+          atol=5e-4,
           variable_seq_lengths=variable_seq_lengths,
           time_major=time_major,
-          dynamic_shape_input=dynamic_shape_input)
-
-      self.assertAllClose(outputs, cu_outputs)
-      # h
-      self.assertAllClose(state_tuple.h, cu_state_tuple.h)
-      # c
-      self.assertAllClose(state_tuple.c, cu_state_tuple.c)
+          dynamic_shape_input=dynamic_shape_input,
+          num_proj=num_proj if use_proj else None)
 
   @parameterized.named_parameters(
       ExpandNamedTestCases(
@@ -452,33 +504,75 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
               "variable_seq_lengths": [True, False],
               "time_major": [True, False],
               "dynamic_shape_input": [True, False],
+              "use_proj": [True, False],
+          }))
+  @test_util.run_gpu_only
+  def test_inference(self, num_units, input_size, batch_size, time, num_layers,
+                     variable_seq_lengths, time_major, dynamic_shape_input,
+                     use_proj):
+    with compat.forward_compatibility_horizon(2019, 6, 27):
+      num_proj = num_units // 2
+      if use_proj and num_proj == 0:
+        self.skipTest("num_proj cannot be 0")
+      with self.session(use_gpu=True) as sess:
+        (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
+            sess,
+            num_units,
+            input_size,
+            batch_size,
+            time,
+            num_layers,
+            is_training=False,
+            variable_seq_lengths=variable_seq_lengths,
+            time_major=time_major,
+            dynamic_shape_input=dynamic_shape_input,
+            num_proj=num_proj if use_proj else None)
+
+        self.assertAllClose(outputs, cu_outputs)
+        # h
+        self.assertAllClose(state_tuple.h, cu_state_tuple.h)
+        # c
+        self.assertAllClose(state_tuple.c, cu_state_tuple.c)
+
+  @parameterized.named_parameters(
+      ExpandNamedTestCases(
+          NAMED_RNN_TESTCASES, **{
+              "variable_seq_lengths": [True, False],
+              "time_major": [True, False],
+              "dynamic_shape_input": [True, False],
+              "use_proj": [True, False],
           }))
   @test_util.run_gpu_only
   def test_inference_fp16(self, num_units, input_size, batch_size, time,
                           num_layers, variable_seq_lengths, time_major,
-                          dynamic_shape_input):
-    with self.session(use_gpu=True) as sess:
-      (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
-          sess,
-          num_units,
-          input_size,
-          batch_size,
-          time,
-          num_layers,
-          is_training=False,
-          dtype=dtypes.float16,
-          variable_seq_lengths=variable_seq_lengths,
-          time_major=time_major,
-          dynamic_shape_input=dynamic_shape_input)
+                          dynamic_shape_input, use_proj):
+    with compat.forward_compatibility_horizon(2019, 6, 27):
+      num_proj = num_units // 2
+      if use_proj and num_proj == 0:
+        self.skipTest("num_proj cannot be 0")
+      with self.session(use_gpu=True) as sess:
+        (outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
+            sess,
+            num_units,
+            input_size,
+            batch_size,
+            time,
+            num_layers,
+            is_training=False,
+            dtype=dtypes.float16,
+            variable_seq_lengths=variable_seq_lengths,
+            time_major=time_major,
+            dynamic_shape_input=dynamic_shape_input,
+            num_proj=num_proj if use_proj else None)
 
-      rtol, atol = 5e-3, 5e-4
-      self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
-      # h
-      self.assertAllClose(
-          state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol)
-      # c
-      self.assertAllClose(
-          state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol)
+        rtol, atol = 5e-3, 5e-4
+        self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
+        # h
+        self.assertAllClose(
+            state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol)
+        # c
+        self.assertAllClose(
+            state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol)
 
   @parameterized.named_parameters(
       ExpandNamedTestCases(
@@ -486,49 +580,56 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
               "variable_seq_lengths": [True, False],
               "time_major": [True, False],
               "dynamic_shape_input": [True, False],
+              "use_proj": [True, False],
           }))
   @test_util.run_gpu_only
   def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
                                   num_layers, variable_seq_lengths, time_major,
-                                  dynamic_shape_input):
+                                  dynamic_shape_input, use_proj):
     """Validates that dropout does not affect Cudnn Rnn inference."""
-    # Hand-picked dropouts are used below (0. and 1.)
-    with ops.Graph().as_default() as g:
-      with self.session(use_gpu=True, graph=g) as sess:
-        # 1st time w/o dropout.
-        (_, cu_outputs, _, cu_state_tuple) = RunLSTM(
-            sess,
-            num_units,
-            input_size,
-            batch_size,
-            time,
-            num_layers,
-            is_training=False,
-            dropout=0.,
-            variable_seq_lengths=variable_seq_lengths,
-            time_major=time_major,
-            dynamic_shape_input=dynamic_shape_input)
+    with compat.forward_compatibility_horizon(2019, 6, 27):
+      num_proj = num_units // 2
+      if use_proj and num_proj == 0:
+        self.skipTest("num_proj cannot be 0")
+      # Hand-picked dropouts are used below (0. and 1.)
+      with ops.Graph().as_default() as g:
+        with self.session(use_gpu=True, graph=g) as sess:
+          # 1st time w/o dropout.
+          (_, cu_outputs, _, cu_state_tuple) = RunLSTM(
+              sess,
+              num_units,
+              input_size,
+              batch_size,
+              time,
+              num_layers,
+              is_training=False,
+              dropout=0.,
+              variable_seq_lengths=variable_seq_lengths,
+              time_major=time_major,
+              dynamic_shape_input=dynamic_shape_input,
+              num_proj=num_proj if use_proj else None)
 
-    with ops.Graph().as_default() as g:
-      with self.session(use_gpu=True, graph=g) as sess:
-        (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM(
-            sess,
-            num_units,
-            input_size,
-            batch_size,
-            time,
-            num_layers,
-            is_training=False,
-            dropout=1.,
-            variable_seq_lengths=variable_seq_lengths,
-            time_major=time_major,
-            dynamic_shape_input=dynamic_shape_input)
+      with ops.Graph().as_default() as g:
+        with self.session(use_gpu=True, graph=g) as sess:
+          (_, cu_outputs2, _, cu_state_tuple2) = RunLSTM(
+              sess,
+              num_units,
+              input_size,
+              batch_size,
+              time,
+              num_layers,
+              is_training=False,
+              dropout=1.,
+              variable_seq_lengths=variable_seq_lengths,
+              time_major=time_major,
+              dynamic_shape_input=dynamic_shape_input,
+              num_proj=num_proj if use_proj else None)
 
-    self.assertAllClose(cu_outputs, cu_outputs2)
-    # h
-    self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h)
-    # c
-    self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c)
+      self.assertAllClose(cu_outputs, cu_outputs2)
+      # h
+      self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h)
+      # c
+      self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c)
 
 
 def RunGRU(sess,
@@ -890,40 +991,68 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
                                      parameterized.TestCase):
   """Class for testing various format converters."""
 
-  def _test_lstm_helper(self, num_units, input_size, num_layers, direction):
+  def _test_lstm_helper(self,
+                        num_units,
+                        input_size,
+                        num_layers,
+                        direction,
+                        num_proj=None):
     with self.session(use_gpu=True) as sess:
       random_seed.set_random_seed(0)
       np.random.seed(0)
 
       num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2
       format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM(
-          num_layers, num_units, input_size, direction=direction)
+          num_layers,
+          num_units,
+          input_size,
+          direction=direction,
+          num_proj=num_proj if num_proj else None)
 
-      ws, bs = [], []
+      ws, bs, pws = [], [], []
       for _ in range(num_layers * num_dirs):
         w = constant_op.constant(
-            np.random.rand(input_size + num_units, 4 * num_units),
+            np.random.rand(input_size + (num_proj if num_proj else num_units),
+                           4 * num_units),
             dtype=dtypes.float32)
         b = constant_op.constant(
             np.random.rand(4 * num_units), dtype=dtypes.float32)
         ws.append(w)
         bs.append(b)
+        if num_proj:
+          pw = constant_op.constant(
+              np.random.rand(num_units, num_proj), dtype=dtypes.float32)
+          pws.append(pw)
+
+      if num_proj:
+        opaque_params = format_converter.tf_canonical_to_opaque(ws + bs, pws)
+      else:
+        opaque_params = format_converter.tf_canonical_to_opaque(ws + bs)
 
-      opaque_params = format_converter.tf_canonical_to_opaque(ws + bs)
       opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size(
           cudnn_rnn_ops.CUDNN_LSTM,
           num_layers,
           num_units,
           input_size,
-          direction=direction)
+          direction=direction,
+          num_proj=num_proj if num_proj else None)
 
-      ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params)
+      if num_proj:
+        ws_r, bs_r, pws_r = format_converter.opaque_to_tf_canonical(
+            opaque_params)
+        ws, ws_r, pws, bs, bs_r, pws_r = sess.run(
+            [ws, ws_r, pws, bs, bs_r, pws_r])
+      else:
+        ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params)
+        ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r])
 
       # Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical()
       # returns the original input.
-      ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r])
       for w, w_r in zip(ws, ws_r):
         self.assertAllClose(w, w_r)
+      if num_proj:
+        for pw, pw_r in zip(pws, pws_r):
+          self.assertAllClose(pw, pw_r)
       for b, b_r in zip(bs, bs_r):
         self.assertAllClose(b, b_r)
 
@@ -942,6 +1071,22 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
     self._test_lstm_helper(num_units, input_size, num_layers,
                            cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
 
+  @parameterized.named_parameters(
+      (c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"])
+      for c in NAMED_RNN_TESTCASES)
+  @test_util.run_gpu_only
+  def test_lstmp(self, num_units, input_size, num_layers):
+    with compat.forward_compatibility_horizon(2019, 6, 27):
+      num_proj = num_units // 2
+      if num_proj == 0:
+        self.skipTest("num_proj cannot be 0")
+      self._test_lstm_helper(
+          num_units,
+          input_size,
+          num_layers,
+          cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION,
+          num_proj=num_proj)
+
   @parameterized.named_parameters((c["testcase_name"], c["num_units"],
                                    c["input_size"], c["num_layers"])
                                   for c in NAMED_RNN_TESTCASES)
@@ -950,6 +1095,22 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
     self._test_lstm_helper(num_units, input_size, num_layers,
                            cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
 
+  @parameterized.named_parameters(
+      (c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"])
+      for c in NAMED_RNN_TESTCASES)
+  @test_util.run_gpu_only
+  def test_lstmp_bidi(self, num_units, input_size, num_layers):
+    with compat.forward_compatibility_horizon(2019, 6, 27):
+      num_proj = num_units // 2
+      if num_proj == 0:
+        self.skipTest("num_proj cannot be 0")
+      self._test_lstm_helper(
+          num_units,
+          input_size,
+          num_layers,
+          cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION,
+          num_proj=num_proj)
+
   def _test_gru_helper(self, num_units, input_size, num_layers, direction):
     with self.session(use_gpu=True) as sess:
       random_seed.set_random_seed(0)
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 3694d112ce4..7ee00ade6d2 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 import os
 from tensorflow.contrib.checkpoint.python import split_dependency
 from tensorflow.contrib.rnn.python.ops import lstm_ops
+from tensorflow.python.compat import compat
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
@@ -186,6 +187,7 @@ class CudnnParamsFormatConverter(object):
                num_layers,
                num_units,
                input_size,
+               num_proj=None,
                input_mode=CUDNN_INPUT_LINEAR_MODE,
                direction=CUDNN_RNN_UNIDIRECTION):
     """Constructor.
@@ -195,6 +197,8 @@ class CudnnParamsFormatConverter(object):
       num_units: the number of units within the RNN model.
       input_size: the size of the input, it could be different from the
         num_units.
+      num_proj: The output dimensionality for the projection matrices.
+        If None or 0, no projection is performed.
       input_mode: indicate whether there is a linear projection between the
         input and the actual computation before the first layer. It could be one
         of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input'
@@ -209,14 +213,16 @@ class CudnnParamsFormatConverter(object):
     self._input_size = input_size
     self._num_units = num_units
     self._input_mode = input_mode
+    self._num_proj = num_proj
     self._direction = direction
     self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2
     self._num_params = (
         self._num_params_per_layer * self._num_layers * self._num_dirs)
 
-  def tf_canonical_to_opaque(self, tf_canonicals):
+  def tf_canonical_to_opaque(self, tf_canonicals, weights_proj=None):
     r"""Converts tf canonical weights to cudnn opaque param."""
-    cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals)
+    cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(
+        tf_canonicals, weights_proj)
     cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights]
     opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases)
     return opaque_params
@@ -224,8 +230,14 @@ class CudnnParamsFormatConverter(object):
   def opaque_to_tf_canonical(self, opaque_param):
     r"""Converts cudnn opaque param to tf canonical weights."""
     cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param)
-    weights, biases = self._cu_canonical_to_tf_canonical(cu_weights, cu_biases)
-    return weights, biases
+    if self._num_proj:
+      weights, biases, weights_proj = self._cu_canonical_to_tf_canonical(
+          cu_weights, cu_biases)
+      return weights, biases, weights_proj
+    else:
+      weights, biases = self._cu_canonical_to_tf_canonical(
+          cu_weights, cu_biases)
+      return weights, biases
 
   def _opaque_to_cu_canonical(self, opaque_param):
     """Converts opaque params to Cudnn canonical format.
@@ -238,15 +250,31 @@ class CudnnParamsFormatConverter(object):
       2 list for weights and biases respectively.
     """
     with ops.device("/gpu:0"):
-      weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
-          num_layers=self._num_layers,
-          num_units=self._num_units,
-          input_size=self._input_size,
-          params=opaque_param,
-          num_params=self._num_params,
-          rnn_mode=self._rnn_mode,
-          input_mode=self._input_mode,
-          direction=self._direction)
+      if compat.forward_compatible(2019, 6, 26) and self._num_proj:
+        num_params_weights = (
+            self._num_params + 1 * self._num_layers * self._num_dirs)
+        num_params_biases = self._num_params
+        weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2(
+            num_layers=self._num_layers,
+            num_units=self._num_units,
+            input_size=self._input_size,
+            params=opaque_param,
+            rnn_mode=self._rnn_mode,
+            input_mode=self._input_mode,
+            direction=self._direction,
+            num_params_weights=num_params_weights,
+            num_params_biases=num_params_biases,
+            num_proj=self._num_proj)
+      else:
+        weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
+            num_layers=self._num_layers,
+            num_units=self._num_units,
+            input_size=self._input_size,
+            params=opaque_param,
+            num_params=self._num_params,
+            rnn_mode=self._rnn_mode,
+            input_mode=self._input_mode,
+            direction=self._direction)
       return (weights, biases)
 
   def _cu_canonical_to_opaque(self, cu_weights, cu_biases):
@@ -260,15 +288,27 @@ class CudnnParamsFormatConverter(object):
       a single opaque tensor.
     """
     with ops.device("/gpu:0"):
-      return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
-          num_layers=self._num_layers,
-          num_units=self._num_units,
-          input_size=self._input_size,
-          weights=cu_weights,
-          biases=cu_biases,
-          rnn_mode=self._rnn_mode,
-          input_mode=self._input_mode,
-          direction=self._direction)
+      if compat.forward_compatible(2019, 6, 26) and self._num_proj:
+        return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2(
+            num_layers=self._num_layers,
+            num_units=self._num_units,
+            input_size=self._input_size,
+            weights=cu_weights,
+            biases=cu_biases,
+            rnn_mode=self._rnn_mode,
+            input_mode=self._input_mode,
+            num_proj=self._num_proj,
+            direction=self._direction)
+      else:
+        return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
+            num_layers=self._num_layers,
+            num_units=self._num_units,
+            input_size=self._input_size,
+            weights=cu_weights,
+            biases=cu_biases,
+            rnn_mode=self._rnn_mode,
+            input_mode=self._input_mode,
+            direction=self._direction)
 
   def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases):
     r"""Transform from Cudnn canonical to tf canonical.
@@ -294,9 +334,11 @@ class CudnnParamsFormatConverter(object):
       1 tuple, tf canonical weights and biases.
     """
     tf_weights, tf_biases = [], []
+    tf_weights_proj = []
 
     layer_weights_num = self._num_params_per_layer * self._num_dirs
     layer_biases_num = layer_weights_num
+    layer_weights_num += (1 * self._num_dirs) if self._num_proj else 0
 
     for i in range(self._num_layers):
       layer_weights = cu_weights[i * layer_weights_num:(i + 1) *
@@ -305,7 +347,8 @@ class CudnnParamsFormatConverter(object):
       if self._direction == CUDNN_RNN_UNIDIRECTION:
         self._cu_canonical_to_tf_canonical_single_layer(layer_weights,
                                                         layer_biases,
-                                                        tf_weights, tf_biases)
+                                                        tf_weights, tf_biases,
+                                                        tf_weights_proj)
       else:
         fw_weights = layer_weights[:len(layer_weights) // 2]
         bw_weights = layer_weights[len(layer_weights) // 2:]
@@ -317,6 +360,7 @@ class CudnnParamsFormatConverter(object):
             fw_biases,
             tf_weights,
             tf_biases,
+            tf_weights_proj,
         )
 
         self._cu_canonical_to_tf_canonical_single_layer(
@@ -324,11 +368,19 @@ class CudnnParamsFormatConverter(object):
             bw_biases,
             tf_weights,
             tf_biases,
+            tf_weights_proj,
         )
-    return (tf_weights, tf_biases)
+    if self._num_proj:
+      return (tf_weights, tf_biases, tf_weights_proj)
+    else:
+      return (tf_weights, tf_biases)
 
-  def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
-                                                 tf_weights, tf_biases):
+  def _cu_canonical_to_tf_canonical_single_layer(self,
+                                                 cu_weights,
+                                                 cu_biases,
+                                                 tf_weights,
+                                                 tf_biases,
+                                                 tf_weigths_proj=None):
     r"""Transform single layer Cudnn canonicals to tf canonicals.
 
     The elements of cu_weights, cu_biases are laid out in the following format:
@@ -343,7 +395,7 @@ class CudnnParamsFormatConverter(object):
     """
     raise NotImplementedError("Abstract method")
 
-  def _tf_canonical_to_cu_canonical(self, tf_canonicals):
+  def _tf_canonical_to_cu_canonical(self, tf_canonicals, weights_proj):
     r"""Transform from tf canonical to Cudnn canonical.
 
     This is the reverse routine of _TransformCanonical().
@@ -362,6 +414,7 @@ class CudnnParamsFormatConverter(object):
            ---------------
            |fwd   |bak   |
            ---------------
+      weights_proj: (optional) weights matrices for projection
     Returns:
       2 lists: the recovered cudnn canonical weights and biases.
     """
@@ -376,6 +429,9 @@ class CudnnParamsFormatConverter(object):
       layer_biases = biases[i * layer_biases_num:(i + 1) * layer_biases_num]
       if self._direction == CUDNN_RNN_UNIDIRECTION:
         cu_weights.extend(self._tf_to_cudnn_weights(i, *layer_weights))
+        if weights_proj is not None:
+          pw = array_ops.transpose(weights_proj[i])
+          cu_weights.append(pw)
         cu_biases.extend(self._tf_to_cudnn_biases(*layer_biases))
       else:
         fw_weights, bw_weights = layer_weights[:len(layer_weights) //
@@ -385,9 +441,15 @@ class CudnnParamsFormatConverter(object):
                                             2], layer_biases[len(layer_biases
                                                                 ) // 2:]
         cu_weights.extend(self._tf_to_cudnn_weights(i, *fw_weights))
+        if weights_proj is not None:
+          pw0 = array_ops.transpose(weights_proj[2 * i + 0])
+          cu_weights.append(pw0)
         cu_biases.extend(self._tf_to_cudnn_biases(*fw_biases))
 
         cu_weights.extend(self._tf_to_cudnn_weights(i, *bw_weights))
+        if weights_proj is not None:
+          pw1 = array_ops.transpose(weights_proj[2 * i + 1])
+          cu_weights.append(pw1)
         cu_biases.extend(self._tf_to_cudnn_biases(*bw_biases))
     return cu_weights, cu_biases
 
@@ -423,7 +485,10 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
 
   def _cudnn_to_tf_weights(self, *cu_weights):
     r"""Stitching cudnn canonical weights to generate tf canonical weights."""
-    w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights
+    if self._num_proj:
+      w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o, pw = cu_weights
+    else:
+      w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights
 
     # pylint: disable=invalid-name
     W_i = array_ops.concat([w_i, r_i], axis=1)
@@ -433,7 +498,11 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
     # pylint: enable=invalid-name
     # Cudnn LSTM weights are in ifco order, other tf LSTMs are in icfo order.
     reordered = self._cudnn_to_tf_gate_params(*[W_i, W_f, W_c, W_o])
-    return (array_ops.transpose(array_ops.concat(reordered, axis=0)),)
+    if self._num_proj:
+      return (array_ops.transpose(array_ops.concat(reordered, axis=0)),
+              array_ops.transpose(pw))
+    else:
+      return (array_ops.transpose(array_ops.concat(reordered, axis=0)),)
 
   def _tf_to_cudnn_weights(self, layer, *tf_weights):
     r"""Reverse the operations in StitchWeights()."""
@@ -442,7 +511,7 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
     if layer == 0:
       input_weight_width = input_size
     else:
-      input_weight_width = num_units
+      input_weight_width = self._num_proj if self._num_proj else num_units
       if self._direction == CUDNN_RNN_BIDIRECTION:
         input_weight_width *= 2
 
@@ -452,10 +521,15 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
     W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params(
         *array_ops.split(w, 4, axis=0))
 
-    w_i, r_i = array_ops.split(W_i, [input_weight_width, num_units], axis=1)
-    w_c, r_c = array_ops.split(W_c, [input_weight_width, num_units], axis=1)
-    w_f, r_f = array_ops.split(W_f, [input_weight_width, num_units], axis=1)
-    w_o, r_o = array_ops.split(W_o, [input_weight_width, num_units], axis=1)
+    hidden_state_width = self._num_proj if self._num_proj else num_units
+    w_i, r_i = array_ops.split(
+        W_i, [input_weight_width, hidden_state_width], axis=1)
+    w_c, r_c = array_ops.split(
+        W_c, [input_weight_width, hidden_state_width], axis=1)
+    w_f, r_f = array_ops.split(
+        W_f, [input_weight_width, hidden_state_width], axis=1)
+    w_o, r_o = array_ops.split(
+        W_o, [input_weight_width, hidden_state_width], axis=1)
     return w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o
     # pylint: enable=invalid-name
 
@@ -490,11 +564,20 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
     # Return ifco order for Cudnn LSTM.
     return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro
 
-  def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
-                                                 tf_weights, tf_biases):
-    (w,) = self._cudnn_to_tf_weights(*cu_weights)
+  def _cu_canonical_to_tf_canonical_single_layer(self,
+                                                 cu_weights,
+                                                 cu_biases,
+                                                 tf_weights,
+                                                 tf_biases,
+                                                 tf_weights_proj=None):
+    if self._num_proj:
+      (w, pw) = self._cudnn_to_tf_weights(*cu_weights)
+      tf_weights.append(w)
+      tf_weights_proj.append(pw)
+    else:
+      (w,) = self._cudnn_to_tf_weights(*cu_weights)
+      tf_weights.append(w)
     (b,) = self._cudnn_to_tf_biases(*cu_biases)
-    tf_weights.append(w)
     tf_biases.append(b)
 
 
@@ -561,8 +644,12 @@ class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter):
     b_ri, b_rr = array_ops.split(br, 2, axis=0)
     return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh
 
-  def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
-                                                 tf_weights, tf_biases):
+  def _cu_canonical_to_tf_canonical_single_layer(self,
+                                                 cu_weights,
+                                                 cu_biases,
+                                                 tf_weights,
+                                                 tf_biases,
+                                                 tf_weights_proj=None):
     # pylint: disable=invalid-name
     W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights)
     b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases)
@@ -735,8 +822,11 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
   def format_converter(self):
     if self._format_converter is None:
       self._format_converter = self._format_converter_cls(
-          self._num_layers, self._num_units, self._input_size, self._input_mode,
-          self._direction)
+          self._num_layers,
+          self._num_units,
+          self._input_size,
+          input_mode=self._input_mode,
+          direction=self._direction)
     return self._format_converter
 
   def restore(self, restored_tensors, restored_shapes):
@@ -970,6 +1060,7 @@ def _cudnn_rnn(inputs,
                direction=CUDNN_RNN_UNIDIRECTION,
                dropout=0.,
                seed=0,
+               num_proj=None,
                name=None):
   """Cudnn RNN.
 
@@ -1006,6 +1097,8 @@ def _cudnn_rnn(inputs,
     dropout: whether to enable dropout. With it is 0, dropout is disabled.
     seed: the op seed used for initializing dropout. See
       `tf.compat.v1.set_random_seed` for behavior.
+    num_proj: The output dimensionality for the projection matrices.
+      If None or 0, no projection is performed.
     name: name of the operation.
 
   Returns:
@@ -1035,13 +1128,16 @@ def _cudnn_rnn(inputs,
   if sequence_lengths is not None:
     args["sequence_lengths"] = sequence_lengths
     args["time_major"] = time_major
+    args["num_proj"] = 0 if num_proj is None else num_proj
     outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
-  elif time_major is False:
-    batch_size = array_ops.shape(inputs)[0]
-    max_time = array_ops.shape(inputs)[1]
+  elif time_major is False or num_proj:
+    batch_id, time_id = (1, 0) if time_major else (0, 1)
+    batch_size = array_ops.shape(inputs)[batch_id]
+    max_time = array_ops.shape(inputs)[time_id]
     sequence_lengths = array_ops.fill([batch_size], max_time)
     args["sequence_lengths"] = sequence_lengths
     args["time_major"] = time_major
+    args["num_proj"] = 0 if num_proj is None else num_proj
     outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
   elif use_cudnn_v2 != "1":
     outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
@@ -1061,6 +1157,7 @@ def cudnn_lstm(inputs,
                direction=CUDNN_RNN_UNIDIRECTION,
                dropout=0.,
                seed=0,
+               num_proj=None,
                name=None):
   """Cudnn LSTM.
 
@@ -1096,6 +1193,8 @@ def cudnn_lstm(inputs,
     dropout: whether to enable dropout. With it is 0, dropout is disabled.
     seed: the op seed used for initializing dropout. See
       `tf.compat.v1.set_random_seed` for behavior.
+    num_proj: The output dimensionality for the projection matrices.
+      If None or 0, no projection is performed.
     name: name of the operation.
 
   Returns:
@@ -1103,7 +1202,7 @@ def cudnn_lstm(inputs,
   """
   return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
                     sequence_lengths, time_major, input_mode, direction,
-                    dropout, seed, name)
+                    dropout, seed, num_proj, name)
 
 
 def _cudnn_rnn_no_input_c(inputs,
@@ -1160,7 +1259,7 @@ def _cudnn_rnn_no_input_c(inputs,
   outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params,
                                     is_training, rnn_mode, sequence_lengths,
                                     time_major, input_mode, direction, dropout,
-                                    seed, name)
+                                    seed, None, name)
   return outputs, output_h
 
 
@@ -1331,6 +1430,7 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
                                          direction=CUDNN_RNN_UNIDIRECTION,
                                          dropout=0,
                                          seed=0,
+                                         num_proj=None,
                                          name=None):
   """Convert cudnn opaque params to canonical.
 
@@ -1353,6 +1453,8 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
     dropout: whether to enable dropout. With it is 0, dropout is disabled.
     seed: the op seed used for initializing dropout. See
       `tf.compat.v1.set_random_seed` for behavior.
+    num_proj: The output dimensionality for the projection matrices.
+      If None or 0, no projection is performed.
     name: name of the operation.
 
   Returns:
@@ -1366,19 +1468,39 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
   check_input_mode(input_mode)
   num_params = _get_num_params(rnn_mode, num_layers, direction)
   seed, seed2 = random_seed.get_seed(seed)
-  weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
-      rnn_mode=rnn_mode,
-      num_layers=num_layers,
-      num_units=num_units,
-      input_size=input_size,
-      params=params,
-      input_mode=input_mode,
-      direction=direction,
-      dropout=dropout,
-      seed=seed,
-      seed2=seed2,
-      num_params=num_params,
-      name=name)
+  num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2
+  if num_proj is not None and num_proj != 0:
+    num_params_weights = (num_params + 1 * num_layers * num_dirs)
+    num_params_biases = num_params
+    weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2(
+        rnn_mode=rnn_mode,
+        num_layers=num_layers,
+        num_units=num_units,
+        input_size=input_size,
+        params=params,
+        input_mode=input_mode,
+        direction=direction,
+        dropout=dropout,
+        seed=seed,
+        seed2=seed2,
+        num_params_weights=num_params_weights,
+        num_params_biases=num_params_biases,
+        num_proj=num_proj,
+        name=name)
+  else:
+    weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
+        rnn_mode=rnn_mode,
+        num_layers=num_layers,
+        num_units=num_units,
+        input_size=input_size,
+        params=params,
+        input_mode=input_mode,
+        direction=direction,
+        dropout=dropout,
+        seed=seed,
+        seed2=seed2,
+        num_params=num_params,
+        name=name)
   return weights, biases
 
 
@@ -1392,6 +1514,7 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
                                          direction=CUDNN_RNN_UNIDIRECTION,
                                          dropout=0,
                                          seed=0,
+                                         num_proj=None,
                                          name=None):
   """Converts params from the canonical format to a specific format of cuDNN.
 
@@ -1415,6 +1538,8 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
     dropout: whether to enable dropout. With it is 0, dropout is disabled.
     seed: the op seed used for initializing dropout. See
       `tf.compat.v1.set_random_seed` for behavior.
+    num_proj: The output dimensionality for the projection matrices.
+      If None or 0, no projection is performed.
     name: name of the operation.
 
   Returns:
@@ -1426,19 +1551,35 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
   check_direction(direction)
   check_input_mode(input_mode)
   seed, seed2 = random_seed.get_seed(seed)
-  return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
-      rnn_mode=rnn_mode,
-      num_layers=num_layers,
-      num_units=num_units,
-      input_size=input_size,
-      weights=weights,
-      biases=biases,
-      input_mode=input_mode,
-      direction=direction,
-      dropout=dropout,
-      seed=seed,
-      seed2=seed2,
-      name=name)
+  if num_proj is not None and num_proj != 0:
+    return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2(
+        rnn_mode=rnn_mode,
+        num_layers=num_layers,
+        num_units=num_units,
+        input_size=input_size,
+        weights=weights,
+        biases=biases,
+        input_mode=input_mode,
+        direction=direction,
+        dropout=dropout,
+        seed=seed,
+        seed2=seed2,
+        num_proj=num_proj,
+        name=name)
+  else:
+    return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
+        rnn_mode=rnn_mode,
+        num_layers=num_layers,
+        num_units=num_units,
+        input_size=input_size,
+        weights=weights,
+        biases=biases,
+        input_mode=input_mode,
+        direction=direction,
+        dropout=dropout,
+        seed=seed,
+        seed2=seed2,
+        name=name)
 
 
 def cudnn_rnn_opaque_params_size(rnn_mode,
@@ -1450,6 +1591,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode,
                                  dtype=dtypes.float32,
                                  dropout=0,
                                  seed=0,
+                                 num_proj=None,
                                  name=None):
   """Returns opaque params size for specific Cudnn config.
 
@@ -1472,6 +1614,8 @@ def cudnn_rnn_opaque_params_size(rnn_mode,
     dropout: whether to enable dropout. With it is 0, dropout is disabled.
     seed: the op seed used for initializing dropout. See
       `tf.compat.v1.set_random_seed` for behavior.
+    num_proj: The output dimensionality for the projection matrices.
+      If None or 0, no projection is performed.
     name: name of the operation.
 
   Returns:
@@ -1488,6 +1632,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode,
       num_layers=num_layers,
       num_units=num_units,
       input_size=input_size,
+      num_proj=num_proj,
       T=dtype,
       S=dtypes.int32,
       dropout=dropout,
@@ -1516,7 +1661,8 @@ class _CudnnRNN(object):
                direction=CUDNN_RNN_UNIDIRECTION,
                dtype=dtypes.float32,
                dropout=0.,
-               seed=0):
+               seed=0,
+               num_proj=None):
     """Creates a CudnnRNN model from model spec.
 
     Args:
@@ -1539,6 +1685,8 @@ class _CudnnRNN(object):
       dropout: whether to enable dropout. With it is 0, dropout is disabled.
       seed: the op seed used for initializing dropout. See
         `tf.compat.v1.set_random_seed` for behavior.
+      num_proj: The output dimensionality for the projection matrices.
+        If None or 0, no projection is performed.
 
     Raises:
       ValueError: if direction is invalid.
@@ -1552,6 +1700,7 @@ class _CudnnRNN(object):
     self._dtype = dtype
     self._dropout = dropout
     self._seed = seed
+    self._num_proj = num_proj
 
   @property
   def input_mode(self):
@@ -1577,6 +1726,10 @@ class _CudnnRNN(object):
   def direction(self):
     return self._direction
 
+  @property
+  def num_proj(self):
+    return self._num_proj
+
   def params_size(self):
     """Calculates the size of the opaque parameter buffer needed for this model.
 
@@ -1588,6 +1741,7 @@ class _CudnnRNN(object):
         num_layers=self._num_layers,
         num_units=self._num_units,
         input_size=self._input_size,
+        num_proj=self._num_proj,
         dtype=self._dtype,
         dropout=self._dropout,
         seed=self._seed,
@@ -1643,7 +1797,8 @@ class _CudnnRNN(object):
         input_mode=self._input_mode,
         direction=self._direction,
         dropout=self._dropout,
-        seed=self._seed)
+        seed=self._seed,
+        num_proj=self._num_proj)
 
   def params_to_canonical(self, params):
     """Converts params from a specific format of cuDNN to the canonical format.
@@ -1663,7 +1818,8 @@ class _CudnnRNN(object):
         input_mode=self._input_mode,
         direction=self._direction,
         dropout=self._dropout,
-        seed=self._seed)
+        seed=self._seed,
+        num_proj=self._num_proj)
 
   def canonical_to_params(self, weights, biases):
     """Converts params from the canonical format to a specific format of cuDNN.
@@ -1685,7 +1841,8 @@ class _CudnnRNN(object):
         input_mode=self._input_mode,
         direction=self._direction,
         dropout=self._dropout,
-        seed=self._seed)
+        seed=self._seed,
+        num_proj=self._num_proj)
 
 
 class CudnnLSTM(_CudnnRNN):
@@ -1703,7 +1860,8 @@ class CudnnLSTM(_CudnnRNN):
                direction=CUDNN_RNN_UNIDIRECTION,
                dtype=dtypes.float32,
                dropout=0.,
-               seed=0):
+               seed=0,
+               num_proj=None):
     """Creates a Cudnn LSTM model from model spec.
 
     Args:
@@ -1721,6 +1879,8 @@ class CudnnLSTM(_CudnnRNN):
       dtype: dtype of params, tf.float32 or tf.float64.
       dropout: whether to enable dropout. With it is 0, dropout is disabled.
       seed: the seed used for initializing dropout.
+      num_proj: The output dimensionality for the projection matrices.
+          If None or 0, no projection is performed.
     """
     super(CudnnLSTM, self).__init__(
         CUDNN_LSTM,
@@ -1731,7 +1891,8 @@ class CudnnLSTM(_CudnnRNN):
         direction=direction,
         dtype=dtype,
         dropout=dropout,
-        seed=seed)
+        seed=seed,
+        num_proj=num_proj)
 
   def __call__(self,
                input_data,
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt
new file mode 100644
index 00000000000..cac3674c1ae
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt
@@ -0,0 +1,36 @@
+op {
+  graph_op_name: "CudnnRNNCanonicalToParamsV2"
+  summary: "Converts CudnnRNN params from canonical form to usable form. It supports the projection in LSTM."
+  description: <<END
+Writes a set of weights into the opaque params buffer so they can be used in
+upcoming training or inferences.
+
+Note that the params buffer may not be compatible across different GPUs. So any
+save and restoration should be converted to and from the canonical weights and
+biases.
+
+num_layers: Specifies the number of layers in the RNN model.
+num_units: Specifies the size of the hidden state.
+input_size: Specifies the size of the input state.
+weights: the canonical form of weights that can be used for saving
+    and restoration. They are more likely to be compatible across different
+    generations.
+biases: the canonical form of biases that can be used for saving
+    and restoration. They are more likely to be compatible across different
+    generations.
+num_params_weigths: number of weight parameter matrix for all layers.
+num_params_biases: number of bias parameter vector for all layers.
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicate whether there is a linear projection between the input and
+    The actual computation before the first layer. 'skip_input' is only allowed
+    when input_size == num_units; 'auto_select' implies 'skip_input' when
+    input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used.
+    dir = (direction == bidirectional) ? 2 : 1
+dropout: dropout probability. When set to 0., dropout is disabled.
+seed: the 1st part of a seed to initialize dropout.
+seed2: the 2nd part of a seed to initialize dropout.
+num_proj: The output dimensionality for the projection matrices. If None or 0,
+    no projection is performed.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt
new file mode 100644
index 00000000000..aa51414ba23
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt
@@ -0,0 +1,36 @@
+op {
+  graph_op_name: "CudnnRNNParamsToCanonicalV2"
+  summary: "Retrieves CudnnRNN params in canonical form. It supports the projection in LSTM."
+  description: <<END
+Retrieves a set of weights from the opaque params buffer that can be saved and
+restored in a way compatible with future runs.
+
+Note that the params buffer may not be compatible across different GPUs. So any
+save and restoration should be converted to and from the canonical weights and
+biases.
+
+num_layers: Specifies the number of layers in the RNN model.
+num_units: Specifies the size of the hidden state.
+input_size: Specifies the size of the input state.
+num_params_weigths: number of weight parameter matrix for all layers.
+num_params_biases: number of bias parameter vector for all layers.
+weights: the canonical form of weights that can be used for saving
+    and restoration. They are more likely to be compatible across different
+    generations.
+biases: the canonical form of biases that can be used for saving
+    and restoration. They are more likely to be compatible across different
+    generations.
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicate whether there is a linear projection between the input and
+    The actual computation before the first layer. 'skip_input' is only allowed
+    when input_size == num_units; 'auto_select' implies 'skip_input' when
+    input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used.
+    dir = (direction == bidirectional) ? 2 : 1
+dropout: dropout probability. When set to 0., dropout is disabled.
+seed: the 1st part of a seed to initialize dropout.
+seed2: the 2nd part of a seed to initialize dropout.
+num_proj: The output dimensionality for the projection matrices. If None or 0,
+    no projection is performed.
+END
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt
new file mode 100644
index 00000000000..d953e7ccbae
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CudnnRNNCanonicalToParamsV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "CudnnRNNCanonicalToParamsV2"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt
new file mode 100644
index 00000000000..a3ca3dfbca3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_CudnnRNNParamsToCanonicalV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "CudnnRNNParamsToCanonicalV2"
+  visibility: HIDDEN
+}
diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index ed6d81d3a0a..09826f57ce5 100644
--- a/tensorflow/core/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
@@ -502,20 +502,23 @@ struct CudnnRnnModelShapes {
   int dir_count;
   int max_seq_length;
   int batch_size;
+  int cell_num_units = 0;
   TensorShape input_shape;
   TensorShape output_shape;
   TensorShape hidden_state_shape;
+  TensorShape cell_state_shape;
   // At present only fields related to cached RnnDescriptor are concerned.
   bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
     return num_layers == rhs.num_layers && input_size == rhs.input_size &&
-           num_units == rhs.num_units && dir_count == rhs.dir_count;
+           num_units == rhs.num_units && dir_count == rhs.dir_count &&
+           cell_num_units == rhs.cell_num_units;
   }
   string DebugString() const {
     return strings::Printf(
         "[num_layers, input_size, num_units, dir_count, max_seq_length, "
-        "batch_size]: [%d, %d, %d, %d, %d, %d] ",
+        "batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
         num_layers, input_size, num_units, dir_count, max_seq_length,
-        batch_size);
+        batch_size, cell_num_units);
   }
 };
 
@@ -562,6 +565,7 @@ Status ExtractForwardInput(OpKernelContext* context,
                            const CudnnModelTypes& model_types, bool time_major,
                            const Tensor** input, const Tensor** input_h,
                            const Tensor** input_c, const Tensor** params,
+                           const int num_proj,
                            CudnnRnnModelShapes* model_shapes) {
   TF_RETURN_IF_ERROR(context->input("input", input));
   TF_RETURN_IF_ERROR(context->input("input_h", input_h));
@@ -615,12 +619,48 @@ Status ExtractForwardInput(OpKernelContext* context,
         model_shapes->hidden_state_shape.DebugString());
   }
   if (model_types.HasInputC()) {
-    if ((*input_h)->shape() != (*input_c)->shape()) {
-      return errors::InvalidArgument(
-          "input_h and input_c must have the same shape: ",
-          (*input_h)->shape().DebugString(), " ",
-          (*input_c)->shape().DebugString());
+    model_shapes->cell_num_units = (*input_c)->dim_size(2);
+    if (time_major) {
+      model_shapes->cell_state_shape =
+          TensorShape({model_shapes->dir_count * model_shapes->num_layers,
+                       model_shapes->batch_size, model_shapes->cell_num_units});
+    } else {
+      model_shapes->cell_state_shape =
+          TensorShape({model_shapes->batch_size,
+                       model_shapes->dir_count * model_shapes->num_layers,
+                       model_shapes->cell_num_units});
     }
+    if (num_proj == 0) {
+      if ((*input_h)->shape() != (*input_c)->shape()) {
+        return errors::InvalidArgument(
+            "input_h and input_c must have the same shape w/o projection: ",
+            (*input_h)->shape().DebugString(), " ",
+            (*input_c)->shape().DebugString());
+      }
+    } else {
+      if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) ||
+          num_proj != (*input_h)->dim_size(2) ||
+          (*input_h)->dim_size(0) != (*input_c)->dim_size(0) ||
+          (*input_h)->dim_size(1) != (*input_c)->dim_size(1)) {
+        return errors::InvalidArgument(
+            "Invalid input_h and input_c w/ projection size: ", num_proj, " ",
+            (*input_h)->shape().DebugString(), " ",
+            (*input_c)->shape().DebugString());
+      }
+    }
+  } else {
+    // dummy cell_state_shape TODO(kaixih): remove the time_major branch
+    if (time_major) {
+      model_shapes->cell_state_shape =
+          TensorShape({model_shapes->dir_count * model_shapes->num_layers,
+                       model_shapes->batch_size, model_shapes->num_units});
+    } else {
+      model_shapes->cell_state_shape =
+          TensorShape({model_shapes->batch_size,
+                       model_shapes->dir_count * model_shapes->num_layers,
+                       model_shapes->num_units});
+    }
+    model_shapes->cell_num_units = 0;
   }
   if (time_major) {
     model_shapes->output_shape =
@@ -639,18 +679,19 @@ Status ExtractForwardInput(OpKernelContext* context,
                            const CudnnModelTypes& model_types, bool time_major,
                            const Tensor** input, const Tensor** input_h,
                            const Tensor** input_c, const Tensor** params,
-                           const Tensor** sequence_lengths,
+                           const Tensor** sequence_lengths, const int num_proj,
                            CudnnRnnModelShapes* model_shapes) {
   TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
   return ExtractForwardInput(context, model_types, time_major, input, input_h,
-                             input_c, params, model_shapes);
+                             input_c, params, num_proj, model_shapes);
 }
 
 template <typename T>
 Status CreateForwardAndBackwardIODescriptors(
     OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
     std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
-    std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
+    std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc,
+    std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc,
     std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
     const absl::Span<const int>& seq_lengths, bool time_major) {
   StreamExecutor* executor = context->op_device_context()->stream()->parent();
@@ -658,6 +699,7 @@ Status CreateForwardAndBackwardIODescriptors(
 
   const TensorShape& input_shape = model_shapes.input_shape;
   const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
+  const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
   const TensorShape& output_shape = model_shapes.output_shape;
 
   DCHECK_EQ(input_shape.dims(), 3);
@@ -689,13 +731,28 @@ Status CreateForwardAndBackwardIODescriptors(
         hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
         hidden_state_shape.dim_size(2), data_type);
     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
-    *state_desc = hidden_state_desc_s.ConsumeValueOrDie();
+    *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
   } else {
     auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
         hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
         hidden_state_shape.dim_size(2), data_type);
     TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
-    *state_desc = hidden_state_desc_s.ConsumeValueOrDie();
+    *h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
+  }
+
+  DCHECK_EQ(cell_state_shape.dims(), 3);
+  if (time_major) {
+    auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
+        cell_state_shape.dim_size(0), cell_state_shape.dim_size(1),
+        cell_state_shape.dim_size(2), data_type);
+    TF_RETURN_IF_ERROR(cell_state_desc_s.status());
+    *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
+  } else {
+    auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
+        cell_state_shape.dim_size(1), cell_state_shape.dim_size(0),
+        cell_state_shape.dim_size(2), data_type);
+    TF_RETURN_IF_ERROR(cell_state_desc_s.status());
+    *c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
   }
 
   DCHECK_EQ(output_shape.dims(), 3);
@@ -739,7 +796,8 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
                  ScratchAllocator* workspace_allocator,
                  ProfileResult* output_profile_result) {
   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
-  std::unique_ptr<RnnStateTensorDescriptor> state_desc;
+  std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
+  std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
 
   absl::Span<const int> seq_lengths;
@@ -748,8 +806,8 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
   }
   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
-      context, model_shapes, &input_desc, &state_desc, &output_desc,
-      seq_lengths, time_major));
+      context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
+      &output_desc, seq_lengths, time_major));
 
   auto input_data = AsDeviceMemory<T>(input);
   auto input_h_data = AsDeviceMemory<T>(input_h);
@@ -769,11 +827,11 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
   Stream* stream = context->op_device_context()->stream();
   bool launch_success =
       stream
-          ->ThenRnnForward(rnn_desc, *input_desc, input_data, *state_desc,
-                           input_h_data, *state_desc, input_c_data, params_data,
-                           *output_desc, &output_data, *state_desc,
-                           &output_h_data, *state_desc, &output_c_data,
-                           is_training, reserve_space_allocator,
+          ->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc,
+                           input_h_data, *c_state_desc, input_c_data,
+                           params_data, *output_desc, &output_data,
+                           *h_state_desc, &output_h_data, *c_state_desc,
+                           &output_c_data, is_training, reserve_space_allocator,
                            workspace_allocator, output_profile_result)
           .ok();
   return launch_success
@@ -801,7 +859,8 @@ Status DoBackward(
     ScratchAllocator* workspace_allocator,
     ProfileResult* output_profile_result) {
   std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
-  std::unique_ptr<RnnStateTensorDescriptor> state_desc;
+  std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
+  std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
   std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
 
   absl::Span<const int> seq_lengths;
@@ -810,8 +869,8 @@ Status DoBackward(
         sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
   }
   TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
-      context, model_shapes, &input_desc, &state_desc, &output_desc,
-      seq_lengths, time_major));
+      context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
+      &output_desc, seq_lengths, time_major));
 
   auto input_data = AsDeviceMemory<T>(input);
   auto input_h_data = AsDeviceMemory<T>(input_h);
@@ -847,15 +906,15 @@ Status DoBackward(
   Stream* stream = context->op_device_context()->stream();
   bool launch_success =
       stream
-          ->ThenRnnBackward(rnn_desc, *input_desc, input_data, *state_desc,
-                            input_h_data, *state_desc, input_c_data,
-                            params_data, *output_desc, output_data, *state_desc,
-                            output_h_data, *state_desc, output_c_data,
-                            output_backprop_data, output_h_backprop_data,
-                            output_c_backprop_data, &input_backprop_data,
-                            &input_h_backprop_data, &input_c_backprop_data,
-                            &params_backprop_data, &reserve_space_uint8,
-                            workspace_allocator, output_profile_result)
+          ->ThenRnnBackward(
+              rnn_desc, *input_desc, input_data, *h_state_desc, input_h_data,
+              *c_state_desc, input_c_data, params_data, *output_desc,
+              output_data, *h_state_desc, output_h_data, *c_state_desc,
+              output_c_data, output_backprop_data, output_h_backprop_data,
+              output_c_backprop_data, &input_backprop_data,
+              &input_h_backprop_data, &input_c_backprop_data,
+              &params_backprop_data, &reserve_space_uint8, workspace_allocator,
+              output_profile_result)
           .ok();
   return launch_success
              ? Status::OK()
@@ -932,7 +991,7 @@ class CudnnRNNKernelCommon : public OpKernel {
   bool ResetRndGenState() { return reset_rnd_gen_state_; }
 
   template <typename T>
-  Status ExtractCudnnRNNParamsInfo(OpKernelContext* context,
+  Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj,
                                    std::unique_ptr<RnnDescriptor>* rnn_desc) {
     const Tensor* num_layers_t = nullptr;
     TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
@@ -953,6 +1012,9 @@ class CudnnRNNKernelCommon : public OpKernel {
     }
     int input_size = input_size_t->scalar<int>()();
 
+    int h_num_units = (num_proj == 0 ? num_units : num_proj);
+    int c_num_units = (num_proj == 0 ? 0 : num_units);
+
     RnnInputMode input_mode;
     TF_RETURN_IF_ERROR(
         ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
@@ -962,9 +1024,10 @@ class CudnnRNNKernelCommon : public OpKernel {
     // random number generator, therefore set state_allocator to nullptr.
     const AlgorithmConfig algo_config;
     auto rnn_desc_s = stream->parent()->createRnnDescriptor(
-        num_layers, num_units, input_size, /*batch_size=*/0, input_mode,
-        rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config,
-        dropout(), seed(), /* state_allocator=*/nullptr);
+        num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
+        /*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
+        ToDataType<T>::value, algo_config, dropout(), seed(),
+        /* state_allocator=*/nullptr);
     if (!rnn_desc_s.ok()) {
       return FromExecutorStatus(rnn_desc_s);
     }
@@ -983,9 +1046,9 @@ class CudnnRNNKernelCommon : public OpKernel {
     se::dnn::DataType data_type = ToDataType<T>::value;
     auto rnn_desc_s = executor->createRnnDescriptor(
         model_shapes.num_layers, model_shapes.num_units,
-        model_shapes.input_size, model_shapes.batch_size, input_mode,
-        rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(),
-        seed(), dropout_state_allocator);
+        model_shapes.input_size, model_shapes.cell_num_units,
+        model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
+        data_type, algo_config, dropout(), seed(), dropout_state_allocator);
     TF_RETURN_IF_ERROR(rnn_desc_s.status());
 
     *rnn_desc = rnn_desc_s.ConsumeValueOrDie();
@@ -1035,11 +1098,18 @@ template <typename T, typename Index>
 class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
  public:
   explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
-      : CudnnRNNKernelCommon(context) {}
+      : CudnnRNNKernelCommon(context) {
+    if (context->HasAttr("num_proj")) {
+      OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
+    } else {
+      num_proj_ = 0;
+    }
+  }
 
   void Compute(OpKernelContext* context) override {
     std::unique_ptr<RnnDescriptor> rnn_desc;
-    OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
+    OP_REQUIRES_OK(context,
+                   ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
     CHECK(params_size_in_bytes % sizeof(T) == 0)
         << "params_size_in_bytes must be multiple of element size";
@@ -1049,6 +1119,9 @@ class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
     OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
     *output_t->template flat<Index>().data() = params_size;
   }
+
+ private:
+  int num_proj_;
 };
 
 #define REGISTER_GPU(T)                                    \
@@ -1074,7 +1147,32 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
  public:
   explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
       : CudnnRNNKernelCommon(context) {
-    OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
+    if (context->HasAttr("num_params")) {
+      OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
+    } else {
+      num_params_ = 0;
+    }
+    if (context->HasAttr("num_params_weights")) {
+      OP_REQUIRES_OK(context, context->GetAttr("num_params_weights",
+                                               &num_params_weights_));
+    } else {
+      num_params_weights_ = 0;
+    }
+    if (context->HasAttr("num_params_biases")) {
+      OP_REQUIRES_OK(
+          context, context->GetAttr("num_params_biases", &num_params_biases_));
+    } else {
+      num_params_biases_ = 0;
+    }
+    if (context->HasAttr("num_proj")) {
+      OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
+    } else {
+      num_proj_ = 0;
+    }
+    if (num_proj_ == 0) {
+      num_params_weights_ = num_params_;
+      num_params_biases_ = num_params_;
+    }
   }
 
   void Compute(OpKernelContext* context) override {
@@ -1083,7 +1181,8 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
     Stream* stream = context->op_device_context()->stream();
 
     std::unique_ptr<RnnDescriptor> rnn_desc;
-    OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
+    OP_REQUIRES_OK(context,
+                   ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
     CHECK(params_size_in_bytes % sizeof(T) == 0)
         << "params_size_in_bytes must be multiple of element size";
@@ -1109,25 +1208,46 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
     if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
       num_dirs = 2;
     }
-    const int num_params_per_layer = num_params_ / num_layers / num_dirs;
+    const int num_params_weights_per_layer =
+        num_params_weights_ / num_layers / num_dirs;
     // Number of params applied on inputs. The rest are applied on recurrent
     // hidden states.
-    const int num_params_input_state = num_params_per_layer / 2;
-    CHECK(num_params_ % (num_layers * num_dirs) == 0)
-        << "Number of params is not a multiple of num_layers * num_dirs.";
-    CHECK(num_params_per_layer % 2 == 0)
-        << "Number of params per layer is not a even number.";
+    const int num_params_input_state = num_params_weights_per_layer / 2;
+    OP_REQUIRES(
+        context, num_params_weights_ % (num_layers * num_dirs) == 0,
+        errors::InvalidArgument("Number of params (weights) is not a multiple"
+                                "of num_layers * num_dirs."));
+    OP_REQUIRES(
+        context, num_params_biases_ % (num_layers * num_dirs) == 0,
+        errors::InvalidArgument("Number of params (biases) is not a multiple"
+                                "of num_layers * num_dirs."));
+    if (num_proj_ == 0) {
+      OP_REQUIRES(
+          context, num_params_weights_per_layer % 2 == 0,
+          errors::InvalidArgument("Number of params (weights) per layer is not"
+                                  "an even number with no projection."));
+    } else {
+      OP_REQUIRES(
+          context, num_params_weights_per_layer % 2 != 0,
+          errors::InvalidArgument("Number of params (weights) per layer is not"
+                                  "an odl number with projection."));
+    }
 
-    CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size())
-        << "Number of params mismatch. Expected " << num_params_ << ", got "
-        << rnn_desc->ParamsWeightRegions().size();
+    OP_REQUIRES(
+        context, num_params_weights_ == rnn_desc->ParamsWeightRegions().size(),
+        errors::InvalidArgument("C Number of params mismatch. Expected ",
+                                num_params_weights_, ", got ",
+                                rnn_desc->ParamsWeightRegions().size()));
+    int h_num_units = (num_proj_ == 0 ? num_units : num_proj_);
+    int c_num_units = (num_proj_ == 0 ? 0 : num_units);
     for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
       int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
       int64 size = size_in_bytes / sizeof(T);
-      const int layer_idx = i / num_params_per_layer;
-      const int index_within_layer = i % num_params_per_layer;
-      int width = 0, height = num_units;
-      // In CuDNN layout, each layer has num_params_per_layer params, with the
+      const int layer_idx = i / num_params_weights_per_layer;
+      const int index_within_layer = i % num_params_weights_per_layer;
+      int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units);
+      // In CuDNN layout, each layer has num_params_weights_per_layer params,
+      // with the
       // first half a.k.a num_params_input_state params applied on the inputs,
       // and the second half on the recurrent hidden states.
       bool apply_on_input_state = index_within_layer < num_params_input_state;
@@ -1135,7 +1255,7 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
         if (layer_idx == 0 && apply_on_input_state) {
           width = input_size;
         } else {
-          width = num_units;
+          width = h_num_units;
         }
       } else {
         if (apply_on_input_state) {
@@ -1145,15 +1265,19 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
           } else {
             // Following layers, cell inputs are concatenated outputs of
             // its prior layer.
-            width = 2 * num_units;
+            width = 2 * h_num_units;
           }
         } else {
-          width = num_units;
+          width = h_num_units;
         }
       }
       CHECK(size == width * height) << "Params size mismatch. Expected "
                                     << width * height << ", got " << size;
       Tensor* output = nullptr;
+      int id_in_layer = i % num_params_weights_per_layer;
+      if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer - 1) {
+        std::swap(height, width);
+      }
       OP_REQUIRES_OK(context, context->allocate_output(
                                   i, TensorShape({height, width}), &output));
       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
@@ -1162,10 +1286,11 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
       stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
     }
 
-    OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(),
-                errors::InvalidArgument("Number of params mismatch. Expected ",
-                                        num_params_, ", got ",
-                                        rnn_desc->ParamsBiasRegions().size()));
+    OP_REQUIRES(
+        context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(),
+        errors::InvalidArgument("A Number of params mismatch. Expected ",
+                                num_params_biases_, ", got ",
+                                rnn_desc->ParamsBiasRegions().size()));
     for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
       int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
       int64 size = size_in_bytes / sizeof(T);
@@ -1175,7 +1300,7 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
 
       Tensor* output = nullptr;
       OP_REQUIRES_OK(context,
-                     context->allocate_output(num_params_ + i,
+                     context->allocate_output(num_params_weights_ + i,
                                               TensorShape({size}), &output));
       DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
           input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
@@ -1186,6 +1311,9 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
 
  private:
   int num_params_;
+  int num_params_weights_;
+  int num_params_biases_;
+  int num_proj_;
 };
 
 #define REGISTER_GPU(T)                                     \
@@ -1201,17 +1329,37 @@ TF_CALL_float(REGISTER_GPU);
 TF_CALL_double(REGISTER_GPU);
 #undef REGISTER_GPU
 
+#define REGISTER_GPU(T)                                       \
+  REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \
+                              .Device(DEVICE_GPU)             \
+                              .HostMemory("num_layers")       \
+                              .HostMemory("num_units")        \
+                              .HostMemory("input_size")       \
+                              .TypeConstraint<T>("T"),        \
+                          CudnnRNNParamsToCanonical<GPUDevice, T>);
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+#undef REGISTER_GPU
+
 // Convert weight and bias params from the canonical form to a
 // platform-specific layout.
 template <typename T>
 class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
  public:
   explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
-      : CudnnRNNKernelCommon(context) {}
+      : CudnnRNNKernelCommon(context) {
+    if (context->HasAttr("num_proj")) {
+      OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
+    } else {
+      num_proj_ = 0;
+    }
+  }
 
   void Compute(OpKernelContext* context) override {
     std::unique_ptr<RnnDescriptor> rnn_desc;
-    OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
+    OP_REQUIRES_OK(context,
+                   ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
     int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
     CHECK(params_size_in_bytes % sizeof(T) == 0)
         << "params_size_in_bytes must be multiple of element size";
@@ -1232,6 +1380,9 @@ class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
     RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
                      stream);
   }
+
+ private:
+  int num_proj_;
 };
 
 #define REGISTER_GPU(T)                                     \
@@ -1247,6 +1398,19 @@ TF_CALL_float(REGISTER_GPU);
 TF_CALL_double(REGISTER_GPU);
 #undef REGISTER_GPU
 
+#define REGISTER_GPU(T)                                       \
+  REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \
+                              .Device(DEVICE_GPU)             \
+                              .HostMemory("num_layers")       \
+                              .HostMemory("num_units")        \
+                              .HostMemory("input_size")       \
+                              .TypeConstraint<T>("T"),        \
+                          CudnnRNNCanonicalToParams<GPUDevice, T>);
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+#undef REGISTER_GPU
+
 // Run the forward operation of the RNN model.
 template <typename T>
 class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
@@ -1264,14 +1428,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
   void Compute(OpKernelContext* context) override {
     AlgorithmConfig algo_config;
     ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
-                              /*time_major=*/true);
+                              /*time_major=*/true, /*num_proj=*/0);
   }
 
  protected:
   virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
                                          AlgorithmConfig* output_algo_config,
-                                         bool var_seq_lengths,
-                                         bool time_major) {
+                                         bool var_seq_lengths, bool time_major,
+                                         int num_proj) {
     CHECK_NE(output_algo_config, nullptr);
 
     const Tensor* input = nullptr;
@@ -1281,14 +1445,15 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
     const Tensor* sequence_lengths = nullptr;
     CudnnRnnModelShapes model_shapes;
     if (var_seq_lengths) {
+      OP_REQUIRES_OK(context, ExtractForwardInput(
+                                  context, model_types(), time_major, &input,
+                                  &input_h, &input_c, &params,
+                                  &sequence_lengths, num_proj, &model_shapes));
+    } else {
       OP_REQUIRES_OK(context,
                      ExtractForwardInput(context, model_types(), time_major,
                                          &input, &input_h, &input_c, &params,
-                                         &sequence_lengths, &model_shapes));
-    } else {
-      OP_REQUIRES_OK(context, ExtractForwardInput(
-                                  context, model_types(), time_major, &input,
-                                  &input_h, &input_c, &params, &model_shapes));
+                                         num_proj, &model_shapes));
     }
     RnnInputMode input_mode;
     OP_REQUIRES_OK(context,
@@ -1362,13 +1527,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
                          Tensor** output_c) {
     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
     const TensorShape& output_shape = model_shapes.output_shape;
+    const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
 
     TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
     TF_RETURN_IF_ERROR(
         context->allocate_output(1, hidden_state_shape, output_h));
     if (HasInputC()) {
       TF_RETURN_IF_ERROR(
-          context->allocate_output(2, hidden_state_shape, output_c));
+          context->allocate_output(2, cell_state_shape, output_c));
     } else {
       // Only LSTM uses input_c and output_c. So for all other models, we only
       // need to create dummy outputs.
@@ -1414,7 +1580,7 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
     AlgorithmConfig best_algo_config;
     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
         context, &best_algo_config, /*var_seq_lengths=*/false,
-        /*time_major=*/true);
+        /*time_major=*/true, /*num_proj=*/0);
     if (!context->status().ok()) {
       return;
     }
@@ -1613,13 +1779,18 @@ class CudnnRNNForwardOpV3<GPUDevice, T>
   explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
       : CudnnRNNForwardOp<GPUDevice, T>(context) {
     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
+    if (context->HasAttr("num_proj")) {
+      OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
+    } else {
+      num_proj_ = 0;
+    }
   }
 
   void Compute(OpKernelContext* context) override {
     AlgorithmConfig best_algo_config;
     CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
         context, &best_algo_config, /*var_seq_lengths=*/true,
-        /*time_major=*/time_major());
+        /*time_major=*/time_major(), num_proj_);
     if (!context->status().ok()) {
       return;
     }
@@ -1631,6 +1802,9 @@ class CudnnRNNForwardOpV3<GPUDevice, T>
     OP_REQUIRES_OK(context,
                    context->allocate_output(4, {}, &output_host_reserved));
   }
+
+ private:
+  int num_proj_;
 };
 
 #define REGISTER_GPU(T)                                       \
@@ -1654,12 +1828,12 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
       : CudnnRNNKernelCommon(context) {}
 
   void Compute(OpKernelContext* context) override {
-    ComputeImpl(context, false, true);
+    ComputeImpl(context, false, true, 0);
   }
 
  protected:
   virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
-                           bool time_major) {
+                           bool time_major, int num_proj) {
     const Tensor* input = nullptr;
     const Tensor* input_h = nullptr;
     const Tensor* input_c = nullptr;
@@ -1667,14 +1841,15 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
     const Tensor* sequence_lengths = nullptr;
     CudnnRnnModelShapes model_shapes;
     if (var_seq_lengths) {
+      OP_REQUIRES_OK(context, ExtractForwardInput(
+                                  context, model_types(), time_major, &input,
+                                  &input_h, &input_c, &params,
+                                  &sequence_lengths, num_proj, &model_shapes));
+    } else {
       OP_REQUIRES_OK(context,
                      ExtractForwardInput(context, model_types(), time_major,
                                          &input, &input_h, &input_c, &params,
-                                         &sequence_lengths, &model_shapes));
-    } else {
-      OP_REQUIRES_OK(context, ExtractForwardInput(
-                                  context, model_types(), time_major, &input,
-                                  &input_h, &input_c, &params, &model_shapes));
+                                         num_proj, &model_shapes));
     }
     RnnInputMode input_mode;
     OP_REQUIRES_OK(context,
@@ -1757,6 +1932,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
     TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
     const TensorShape& output_shape = model_shapes.output_shape;
+    const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
 
     if (output_shape != (*output)->shape()) {
       return errors::InvalidArgument(
@@ -1782,16 +1958,16 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
     }
 
     if (model_types.HasInputC()) {
-      if (hidden_state_shape != (*output_c)->shape()) {
+      if (cell_state_shape != (*output_c)->shape()) {
         return errors::InvalidArgument(
             "Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
-            hidden_state_shape.DebugString());
+            cell_state_shape.DebugString());
       }
-      if (hidden_state_shape != (*output_c_backprop)->shape()) {
+      if (cell_state_shape != (*output_c_backprop)->shape()) {
         return errors::InvalidArgument(
             "Invalid output_c_backprop shape: ",
             (*output_c_backprop)->shape().DebugString(), " ",
-            hidden_state_shape.DebugString());
+            cell_state_shape.DebugString());
       }
     }
     return Status::OK();
@@ -1804,6 +1980,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
                          Tensor** input_c_backprop, Tensor** params_backprop) {
     const TensorShape& input_shape = model_shapes.input_shape;
     const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
+    const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
 
     TF_RETURN_IF_ERROR(
         context->allocate_output(0, input_shape, input_backprop));
@@ -1811,7 +1988,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
         context->allocate_output(1, hidden_state_shape, input_h_backprop));
     if (HasInputC()) {
       TF_RETURN_IF_ERROR(
-          context->allocate_output(2, hidden_state_shape, input_c_backprop));
+          context->allocate_output(2, cell_state_shape, input_c_backprop));
     } else {
       // Only LSTM uses input_c and output_c. So for all other models, we only
       // need to create dummy outputs.
@@ -1879,11 +2056,20 @@ class CudnnRNNBackwardOpV3<GPUDevice, T>
   explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
       : CudnnRNNBackwardOp<GPUDevice, T>(context) {
     OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
+    if (context->HasAttr("num_proj")) {
+      OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
+    } else {
+      num_proj_ = 0;
+    }
   }
 
   void Compute(OpKernelContext* context) override {
-    CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major());
+    CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(),
+                                                  num_proj_);
   }
+
+ private:
+  int num_proj_;
 };
 
 #define REGISTER_GPU(T)                                       \
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index 9b22ccdeeec..1dd7659e137 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -49,6 +49,7 @@ REGISTER_OP("CudnnRNNParamsSize")
     .Attr("dropout: float = 0.0")
     .Attr("seed: int = 0")
     .Attr("seed2: int = 0")
+    .Attr("num_proj: int = 0")
     .Output("params_size: S")
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle unused;
@@ -166,11 +167,13 @@ REGISTER_OP("CudnnRNNV3")
     .Attr("dropout: float = 0.0")
     .Attr("seed: int = 0")
     .Attr("seed2: int = 0")
+    .Attr("num_proj: int = 0")
     .Attr("is_training: bool = true")
     .Attr("time_major: bool = true")
     .SetShapeFn([](InferenceContext* c) {
       auto input_shape = c->input(0);
       auto input_h_shape = c->input(1);
+      auto input_c_shape = c->input(2);
       auto max_seq_length = c->Dim(input_shape, 0);
       auto batch_size = c->Dim(input_shape, 1);
       auto num_units = c->Dim(input_h_shape, 2);
@@ -185,7 +188,7 @@ REGISTER_OP("CudnnRNNV3")
           c->MakeShape({max_seq_length, batch_size, output_size});
       auto output_h_shape = input_h_shape;
       auto output_c_shape TF_ATTRIBUTE_UNUSED =
-          (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
+          (rnn_mode == "lstm") ? input_c_shape : c->MakeShape({});
       c->set_output(0, output_shape);
       c->set_output(1, output_h_shape);
       c->set_output(2, output_c_shape);
@@ -293,6 +296,7 @@ REGISTER_OP("CudnnRNNBackpropV3")
     .Attr("dropout: float = 0.0")
     .Attr("seed: int = 0")
     .Attr("seed2: int = 0")
+    .Attr("num_proj: int = 0")
     .Attr("time_major: bool = true")
     .SetShapeFn([](InferenceContext* c) {
       auto input_shape = c->input(0);
@@ -338,6 +342,43 @@ REGISTER_OP("CudnnRNNParamsToCanonical")
       return Status::OK();
     });
 
+REGISTER_OP("CudnnRNNParamsToCanonicalV2")
+    .Input("num_layers: int32")
+    .Input("num_units: int32")
+    .Input("input_size: int32")
+    .Input("params: T")
+    .Output("weights: num_params_weights * T")
+    .Output("biases: num_params_biases * T")
+    .Attr("T: {float16, float32, float64}")
+    .Attr("num_params_weights: int")
+    .Attr("num_params_biases: int")
+    .Attr(kRNNModeAttrs)
+    .Attr(kRNNInputModeAttrs)
+    .Attr(kRNNDirectionAttrs)
+    .Attr("dropout: float = 0.0")
+    .Attr("seed: int = 0")
+    .Attr("seed2: int = 0")
+    .Attr("num_proj: int = 0")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle unused;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
+      int num_params_weights;
+      int num_params_biases;
+      TF_RETURN_IF_ERROR(c->GetAttr("num_params_weights", &num_params_weights));
+      TF_RETURN_IF_ERROR(c->GetAttr("num_params_biases", &num_params_biases));
+      // Set shape for weight matrices
+      for (int i = 0; i < num_params_weights; i++) {
+        c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
+                                   InferenceContext::kUnknownDim));
+      }
+      // Set shape for bias vectors
+      for (int i = 0; i < num_params_biases; i++) {
+        c->set_output(num_params_weights + i,
+                      c->Vector(InferenceContext::kUnknownDim));
+      }
+      return Status::OK();
+    });
+
 REGISTER_OP("CudnnRNNCanonicalToParams")
     .Input("num_layers: int32")
     .Input("num_units: int32")
@@ -358,4 +399,26 @@ REGISTER_OP("CudnnRNNCanonicalToParams")
       return Status::OK();
     });
 
+REGISTER_OP("CudnnRNNCanonicalToParamsV2")
+    .Input("num_layers: int32")
+    .Input("num_units: int32")
+    .Input("input_size: int32")
+    .Input("weights: num_params_weights * T")
+    .Input("biases: num_params_biases * T")
+    .Output("params: T")
+    .Attr("T: {float16, float32, float64}")
+    .Attr("num_params_weights: int")
+    .Attr("num_params_biases: int")
+    .Attr(kRNNModeAttrs)
+    .Attr(kRNNInputModeAttrs)
+    .Attr(kRNNDirectionAttrs)
+    .Attr("dropout: float = 0.0")
+    .Attr("seed: int = 0")
+    .Attr("seed2: int = 0")
+    .Attr("num_proj: int = 0")
+    .SetShapeFn([](InferenceContext* c) {
+      c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
+      return Status::OK();
+    });
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 788a8aceaa8..8e8c8193a14 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -111,6 +111,8 @@ TEST(CudnnRNNOpsTest, ForwardV3Lstm_ShapeFn) {
   std::vector<int> input_shape = {max_seq_length, batch_size, num_units};
   std::vector<int> input_h_shape = {num_layers * dir_count, batch_size,
                                     num_units};
+  std::vector<int> input_c_shape = {num_layers * dir_count, batch_size,
+                                    num_units};
   std::vector<int> output_shape = {max_seq_length, batch_size,
                                    num_units * dir_count};
   std::vector<int> seq_lengths_shape = {batch_size};
@@ -119,9 +121,9 @@ TEST(CudnnRNNOpsTest, ForwardV3Lstm_ShapeFn) {
   };
   string input_shapes_desc = strings::StrCat(
       shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";",
-      shape_to_str(input_h_shape), ";", "[?]", ";",
+      shape_to_str(input_c_shape), ";", "[?]", ";",
       shape_to_str(seq_lengths_shape));
-  string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?;?";
+  string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in2;?;?";
 
   ShapeInferenceTestOp op("CudnnRNNV3");
   TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNNV3")
diff --git a/tensorflow/python/ops/cudnn_rnn_grad.py b/tensorflow/python/ops/cudnn_rnn_grad.py
index 9ce906121f2..3c93e0b8ec9 100644
--- a/tensorflow/python/ops/cudnn_rnn_grad.py
+++ b/tensorflow/python/ops/cudnn_rnn_grad.py
@@ -98,6 +98,7 @@ def _cudnn_rnn_backwardv3(op, *grads):
       seed=op.get_attr("seed"),
       seed2=op.get_attr("seed2"),
       time_major=op.get_attr("time_major"),
+      num_proj=op.get_attr("num_proj"),
       rnn_mode=op.get_attr("rnn_mode"),
       input_mode=op.get_attr("input_mode"),
       direction=op.get_attr("direction")) + (None,)
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 8331be22893..e0a6dce5115 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1002,8 +1002,8 @@ class CudnnRnnParamsDescriptor {
 class CudnnRnnDescriptor : public dnn::RnnDescriptor {
   CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
                      PersistentRnnPlan rnn_plan, int num_layers,
-                     int hidden_size, int input_size, int batch_size,
-                     cudnnRNNInputMode_t input_mode,
+                     int hidden_size, int input_size, int cell_size,
+                     int batch_size, cudnnRNNInputMode_t input_mode,
                      cudnnDirectionMode_t direction_mode,
                      cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
                      cudnnDataType_t compute_type,
@@ -1015,6 +1015,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
         num_layers_(num_layers),
         hidden_size_(hidden_size),
         input_size_(input_size),
+        cell_size_(cell_size),
         batch_size_(batch_size),
         rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
         input_mode_(input_mode),
@@ -1031,7 +1032,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
 
   static port::StatusOr<CudnnRnnDescriptor> Create(
       const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
-      int batch_size, cudnnRNNInputMode_t input_mode,
+      int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
       cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
       cudnnDataType_t data_type, cudnnDataType_t compute_type,
       const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
@@ -1044,12 +1045,28 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
     cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
 
     // TODO: allow the user to choose an algorithm.
+    int unified_size = hidden_size;
+    bool use_projection = cell_size != 0 && hidden_size < cell_size;
+    if (use_projection) {
+      unified_size = cell_size;
+    }
     RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
-        cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size,
-        /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(),
-        /*inputMode=*/input_mode, /*direction=*/direction_mode,
-        /*mode=*/rnn_mode, /*algo=*/rnn_algo,
+        cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
+        /*hiddenSize=*/unified_size, /*numLayers=*/num_layers,
+        /*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
+        /*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
         /*dataType=*/compute_type));
+    if (use_projection) {
+#if CUDNN_VERSION >= 7101
+      RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
+          cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
+          /*recProjSize=*/hidden_size, /*outProjSize=*/0));
+#else
+      return port::Status(port::error::INVALID_ARGUMENT,
+                          "No supported cudnnSetRNNProjectionLayers when "
+                          "CUDNN_VERSION < 7.1.1");
+#endif
+    }
 
     // TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
     // But in the future if these APIs are used to process full length arrays,
@@ -1106,9 +1123,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
 #endif  // CUDNN_VERSION >= 7000
 
     return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
-                              num_layers, hidden_size, input_size, batch_size,
-                              input_mode, direction_mode, rnn_mode, data_type,
-                              compute_type, algorithm_config,
+                              num_layers, hidden_size, input_size, cell_size,
+                              batch_size, input_mode, direction_mode, rnn_mode,
+                              data_type, compute_type, algorithm_config,
                               std::move(dropout_desc), std::move(params_desc));
   }
 
@@ -1116,6 +1133,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
   int num_layers() const { return num_layers_; }
   int hidden_size() const { return hidden_size_; }
   int input_size() const { return input_size_; }
+  int cell_size() const { return cell_size_; }
   int batch_size() const { return batch_size_; }
   cudnnRNNInputMode_t input_mode() const { return input_mode_; }
   cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
@@ -1144,6 +1162,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
   int num_layers_;
   int hidden_size_;
   int input_size_;
+  // cell_size_ is the size of cell state, which will be different from
+  // hidden_size_ if the projection is used.
+  int cell_size_;
   // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
   // algorithm.
   int batch_size_;
@@ -1161,6 +1182,69 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
 
 namespace {
 
+// Check if the LSTM projection is used. If yes, an additional weigth matrix
+// (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
+// be done.
+port::Status CheckAndFetchProjectionWeights(
+    const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
+    const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
+    const FilterDescriptor& region_desc_handle,
+    dnn::RnnDescriptor::ParamsRegions* weights) {
+#if CUDNN_VERSION >= 7101
+  int hidden_size_v;
+  int num_layers_v;
+  cudnnDropoutDescriptor_t dropout_desc;
+  cudnnRNNInputMode_t input_mode;
+  cudnnDirectionMode_t direction;
+  cudnnRNNMode_t mode;
+  cudnnRNNAlgo_t algo;
+  cudnnDataType_t data_type;
+  RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
+      /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
+      /*hiddenSize=*/&hidden_size_v,
+      /*numLayers=*/&num_layers_v,
+      /*dropoutDesc=*/&dropout_desc,
+      /*inputMode=*/&input_mode,
+      /*direction=*/&direction,
+      /*mode=*/&mode,
+      /*algo=*/&algo,
+      /*dataType=*/&data_type));
+  int rec_proj_size_v;
+  int out_proj_size_v;
+  RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
+      /*handle=*/cudnn.handle(),
+      /*rnnDesc=*/rnn_desc,
+      /*recProjSize*/ &rec_proj_size_v,
+      /*outProjSize*/ &out_proj_size_v));
+  if (rec_proj_size_v != hidden_size_v) {
+    void* offset = nullptr;
+    int region_id = 8;
+    RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
+        /*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
+        /*layer=*/layer, /*xDesc=*/input_desc.get(),
+        /*wDesc=*/filter_desc.get(),
+        /*w=*/nullptr, /*linLayerID=*/region_id,
+        /*linLayerMatDesc=*/region_desc_handle.get(),
+        /*linLayerMat or linLayerBias=*/&offset));
+    int dims[] = {1, 1, 1};
+    cudnnDataType_t data_type;
+    cudnnTensorFormat_t tensor_format;
+    int n_dims;
+    RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
+        /*filterDesc=*/region_desc_handle.get(),
+        /*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
+        /*dataType=*/&data_type, /*format=*/&tensor_format,
+        /*nbDims=*/&n_dims, /*filterDimA=*/dims));
+    int64 size =
+        dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
+    dnn::RnnDescriptor::ParamsRegion region = {reinterpret_cast<int64>(offset),
+                                               size};
+    weights->push_back(region);
+  }
+#endif  // CUDNN_VERSION >= 7101
+  return port::Status::OK();
+}
+
 port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
     const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
     cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
@@ -1248,6 +1332,9 @@ port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
         (type == 0 ? weights : biases).push_back(region);
       }
     }
+    TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
+        cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
+        &weights));
   }
 
   return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
@@ -1412,6 +1499,7 @@ struct RnnModelDims {
   int max_seq_length = 0;
   int hidden_size = 0;
   int input_size = 0;
+  int cell_size = 0;
   int dir_count = 0;
 };
 
@@ -1437,6 +1525,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
   model_dims.max_seq_length = input_desc.max_seq_length();
   model_dims.hidden_size = rnn_desc.hidden_size();
   model_dims.input_size = input_desc.data_size();
+  model_dims.cell_size = rnn_desc.cell_size();
   model_dims.dir_count =
       (rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
 
@@ -1447,9 +1536,11 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
         input_h_desc.data_size() == model_dims.hidden_size)) {
     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
   }
+  // The LSTM projection will be used if input_h_desc.data_size() <
+  // input_c_desc.data_size()
   if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
         input_h_desc.batch_size() == input_c_desc.batch_size() &&
-        input_h_desc.data_size() == input_c_desc.data_size())) {
+        input_h_desc.data_size() <= input_c_desc.data_size())) {
     return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
   }
   if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
@@ -1466,7 +1557,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
   }
   if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
         input_h_desc.batch_size() == output_c_desc.batch_size() &&
-        input_h_desc.data_size() == output_c_desc.data_size())) {
+        input_h_desc.data_size() <= output_c_desc.data_size())) {
     return port::Status(port::error::INVALID_ARGUMENT,
                         "Invalid output_c shape");
   }
@@ -1872,18 +1963,18 @@ port::Status CudnnSupport::DoRnnBackwardImpl(
 
 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
 CudnnSupport::createRnnDescriptor(
-    int num_layers, int hidden_size, int input_size, int batch_size,
-    dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
-    dnn::RnnMode rnn_mode, dnn::DataType data_type,
-    const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
-    ScratchAllocator* state_allocator) {
+    int num_layers, int hidden_size, int input_size, int cell_size,
+    int batch_size, dnn::RnnInputMode input_mode,
+    dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
+    dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
+    float dropout, uint64 seed, ScratchAllocator* state_allocator) {
   // Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
   // not enqueueing anything into a stream, we pass in the null stream.
   auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
   SE_ASSIGN_OR_RETURN(
       CudnnRnnDescriptor rnn_desc,
       CudnnRnnDescriptor::Create(
-          cudnn, num_layers, hidden_size, input_size, batch_size,
+          cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
           ToCudnnRnnInputMode(input_mode),
           ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
           ToCudnnDataType(data_type), GetRnnComputeType(data_type),
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 8dd66d5e0c5..e3742c07a56 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -47,11 +47,11 @@ class CudnnSupport : public dnn::DnnSupport {
   port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
 
   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
-      int num_layers, int hidden_size, int input_size, int batch_size,
-      dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
-      dnn::RnnMode rnn_mode, dnn::DataType data_type,
-      const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
-      ScratchAllocator* state_allocator) override;
+      int num_layers, int hidden_size, int input_size, int cell_size,
+      int batch_size, dnn::RnnInputMode input_mode,
+      dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
+      dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
+      float dropout, uint64 seed, ScratchAllocator* state_allocator) override;
 
   port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
   createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 58fd7baa580..f410875e77e 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -2065,6 +2065,7 @@ class DnnSupport {
   //  num_layers: the number of layers for a RNN model.
   //  hidden_size: the size of the hidden state.
   //  input_size: the size of the input state.
+  //  cell_size: the size of the cell state
   //  input_mode: an enum to specify whether a linear transformation is added
   //    after the input state. If input_size is different from hidden_size, this
   //    is required.
@@ -2080,7 +2081,8 @@ class DnnSupport {
   //    is no longer in use.
   virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
   createRnnDescriptor(int num_layers, int hidden_size, int input_size,
-                      int batch_size, dnn::RnnInputMode input_mode,
+                      int cell_size, int batch_size,
+                      dnn::RnnInputMode input_mode,
                       dnn::RnnDirectionMode direction_mode,
                       dnn::RnnMode rnn_mode, dnn::DataType data_type,
                       const dnn::AlgorithmConfig& algorithm_config,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index b8aca1b56c4..839f1cd20be 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -336,18 +336,18 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
 
 port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
 StreamExecutor::createRnnDescriptor(
-    int num_layers, int hidden_size, int input_size, int batch_size,
-    dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
-    dnn::RnnMode rnn_mode, dnn::DataType data_type,
-    const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
-    ScratchAllocator *state_allocator) {
+    int num_layers, int hidden_size, int input_size, int cell_size,
+    int batch_size, dnn::RnnInputMode input_mode,
+    dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
+    dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
+    float dropout, uint64 seed, ScratchAllocator *state_allocator) {
   dnn::DnnSupport *dnn_support = AsDnn();
   if (!dnn_support) {
     return port::Status(port::error::UNKNOWN,
                         "Fail to find the dnn implementation.");
   }
   return dnn_support->createRnnDescriptor(
-      num_layers, hidden_size, input_size, batch_size, input_mode,
+      num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
       direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
       state_allocator);
 }
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 587f36c06de..d2f2f591e2a 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -394,11 +394,11 @@ class StreamExecutor {
   // Create an RNN descriptor based on model shapes and configurations.
   // The caller retains the ownership of the descriptor.
   port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
-      int num_layers, int hidden_size, int input_size, int batch_size,
-      dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
-      dnn::RnnMode rnn_mode, dnn::DataType data_type,
-      const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
-      ScratchAllocator *state_allocator);
+      int num_layers, int hidden_size, int input_size, int cell_size,
+      int batch_size, dnn::RnnInputMode input_mode,
+      dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
+      dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
+      float dropout, uint64 seed, ScratchAllocator *state_allocator);
 
   // Create a RNN sequence descriptor that specifies either the input or output
   // sequence. The caller retains the ownership of the returned descriptor.
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 650a42505f4..1d4220ec4b3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -770,27 +770,35 @@ tf_module {
   }
   member_method {
     name: "CudnnRNNBackpropV3"
-    argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
+    argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
   }
   member_method {
     name: "CudnnRNNCanonicalToParams"
     argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
   }
+  member_method {
+    name: "CudnnRNNCanonicalToParamsV2"
+    argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
+  }
   member_method {
     name: "CudnnRNNParamsSize"
-    argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
+    argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
   }
   member_method {
     name: "CudnnRNNParamsToCanonical"
     argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
   }
+  member_method {
+    name: "CudnnRNNParamsToCanonicalV2"
+    argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
+  }
   member_method {
     name: "CudnnRNNV2"
     argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
   }
   member_method {
     name: "CudnnRNNV3"
-    argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
+    argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
   }
   member_method {
     name: "Cumprod"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 650a42505f4..1d4220ec4b3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -770,27 +770,35 @@ tf_module {
   }
   member_method {
     name: "CudnnRNNBackpropV3"
-    argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
+    argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
   }
   member_method {
     name: "CudnnRNNCanonicalToParams"
     argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
   }
+  member_method {
+    name: "CudnnRNNCanonicalToParamsV2"
+    argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
+  }
   member_method {
     name: "CudnnRNNParamsSize"
-    argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
+    argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
   }
   member_method {
     name: "CudnnRNNParamsToCanonical"
     argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
   }
+  member_method {
+    name: "CudnnRNNParamsToCanonicalV2"
+    argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
+  }
   member_method {
     name: "CudnnRNNV2"
     argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
   }
   member_method {
     name: "CudnnRNNV3"
-    argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
+    argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
   }
   member_method {
     name: "Cumprod"