diff --git a/tensorflow/core/api_def/base_api/api_def_BlockLSTMGradV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BlockLSTMGradV2.pbtxt
new file mode 100644
index 00000000000..638d1549804
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BlockLSTMGradV2.pbtxt
@@ -0,0 +1,171 @@
+op {
+  graph_op_name: "BlockLSTMGradV2"
+  visibility: HIDDEN
+  in_arg {
+    name: "seq_len_max"
+    description: <<END
+Maximum time length actually used by this input. Outputs are padded
+with zeros beyond this length.
+END
+  }
+  in_arg {
+    name: "x"
+    description: <<END
+The sequence input to the LSTM, shape (timelen, batch_size, num_inputs).
+END
+  }
+  in_arg {
+    name: "cs_prev"
+    description: <<END
+Value of the initial cell state.
+END
+  }
+  in_arg {
+    name: "h_prev"
+    description: <<END
+Initial output of cell (to be used for peephole).
+END
+  }
+  in_arg {
+    name: "w"
+    description: <<END
+The weight matrix.
+END
+  }
+  in_arg {
+    name: "wci"
+    description: <<END
+The weight matrix for input gate peephole connection.
+END
+  }
+  in_arg {
+    name: "wcf"
+    description: <<END
+The weight matrix for forget gate peephole connection.
+END
+  }
+  in_arg {
+    name: "wco"
+    description: <<END
+The weight matrix for output gate peephole connection.
+END
+  }
+  in_arg {
+    name: "b"
+    description: <<END
+The bias vector.
+END
+  }
+  in_arg {
+    name: "i"
+    description: <<END
+The input gate over the whole time sequence.
+END
+  }
+  in_arg {
+    name: "cs"
+    description: <<END
+The cell state before the tanh over the whole time sequence.
+END
+  }
+  in_arg {
+    name: "f"
+    description: <<END
+The forget gate over the whole time sequence.
+END
+  }
+  in_arg {
+    name: "o"
+    description: <<END
+The output gate over the whole time sequence.
+END
+  }
+  in_arg {
+    name: "ci"
+    description: <<END
+The cell input over the whole time sequence.
+END
+  }
+  in_arg {
+    name: "co"
+    description: <<END
+The cell after the tanh over the whole time sequence.
+END
+  }
+  in_arg {
+    name: "h"
+    description: <<END
+The output h vector over the whole time sequence.
+END
+  }
+  in_arg {
+    name: "cs_grad"
+    description: <<END
+The current gradient of cs.
+END
+  }
+  in_arg {
+    name: "h_grad"
+    description: <<END
+The gradient of h vector.
+END
+  }
+  out_arg {
+    name: "x_grad"
+    description: <<END
+The gradient of x to be back-propped.
+END
+  }
+  out_arg {
+    name: "cs_prev_grad"
+    description: <<END
+The gradient of cs_prev to be back-propped.
+END
+  }
+  out_arg {
+    name: "h_prev_grad"
+    description: <<END
+The gradient of h_prev to be back-propped.
+END
+  }
+  out_arg {
+    name: "w_grad"
+    description: <<END
+The gradient for w to be back-propped.
+END
+  }
+  out_arg {
+    name: "wci_grad"
+    description: <<END
+The gradient for wci to be back-propped.
+END
+  }
+  out_arg {
+    name: "wcf_grad"
+    description: <<END
+The gradient for wcf to be back-propped.
+END
+  }
+  out_arg {
+    name: "wco_grad"
+    description: <<END
+The gradient for wco to be back-propped.
+END
+  }
+  out_arg {
+    name: "b_grad"
+    description: <<END
+The gradient for w to be back-propped.
+END
+  }
+  attr {
+    name: "use_peephole"
+    description: <<END
+Whether to use peephole weights.
+END
+  }
+  summary: "Computes the LSTM cell backward propagation for the entire time sequence."
+  description: <<END
+This implementation is to be used in conjunction of BlockLSTMV2.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt
new file mode 100644
index 00000000000..4da9ebaf863
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt
@@ -0,0 +1,137 @@
+op {
+  graph_op_name: "BlockLSTMV2"
+  visibility: HIDDEN
+  in_arg {
+    name: "seq_len_max"
+    description: <<END
+Maximum time length actually used by this input. Outputs are padded
+with zeros beyond this length.
+END
+  }
+  in_arg {
+    name: "x"
+    description: <<END
+The sequence input to the LSTM, shape (timelen, batch_size, num_inputs).
+END
+  }
+  in_arg {
+    name: "cs_prev"
+    description: <<END
+Value of the initial cell state.
+END
+  }
+  in_arg {
+    name: "h_prev"
+    description: <<END
+Initial output of cell (to be used for peephole).
+END
+  }
+  in_arg {
+    name: "w"
+    description: <<END
+The weight matrix.
+END
+  }
+  in_arg {
+    name: "wci"
+    description: <<END
+The weight matrix for input gate peephole connection.
+END
+  }
+  in_arg {
+    name: "wcf"
+    description: <<END
+The weight matrix for forget gate peephole connection.
+END
+  }
+  in_arg {
+    name: "wco"
+    description: <<END
+The weight matrix for output gate peephole connection.
+END
+  }
+  in_arg {
+    name: "b"
+    description: <<END
+The bias vector.
+END
+  }
+  out_arg {
+    name: "i"
+    description: <<END
+The input gate over the whole time sequence.
+END
+  }
+  out_arg {
+    name: "cs"
+    description: <<END
+The cell state before the tanh over the whole time sequence.
+END
+  }
+  out_arg {
+    name: "f"
+    description: <<END
+The forget gate over the whole time sequence.
+END
+  }
+  out_arg {
+    name: "o"
+    description: <<END
+The output gate over the whole time sequence.
+END
+  }
+  out_arg {
+    name: "ci"
+    description: <<END
+The cell input over the whole time sequence.
+END
+  }
+  out_arg {
+    name: "co"
+    description: <<END
+The cell after the tanh over the whole time sequence.
+END
+  }
+  out_arg {
+    name: "h"
+    description: <<END
+The output h vector over the whole time sequence.
+END
+  }
+  attr {
+    name: "cell_clip"
+    description: <<END
+Value to clip the 'cs' value to.
+END
+  }
+  attr {
+    name: "use_peephole"
+    description: <<END
+Whether to use peephole weights.
+END
+  }
+  summary: "Computes the LSTM cell forward propagation for all the time steps."
+  description: <<END
+This is equivalent to applying LSTMBlockCell in a loop, like so:
+
+```python
+for x1 in unpack(x):
+  i1, cs1, f1, o1, ci1, co1, h1 = LSTMBlock(
+    x1, cs_prev, h_prev, w, wci, wcf, wco, b)
+  cs_prev = cs1
+  h_prev = h1
+  i.append(i1)
+  cs.append(cs1)
+  f.append(f1)
+  o.append(o1)
+  ci.append(ci1)
+  co.append(co1)
+  h.append(h1)
+return pack(i), pack(cs), pack(f), pack(o), pack(ci), pack(ch), pack(h)
+
+Note that unlike LSTMBlockCell (and BlockLSTM) which uses ICFO gate layout, 
+this op uses IFCO. So in order for the following snippet to be equivalent
+all gate-related outputs should be reordered.
+```
+END
+}
diff --git a/tensorflow/core/kernels/rnn/lstm_ops.cc b/tensorflow/core/kernels/rnn/lstm_ops.cc
index 57d3e9b1323..dfd77f9ca5f 100644
--- a/tensorflow/core/kernels/rnn/lstm_ops.cc
+++ b/tensorflow/core/kernels/rnn/lstm_ops.cc
@@ -235,7 +235,9 @@ void LSTMBlockCellBpropWithEigen(
   template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */,     \
                                      GATE_LAYOUT>;
 
-#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T, ICFO);
+#define DECLARE_CPU_SPECS(T)   \
+  DECLARE_CPU_FBPROP(T, ICFO); \
+  DECLARE_CPU_FBPROP(T, IFCO);
 
 DECLARE_CPU_SPECS(Eigen::half);
 DECLARE_CPU_SPECS(float);
@@ -827,7 +829,12 @@ template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
 class BlockLSTMOp : public OpKernel {
  public:
   explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
-    OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
+    if (ctx->HasAttr("forget_bias")) {
+      OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
+    } else {
+      // V2 version does not have "forget_bias" attribute.
+      forget_bias_ = 0.0;
+    }
     OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
   }
@@ -1010,10 +1017,13 @@ class BlockLSTMOp : public OpKernel {
   bool use_peephole_;
 };
 
-#define REGISTER_KERNEL(T)                                         \
-  REGISTER_KERNEL_BUILDER(                                         \
-      Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      BlockLSTMOp<CPUDevice, T, false, ICFO>);
+#define REGISTER_KERNEL(T)                                           \
+  REGISTER_KERNEL_BUILDER(                                           \
+      Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
+      BlockLSTMOp<CPUDevice, T, false, ICFO>);                       \
+  REGISTER_KERNEL_BUILDER(                                           \
+      Name("BlockLSTMV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      BlockLSTMOp<CPUDevice, T, false, IFCO>);
 
 REGISTER_KERNEL(Eigen::half);
 REGISTER_KERNEL(float);
@@ -1039,12 +1049,17 @@ DECLARE_GPU_SPECS(float);
 #undef DECLARE_GPU_SPECS
 }  // end namespace functor
 
-#define REGISTER_GPU_KERNEL(T)                           \
-  REGISTER_KERNEL_BUILDER(Name("BlockLSTM")              \
-                              .Device(DEVICE_GPU)        \
-                              .HostMemory("seq_len_max") \
-                              .TypeConstraint<T>("T"),   \
-                          BlockLSTMOp<GPUDevice, T, true, ICFO>);
+#define REGISTER_GPU_KERNEL(T)                                    \
+  REGISTER_KERNEL_BUILDER(Name("BlockLSTM")                       \
+                              .Device(DEVICE_GPU)                 \
+                              .HostMemory("seq_len_max")          \
+                              .TypeConstraint<T>("T"),            \
+                          BlockLSTMOp<GPUDevice, T, true, ICFO>); \
+  REGISTER_KERNEL_BUILDER(Name("BlockLSTMV2")                     \
+                              .Device(DEVICE_GPU)                 \
+                              .HostMemory("seq_len_max")          \
+                              .TypeConstraint<T>("T"),            \
+                          BlockLSTMOp<GPUDevice, T, true, IFCO>);
 
 REGISTER_GPU_KERNEL(Eigen::half);
 REGISTER_GPU_KERNEL(float);
@@ -1284,10 +1299,13 @@ class BlockLSTMGradOp : public OpKernel {
   bool use_peephole_;
 };
 
-#define REGISTER_KERNEL(T)                                             \
-  REGISTER_KERNEL_BUILDER(                                             \
-      Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
-      BlockLSTMGradOp<CPUDevice, T, false, ICFO>);
+#define REGISTER_KERNEL(T)                                               \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"),   \
+      BlockLSTMGradOp<CPUDevice, T, false, ICFO>);                       \
+  REGISTER_KERNEL_BUILDER(                                               \
+      Name("BlockLSTMGradV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      BlockLSTMGradOp<CPUDevice, T, false, IFCO>);
 
 REGISTER_KERNEL(Eigen::half);
 REGISTER_KERNEL(float);
@@ -1345,7 +1363,8 @@ namespace functor {
   extern template struct TensorCopy<GPUDevice, T>;                             \
   extern template struct TensorAdd<GPUDevice, T>;                              \
                                                                                \
-  DECLARE_GPU_BPROP(T, ICFO);
+  DECLARE_GPU_BPROP(T, ICFO);                                                  \
+  DECLARE_GPU_BPROP(T, IFCO);
 
 DECLARE_GPU_SPECS(Eigen::half);
 DECLARE_GPU_SPECS(float);
@@ -1353,12 +1372,17 @@ DECLARE_GPU_SPECS(float);
 #undef DECLARE_GPU_BPROP
 }  // end namespace functor
 
-#define REGISTER_GPU_KERNEL(T)                           \
-  REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad")          \
-                              .Device(DEVICE_GPU)        \
-                              .HostMemory("seq_len_max") \
-                              .TypeConstraint<T>("T"),   \
-                          BlockLSTMGradOp<GPUDevice, T, true, ICFO>);
+#define REGISTER_GPU_KERNEL(T)                                        \
+  REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad")                       \
+                              .Device(DEVICE_GPU)                     \
+                              .HostMemory("seq_len_max")              \
+                              .TypeConstraint<T>("T"),                \
+                          BlockLSTMGradOp<GPUDevice, T, true, ICFO>); \
+  REGISTER_KERNEL_BUILDER(Name("BlockLSTMGradV2")                     \
+                              .Device(DEVICE_GPU)                     \
+                              .HostMemory("seq_len_max")              \
+                              .TypeConstraint<T>("T"),                \
+                          BlockLSTMGradOp<GPUDevice, T, true, IFCO>);
 
 REGISTER_GPU_KERNEL(Eigen::half);
 REGISTER_GPU_KERNEL(float);
diff --git a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc
index f3f37858986..3c1ea27b1ea 100644
--- a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc
@@ -468,7 +468,8 @@ void LSTMBlockCellBpropWithCUDA(
   template struct TensorCopyToUnaligned<GPUDevice, T>; \
   template struct TensorAdd<GPUDevice, T>;             \
                                                        \
-  DECLARE_GPU_FBPROP(T, ICFO);
+  DECLARE_GPU_FBPROP(T, ICFO);                         \
+  DECLARE_GPU_FBPROP(T, IFCO);
 
 DECLARE_GPU_SPECS(Eigen::half);
 DECLARE_GPU_SPECS(float);
diff --git a/tensorflow/core/ops/rnn_ops.cc b/tensorflow/core/ops/rnn_ops.cc
index b926feb9d2e..af5dc3d26d1 100644
--- a/tensorflow/core/ops/rnn_ops.cc
+++ b/tensorflow/core/ops/rnn_ops.cc
@@ -199,6 +199,45 @@ REGISTER_OP("BlockLSTM")
       return Status::OK();
     });
 
+REGISTER_OP("BlockLSTMV2")
+    .Input("seq_len_max: int64")
+    .Input("x: T")
+    .Input("cs_prev: T")
+    .Input("h_prev: T")
+    .Input("w: T")
+    .Input("wci: T")
+    .Input("wcf: T")
+    .Input("wco: T")
+    .Input("b: T")
+    .Output("i: T")
+    .Output("cs: T")
+    .Output("f: T")
+    .Output("o: T")
+    .Output("ci: T")
+    .Output("co: T")
+    .Output("h: T")
+    .Attr("cell_clip: float = 0.0")
+    .Attr("use_peephole: bool = false")
+    .Attr("T: {half, float}")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle x, b;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b));
+
+      DimensionHandle timelen = c->Dim(x, 0);
+      DimensionHandle batch_size = c->Dim(x, 1);
+      DimensionHandle cell_size;
+      TF_RETURN_IF_ERROR(
+          c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size));
+
+      DCHECK_EQ(7, c->num_outputs());
+      ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size});
+      for (int i = 0; i < 7; ++i) {
+        c->set_output(i, output);
+      }
+      return Status::OK();
+    });
+
 REGISTER_OP("BlockLSTMGrad")
     .Input("seq_len_max: int64")
     .Input("x: T")
@@ -251,4 +290,56 @@ REGISTER_OP("BlockLSTMGrad")
       return Status::OK();
     });
 
+REGISTER_OP("BlockLSTMGradV2")
+    .Input("seq_len_max: int64")
+    .Input("x: T")
+    .Input("cs_prev: T")
+    .Input("h_prev: T")
+    .Input("w: T")
+    .Input("wci: T")
+    .Input("wcf: T")
+    .Input("wco: T")
+    .Input("b: T")
+    .Input("i: T")
+    .Input("cs: T")
+    .Input("f: T")
+    .Input("o: T")
+    .Input("ci: T")
+    .Input("co: T")
+    .Input("h: T")
+    .Input("cs_grad: T")
+    .Input("h_grad: T")
+    .Output("x_grad: T")
+    .Output("cs_prev_grad: T")
+    .Output("h_prev_grad: T")
+    .Output("w_grad: T")
+    .Output("wci_grad: T")
+    .Output("wcf_grad: T")
+    .Output("wco_grad: T")
+    .Output("b_grad: T")
+    .Attr("use_peephole: bool")
+    .Attr("T: {half, float}")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b));
+
+      c->set_output(0, x);
+      c->set_output(1, cs_prev);
+      c->set_output(2, h_prev);
+      c->set_output(3, w);
+      c->set_output(4, wci);
+      c->set_output(5, wco);
+      c->set_output(6, wcf);
+      c->set_output(7, b);
+
+      return Status::OK();
+    });
+
 }  // end namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 996c31b8635..77a21caaf63 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3691,6 +3691,24 @@ py_library(
     ],
 )
 
+py_test(
+    name = "rnn_grad_test",
+    srcs = ["ops/rnn_grad_test.py"],
+    python_version = "PY3",
+    deps = [
+        ":array_ops",
+        ":client_testlib",
+        ":dtypes",
+        ":framework_ops",
+        ":framework_test_lib",
+        ":gradients",
+        ":math_ops",
+        ":rnn_grad",
+        ":rnn_ops_gen",
+        "//third_party/py/numpy",
+    ],
+)
+
 py_library(
     name = "standard_ops",
     srcs = ["ops/standard_ops.py"],
diff --git a/tensorflow/python/ops/rnn_grad.py b/tensorflow/python/ops/rnn_grad.py
index f2707e178b0..e316b7fb8a1 100644
--- a/tensorflow/python/ops/rnn_grad.py
+++ b/tensorflow/python/ops/rnn_grad.py
@@ -21,7 +21,6 @@ from tensorflow.python.framework import ops
 from tensorflow.python.ops import gen_rnn_ops
 
 
-@ops.RegisterGradient("BlockLSTM")
 def _block_lstm_grad(op, *grads):
   """Gradient for the BlockLSTM op."""
   seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs
@@ -50,3 +49,7 @@ def _block_lstm_grad(op, *grads):
        use_peephole=op.get_attr("use_peephole"))
   return (None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
           wco_grad, b_grad)
+
+
+ops.RegisterGradient("BlockLSTM")(_block_lstm_grad)
+ops.RegisterGradient("BlockLSTMV2")(_block_lstm_grad)
diff --git a/tensorflow/python/ops/rnn_grad_test.py b/tensorflow/python/ops/rnn_grad_test.py
new file mode 100644
index 00000000000..2b320234538
--- /dev/null
+++ b/tensorflow/python/ops/rnn_grad_test.py
@@ -0,0 +1,99 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for gradients of (block) LSTM/GRU operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_rnn_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import rnn_grad  # pylint: disable=unused-import
+from tensorflow.python.platform import test
+
+
+class RNNGradTest(test.TestCase):
+
+  @test_util.deprecated_graph_mode_only
+  def testBlockLSTMV1V2Consistency(self):
+    num_steps = 1
+    batch_size = 1
+    input_size = 1
+    hidden_size = 8
+    w = deterministic_random_uniform(
+        [input_size + hidden_size, 4 * hidden_size])
+    b = deterministic_random_uniform([4 * hidden_size])
+    x = deterministic_random_uniform([num_steps, batch_size, input_size])
+    cs_prev = h_prev = deterministic_random_uniform([batch_size, hidden_size])
+
+    all_cs, all_h = self._lstm_block(
+        functools.partial(
+            gen_rnn_ops.BlockLSTM,
+            forget_bias=0.0,  # Disable to match V2 default.
+            cell_clip=0.0),  # Disable to match V2 default.
+        w, b, x, cs_prev, h_prev)
+    w_grad, b_grad = gradients.gradients(all_cs + all_h, [w, b])
+
+    w_ifco, b_ifco = icfo_to_ifco(w, b)
+    all_cs_ifco, all_h_ifco = self._lstm_block(
+        gen_rnn_ops.BlockLSTMV2, w_ifco, b_ifco, x, cs_prev, h_prev)
+    w_ifco_grad, b_ifco_grad = gradients.gradients(
+        all_cs_ifco + all_h_ifco, [w_ifco, b_ifco])
+
+    self.assertAllEqual(all_cs, all_cs_ifco)
+    self.assertAllEqual(all_h, all_h_ifco)
+    self.assertAllEqual(w_grad, w_ifco_grad)
+    self.assertAllEqual(b_grad, b_ifco_grad)
+
+  def _lstm_block(self, op, w, b, x, cs_prev, h_prev):
+    w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype)
+    _, all_cs, _, _, _, _, all_h = op(
+        seq_len_max=math_ops.cast(array_ops.shape(x)[0], dtypes.int64),
+        x=x,
+        cs_prev=cs_prev,
+        h_prev=h_prev,
+        w=w,
+        wci=w_peephole,
+        wcf=w_peephole,
+        wco=w_peephole,
+        b=b,
+        use_peephole=False)
+    return all_cs, all_h
+
+
+def deterministic_random_uniform(shape):
+  return ops.convert_to_tensor(np.random.random(shape), dtype=dtypes.float32)
+
+
+def icfo_to_ifco(w, b):
+  """Convert gates' weights and biases from ICFO to IFCO layout."""
+  w_i, w_c, w_f, w_o = array_ops.split(w, num_or_size_splits=4, axis=1)
+  b_i, b_c, b_f, b_o = array_ops.split(b, num_or_size_splits=4)
+  w_ifco = array_ops.concat([w_i, w_f, w_c, w_o], axis=1)
+  b_ifco = array_ops.concat([b_i, b_f, b_c, b_o], axis=0)
+  return w_ifco, b_ifco
+
+
+if __name__ == "__main__":
+  test.main()
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index 0f9e320753f..1136881b8a5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -472,6 +472,14 @@ tf_module {
     name: "BlockLSTMGrad"
     argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "BlockLSTMGradV2"
+    argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "BlockLSTMV2"
+    argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'cell_clip\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'None\'], "
+  }
   member_method {
     name: "BoostedTreesAggregateStats"
     argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index 0f9e320753f..1136881b8a5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -472,6 +472,14 @@ tf_module {
     name: "BlockLSTMGrad"
     argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
+  member_method {
+    name: "BlockLSTMGradV2"
+    argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "BlockLSTMV2"
+    argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'cell_clip\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'None\'], "
+  }
   member_method {
     name: "BoostedTreesAggregateStats"
     argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "