diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 6976a372983..729d84a07b0 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -104,6 +104,7 @@ filegroup(
         "//tensorflow/contrib/testing:all_files",
         "//tensorflow/contrib/util:all_files",
         "//tensorflow/core:all_files",
+        "//tensorflow/core/debug:all_files",
         "//tensorflow/core/distributed_runtime:all_files",
         "//tensorflow/core/distributed_runtime/rpc:all_files",
         "//tensorflow/core/kernels:all_files",
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
index 510426cc034..9db453f0dd2 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
@@ -38,7 +38,6 @@ namespace {
 const char kFfmpegExecutable[] = "ffmpeg";
 const int32 kDefaultProbeSize = 5000000;  // 5MB
 
-
 std::vector<string> FfmpegCommandLine(const string& input_filename,
                                       const string& output_filename,
                                       const string& input_format_id,
@@ -63,6 +62,39 @@ std::vector<string> FfmpegCommandLine(const string& input_filename,
   };
 }
 
+// Is a named binary installed and executable by the current process?
+// Note that this is harder than it seems like it should be...
+bool IsBinaryInstalled(const string& binary_name) {
+  string path = ::getenv("PATH");
+  for (const string& dir : str_util::Split(path, ':')) {
+    const string binary_path = io::JoinPath(dir, binary_name);
+    char absolute_path[PATH_MAX + 1];
+    ::realpath(binary_path.c_str(), absolute_path);
+    struct stat statinfo;
+    int result = ::stat(absolute_path, &statinfo);
+    if (result < 0) {
+      continue;
+    }
+    if (!S_ISREG(statinfo.st_mode)) {
+      continue;
+    }
+
+    // Is the current user able to execute the file?
+    if (statinfo.st_uid == ::geteuid() && statinfo.st_mode & S_IXUSR) {
+      return true;
+    }
+    // Is the current group able to execute the file?
+    if (statinfo.st_uid == ::getegid() && statinfo.st_mode & S_IXGRP) {
+      return true;
+    }
+    // Is anyone able to execute the file?
+    if (statinfo.st_mode & S_IXOTH) {
+      return true;
+    }
+  }
+  return false;
+}
+
 [[noreturn]] int ExecuteFfmpeg(const std::vector<string>& args) {
   std::vector<char*> args_chars;
   std::transform(args.begin(), args.end(), std::back_inserter(args_chars),
@@ -191,6 +223,14 @@ Status ReadAudioFile(const string& filename,
       FfmpegCommandLine(filename, output_filename, audio_format_id,
                         samples_per_second, channel_count);
 
+  // Unfortunately, it's impossible to differentiate an exec failure due to the
+  // binary being missing and an error from the binary's execution. Therefore,
+  // check to see if the binary *should* be available. If not, return an error
+  // that will be converted into a helpful error message by the TensorFlow op.
+  if (!IsBinaryInstalled(kFfmpegExecutable)) {
+    return Status(error::Code::NOT_FOUND, StrCat("FFmpeg could not be found."));
+  }
+
   // Execute ffmpeg and report errors.
   pid_t child_pid = ::fork();
   if (child_pid < 0) {
@@ -202,7 +242,7 @@ Status ReadAudioFile(const string& filename,
     int status_code;
     ::waitpid(child_pid, &status_code, 0);
     if (status_code) {
-      return Status(error::Code::NOT_FOUND,
+      return Status(error::Code::UNKNOWN,
                     StrCat("FFmpeg execution failed: ", status_code));
     }
     *output_samples = ReadPcmFile(output_filename);
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index 9c9cfe4c99b..4d849894051 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -818,7 +818,7 @@ class DropoutTest(tf.test.TestCase):
     with self.test_session():
       images = np.random.uniform(size=(5, height, width, 3))
       output = tf.contrib.layers.dropout(images)
-      self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
+      self.assertEquals(output.op.name, 'Dropout/dropout/mul')
       output.get_shape().assert_is_compatible_with(
           tf.convert_to_tensor(images).get_shape())
 
@@ -828,7 +828,7 @@ class DropoutTest(tf.test.TestCase):
       is_training = tf.constant(True)
       images = tf.random_uniform((5, height, width, 3), seed=1)
       output = tf.contrib.layers.dropout(images, is_training=is_training)
-      self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
+      self.assertEquals(output.op.name, 'Dropout/dropout/mul')
       output.get_shape().assert_is_compatible_with(images.get_shape())
 
   def testCreateDropoutWithConstantFalse(self):
diff --git a/tensorflow/contrib/layers/python/layers/target_column.py b/tensorflow/contrib/layers/python/layers/target_column.py
index 9f321895025..08280446723 100644
--- a/tensorflow/contrib/layers/python/layers/target_column.py
+++ b/tensorflow/contrib/layers/python/layers/target_column.py
@@ -22,6 +22,7 @@ import inspect
 
 import six
 
+from tensorflow.contrib import losses
 from tensorflow.contrib import metrics as metrics_lib
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
-from tensorflow.python.ops import nn_ops
 
 
 def regression_target(label_name=None,
@@ -297,8 +297,17 @@ class _BinarySvmTargetColumn(_MultiClassTargetColumn):
   """_TargetColumn for binary classification using SVMs."""
 
   def __init__(self, label_name, weight_column_name):
+    def loss_fn(logits, target):
+      check_shape_op = logging_ops.Assert(
+          math_ops.less_equal(array_ops.rank(target), 2),
+          ["target's shape should be either [batch_size, 1] or [batch_size]"])
+      with ops.control_dependencies([check_shape_op]):
+        target = array_ops.reshape(
+            target, shape=[array_ops.shape(target)[0], 1])
+      return losses.hinge_loss(logits, target)
+
     super(_BinarySvmTargetColumn, self).__init__(
-        loss_fn=_binary_hinge_loss,
+        loss_fn=loss_fn,
         n_classes=2,
         label_name=label_name,
         weight_column_name=weight_column_name)
@@ -331,22 +340,6 @@ def _log_loss_with_two_classes(logits, target):
   return loss_vec
 
 
-# TODO(sibyl-vie3Poto): Move this to contrib/losses/python/losses/loss_ops.py.
-def _binary_hinge_loss(logits, target):
-  """Method that returns the loss vector for binary hinge loss."""
-  check_shape_op = logging_ops.Assert(
-      math_ops.less_equal(
-          array_ops.rank(target), 2),
-      ["target's shape should be either [batch_size, 1] or [batch_size]"])
-  with ops.control_dependencies([check_shape_op]):
-    target = array_ops.reshape(target, shape=[array_ops.shape(target)[0], 1])
-  # First need to convert binary labels to -1/1 labels (as floats).
-  all_ones = array_ops.ones_like(logits)
-  labels = math_ops.sub(2 * math_ops.to_float(target), all_ones)
-  loss_vec = nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits)))
-  return loss_vec
-
-
 def _softmax_cross_entropy_loss(logits, target):
   # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
   # Check that we got int32/int64 for classification.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py
index f646cdf477c..a39254e7b49 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/svm.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py
@@ -61,13 +61,13 @@ class SVM(linear.LinearClassifier):
         whose `value` is a `SparseTensor`.
       - if `column` is a `RealValuedColumn, a feature with `key=column.name`
         whose `value` is a `Tensor`.
-      - if `feauture_columns` is None, then `input` must contains only real
+      - if `feature_columns` is None, then `input` must contains only real
         valued `Tensor`.
 
 
   Parameters:
     example_id_column: A string defining the feature column name representing
-      example ids. Used do initialize the underlying optimizer.
+      example ids. Used to initialize the underlying optimizer.
     feature_columns: An iterable containing all the feature columns used by the
       model. All items in the set should be instances of classes derived from
       `FeatureColumn`.
@@ -75,10 +75,12 @@ class SVM(linear.LinearClassifier):
       weights. It is used to down weight or boost examples during training. It
       will be multiplied by the loss of the example.
     model_dir: Directory to save model parameters, graph and etc. This can also
-        be used to load checkpoints from the directory into a estimator to continue
-        training a previously saved model.
-    l1_regularization: L1-regularization parameter
-    l2_regularization: L2-regularization parameter
+        be used to load checkpoints from the directory into a estimator to
+        continue training a previously saved model.
+    l1_regularization: L1-regularization parameter. Refers to global L1
+    regularization (across all examples).
+    l2_regularization: L2-regularization parameter. Refers to global L2
+    regularization (across all examples).
     kernels: A list of kernels for the SVM. Currently, no kernels are supported.
       Reserved for future use for non-linear SVMs
     config: RunConfig object to configure the runtime settings.
@@ -100,12 +102,13 @@ class SVM(linear.LinearClassifier):
         symmetric_l1_regularization=l1_regularization,
         symmetric_l2_regularization=l2_regularization)
 
-    super(SVM, self).__init__(model_dir=model_dir,
-                              n_classes=2,
-                              weight_column_name=weight_column_name,
-                              feature_columns=feature_columns,
-                              optimizer=optimizer,
-                              config=config)
+    super(SVM, self).__init__(
+        model_dir=model_dir,
+        n_classes=2,
+        weight_column_name=weight_column_name,
+        feature_columns=feature_columns,
+        optimizer=optimizer,
+        config=config)
     self._target_column = layers.binary_svm_target(
         weight_column_name=weight_column_name)
 
diff --git a/tensorflow/contrib/losses/python/losses/__init__.py b/tensorflow/contrib/losses/python/losses/__init__.py
index 081d47e4b55..d8181632bf8 100644
--- a/tensorflow/contrib/losses/python/losses/__init__.py
+++ b/tensorflow/contrib/losses/python/losses/__init__.py
@@ -106,6 +106,7 @@ weighted average over the individual prediction errors:
 
 @@absolute_difference
 @@add_loss
+@@hinge_loss
 @@cosine_distance
 @@get_losses
 @@get_regularization_losses
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 99aab8b44c2..597e6aeda93 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
 
 
 __all__ = ["absolute_difference",
@@ -33,6 +34,7 @@ __all__ = ["absolute_difference",
            "get_losses",
            "get_regularization_losses",
            "get_total_loss",
+           "hinge_loss",
            "log_loss",
            "sigmoid_cross_entropy",
            "softmax_cross_entropy",
@@ -410,6 +412,31 @@ def log_loss(predictions, targets, weight=1.0, epsilon=1e-7, scope=None):
     return _compute_weighted_loss(losses, weight)
 
 
+def hinge_loss(logits, target, scope=None):
+  """Method that returns the loss tensor for hinge loss.
+
+  Args:
+    logits: The logits, a float tensor.
+    target: The ground truth output tensor. Its shape should match the shape of
+      logits. The values of the tensor are expected to be 0.0 or 1.0.
+    scope: The scope for the operations performed in computing the loss.
+
+  Returns:
+    A `Tensor` of same shape as logits and target representing the loss values
+      across the batch.
+
+  Raises:
+    ValueError: If the shapes of `logits` and `target` don't match.
+  """
+  with ops.op_scope([logits, target], scope, "hinge_loss") as scope:
+    logits.get_shape().assert_is_compatible_with(target.get_shape())
+    # We first need to convert binary labels to -1/1 labels (as floats).
+    target = math_ops.to_float(target)
+    all_ones = array_ops.ones_like(target)
+    labels = math_ops.sub(2 * target, all_ones)
+    return nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits)))
+
+
 def sum_of_squares(predictions, targets, weight=1.0, scope=None):
   """Adds a Sum-of-Squares loss to the training procedure.
 
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 49460ec2279..824c24451be 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -499,6 +499,42 @@ class LogLossTest(tf.test.TestCase):
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
 
+class HingeLossTest(tf.test.TestCase):
+
+  def testIncompatibleShapes(self):
+    with self.test_session():
+      logits = tf.constant([[-1.0], [2.1]])
+      target = tf.constant([0.0, 1.0])
+      with self.assertRaises(ValueError):
+        _ = tf.contrib.losses.hinge_loss(logits, target).eval()
+
+  def testAllOutsideMargin(self):
+    with self.test_session():
+      logits = tf.constant([1.2, -1.4, -1.0, 2.1])
+      target = tf.constant([1.0, 0.0, 0.0, 1.0])
+      loss = tf.contrib.losses.hinge_loss(logits, target)
+      self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3)
+
+  def testSomeInsideMargin(self):
+    with self.test_session():
+      logits = tf.constant([[-0.7], [-1.4], [1.4], [0.6]])
+      target = tf.constant([[0.0], [0.0], [1.0], [1.0]])
+      loss = tf.contrib.losses.hinge_loss(logits, target)
+      # Examples 1 and 4 are on the correct side of the hyperplane but within
+      # the margin so they incur some (small) loss.
+      self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3)
+
+  def testSomeMisclassified(self):
+    with self.test_session():
+      logits = tf.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
+      target = tf.constant([[[1.0], [0.0], [0.0], [1.0]]])
+      loss = tf.contrib.losses.hinge_loss(logits, target)
+      # Examples 2 and 4 are on the wrong side of the hyperplane so they incur
+      # some (fairly large) loss.
+      self.assertAllClose(
+          loss.eval(), [[[0.0], [1.4], [0.0], [2.1]]], atol=1e-3)
+
+
 class SumOfSquaresLossTest(tf.test.TestCase):
 
   def setUp(self):
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 9e819ba62fd..dffd139ec0d 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -9,10 +9,14 @@ exports_files(["LICENSE"])
 package(default_visibility = ["//tensorflow:__subpackages__"])
 
 load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
 
 py_library(
     name = "rnn_py",
     srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+    data = [
+        ":python/ops/_lstm_ops.so",
+    ],
     srcs_version = "PY2AND3",
 )
 
@@ -27,6 +31,33 @@ cuda_py_tests(
     ],
 )
 
+cuda_py_tests(
+    name = "lstm_ops_test",
+    size = "small",
+    srcs = ["python/kernel_tests/lstm_ops_test.py"],
+    additional_deps = [
+        ":rnn_py",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+    ],
+)
+
+tf_custom_op_library(
+    name = "python/ops/_lstm_ops.so",
+    srcs = [
+        "kernels/lstm_ops.cc",
+        "kernels/lstm_ops.h",
+        "ops/lstm_ops.cc",
+    ],
+    gpu_srcs = [
+        "kernels/lstm_ops_gpu.cu.cc",
+        "kernels/lstm_ops.h",
+    ],
+    deps = [
+        "//tensorflow/core/kernels:eigen_helpers",
+    ],
+)
+
 filegroup(
     name = "all_files",
     srcs = glob(
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index 2193f644849..8ead5f00045 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -12,14 +12,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Ops for representing statistical distributions.
+"""Additional RNN operations and cells.
 
-## This package provides classes for statistical distributions.
+## This package provides additional contributed RNNCells.
 
+### Fused RNNCells
+@@LSTMFusedCell
+
+### LSTM-like cells
+@@CoupledInputForgetGateLSTMCell
+@@TimeFreqLSTMCell
+@@GridLSTMCell
+
+### RNNCell wrappers
+@@AttentionCellWrapper
 """
+
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
 # pylint: disable=unused-import,wildcard-import, line-too-long
+from tensorflow.contrib.rnn.python.ops.lstm_ops import *
 from tensorflow.contrib.rnn.python.ops.rnn_cell import *
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc
new file mode 100644
index 00000000000..74bede713c1
--- /dev/null
+++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc
@@ -0,0 +1,1053 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#define EIGEN_USE_GPU
+#endif  // GOOGLE_CUDA
+
+#include "tensorflow/contrib/rnn/kernels/lstm_ops.h"
+
+#include <memory>
+#include <vector>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/platform/stream_executor.h"
+#endif  // GOOGLE_CUDA
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+#if GOOGLE_CUDA
+
+namespace {
+template <typename T>
+perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
+  perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
+  perftools::gputools::DeviceMemory<T> typed(wrapped);
+  return typed;
+}
+}  // namespace
+
+#endif  // GOOGLE_CUDA
+
+namespace functor {
+template <typename T>
+void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx,
+                                     perftools::gputools::Stream* stream,
+                                     bool transa, bool transb, uint64 m,
+                                     uint64 n, uint64 k, T alpha, const T* a,
+                                     int lda, const T* b, int ldb, T beta, T* c,
+                                     int ldc) {
+#if GOOGLE_CUDA
+  perftools::gputools::blas::Transpose trans[] = {
+      perftools::gputools::blas::Transpose::kNoTranspose,
+      perftools::gputools::blas::Transpose::kTranspose};
+
+  auto a_ptr = AsDeviceMemory(a);
+  auto b_ptr = AsDeviceMemory(b);
+  auto c_ptr = AsDeviceMemory(c);
+
+  bool blas_launch_status =
+      stream
+          ->ThenBlasGemm(trans[transa], trans[transb], m, n, k, alpha, a_ptr,
+                         lda, b_ptr, ldb, beta, &c_ptr, ldc)
+          .ok();
+  OP_REQUIRES(ctx, blas_launch_status, errors::Aborted("CuBlasGemm failed!"));
+#else
+  ctx->SetStatus(errors::InvalidArgument("CuBlasGemm needs CUDA."));
+#endif
+}
+
+template struct TensorCuBlasGemm<float>;
+// template struct TensorCuBlasGemm<double>;
+}  // end namespace functor
+
+template <typename Device, typename T, bool USE_CUBLAS>
+class LSTMFusedCellOp : public OpKernel {
+ public:
+  explicit LSTMFusedCellOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* x_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
+
+    const Tensor* cs_prev_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
+
+    const Tensor* h_prev_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
+
+    const Tensor* w_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
+
+    const Tensor* wci_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
+
+    const Tensor* wcf_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
+
+    const Tensor* wco_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
+
+    const Tensor* b_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
+
+    const int64 batch_size = x_tensor->dim_size(0);
+    const int64 input_size = x_tensor->dim_size(1);
+    const int64 cell_size = cs_prev_tensor->dim_size(1);
+
+    // Sanity checks for our input shapes.
+    OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
+                errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
+                                        cs_prev_tensor->dim_size(0), " vs. ",
+                                        batch_size));
+    OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
+                errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
+                                        cs_prev_tensor->dim_size(1), " vs. ",
+                                        cell_size));
+
+    OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
+                errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
+                                        h_prev_tensor->dim_size(0), " vs. ",
+                                        batch_size));
+    OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
+                errors::InvalidArgument("h_prev.dims(1) != cell_size: ",
+                                        h_prev_tensor->dim_size(1), " vs. ",
+                                        cell_size));
+
+    OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
+                errors::InvalidArgument(
+                    "w.dim_size(0) != input_size + cell_size: ",
+                    w_tensor->dim_size(0), " vs. ", input_size + cell_size));
+    OP_REQUIRES(
+        ctx, w_tensor->dim_size(1) == cell_size * 4,
+        errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ",
+                                w_tensor->dim_size(1), " vs. ", cell_size * 4));
+
+    OP_REQUIRES(
+        ctx, b_tensor->dim_size(0) == cell_size * 4,
+        errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ",
+                                b_tensor->dim_size(0), " vs. ", cell_size * 4));
+
+    // Allocate our output tensors.
+    Tensor* i_tensor = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output("i", TensorShape({batch_size, cell_size}),
+                                  &i_tensor));
+
+    Tensor* cs_tensor = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output("cs", TensorShape({batch_size, cell_size}),
+                                  &cs_tensor));
+
+    Tensor* f_tensor = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output("f", TensorShape({batch_size, cell_size}),
+                                  &f_tensor));
+
+    Tensor* o_tensor = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output("o", TensorShape({batch_size, cell_size}),
+                                  &o_tensor));
+
+    Tensor* ci_tensor = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output("ci", TensorShape({batch_size, cell_size}),
+                                  &ci_tensor));
+
+    Tensor* co_tensor = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output("co", TensorShape({batch_size, cell_size}),
+                                  &co_tensor));
+
+    Tensor* h_tensor = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output("h", TensorShape({batch_size, cell_size}),
+                                  &h_tensor));
+
+    // Allocate our temp tensors.
+    Tensor xh_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(
+                            DataTypeToEnum<T>::v(),
+                            TensorShape({batch_size, input_size + cell_size}),
+                            &xh_tensor));
+
+    Tensor icfo_tensor;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                      TensorShape({batch_size, cell_size * 4}),
+                                      &icfo_tensor));
+
+    const Device& device = ctx->eigen_device<Device>();
+    perftools::gputools::Stream* stream =
+        std::is_same<Device, GPUDevice>::value
+            ? ctx->op_device_context()->stream()
+            : nullptr;
+
+    functor::LSTMFusedCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
+                                                       cell_size)(
+        ctx, stream, device, forget_bias_, cell_clip_, use_peephole_,
+        x_tensor->matrix<T>(), cs_prev_tensor->matrix<T>(),
+        h_prev_tensor->matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(),
+        wcf_tensor->vec<T>(), wco_tensor->vec<T>(), b_tensor->vec<T>(),
+        xh_tensor.matrix<T>(), i_tensor->matrix<T>(), cs_tensor->matrix<T>(),
+        f_tensor->matrix<T>(), o_tensor->matrix<T>(), ci_tensor->matrix<T>(),
+        co_tensor->matrix<T>(), icfo_tensor.matrix<T>(), h_tensor->matrix<T>());
+  }
+
+ private:
+  float forget_bias_;
+  float cell_clip_;
+  bool use_peephole_;
+};
+
+#define REGISTER_KERNEL(T)                                             \
+  REGISTER_KERNEL_BUILDER(                                             \
+      Name("LSTMFusedCell").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      LSTMFusedCellOp<CPUDevice, T, false>);
+REGISTER_KERNEL(float);
+// REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+namespace functor {
+#define DECLARE_GPU_SPEC(T)                                                \
+  template <>                                                              \
+  void LSTMFusedCellFprop<GPUDevice, T, true>::operator()(                 \
+      OpKernelContext* ctx, perftools::gputools::Stream* stream,           \
+      const GPUDevice& d, const T forget_bias, const T cell_clip,          \
+      bool use_peephole, typename TTypes<T>::ConstMatrix x,                \
+      typename TTypes<T>::ConstMatrix cs_prev,                             \
+      typename TTypes<T>::ConstMatrix h_prev,                              \
+      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci, \
+      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,  \
+      typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,       \
+      typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,         \
+      typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,          \
+      typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,        \
+      typename TTypes<T>::Matrix icfo, typename TTypes<T>::Matrix h);      \
+                                                                           \
+  extern template struct LSTMFusedCellFprop<GPUDevice, T, true>;
+
+DECLARE_GPU_SPEC(float);
+// DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+}  // end namespace functor
+
+#define REGISTER_GPU_KERNEL(T)                                         \
+  REGISTER_KERNEL_BUILDER(                                             \
+      Name("LSTMFusedCell").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+      LSTMFusedCellOp<GPUDevice, T, true>);
+
+REGISTER_GPU_KERNEL(float);
+// REGISTER_GPU_KERNEL(double);
+#undef REGISTER_GPU_KERNEL
+#endif  // GOOGLE_CUDA
+
+template <typename Device, typename T, bool USE_CUBLAS>
+class LSTMFusedCellGradOp : public OpKernel {
+ public:
+  explicit LSTMFusedCellGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* x_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("x", &x_tensor));
+
+    const Tensor* cs_prev_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
+
+    const Tensor* h_prev_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
+
+    const Tensor* w_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
+
+    const Tensor* wci_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
+
+    const Tensor* wcf_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
+
+    const Tensor* wco_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
+
+    const Tensor* b_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
+
+    const Tensor* i_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("i", &i_tensor));
+
+    const Tensor* cs_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_tensor));
+
+    const Tensor* f_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("f", &f_tensor));
+
+    const Tensor* o_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("o", &o_tensor));
+
+    const Tensor* ci_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_tensor));
+
+    const Tensor* co_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("co", &co_tensor));
+
+    const Tensor* cs_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad_tensor));
+
+    const Tensor* h_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad_tensor));
+
+    const int64 batch_size = x_tensor->dim_size(0);
+    const int64 input_size = x_tensor->dim_size(1);
+    const int64 cell_size = cs_prev_tensor->dim_size(1);
+
+    // Sanity checks for our input shapes.
+    OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
+                errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
+                                        cs_prev_tensor->dim_size(0), " vs. ",
+                                        batch_size));
+    OP_REQUIRES(ctx, cs_prev_tensor->dim_size(1) == cell_size,
+                errors::InvalidArgument("cs_prev.dims(1) != cell_size: ",
+                                        cs_prev_tensor->dim_size(1), " vs. ",
+                                        cell_size));
+
+    OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
+                errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
+                                        h_prev_tensor->dim_size(0), " vs. ",
+                                        batch_size));
+    OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
+                errors::InvalidArgument("h_prev.dims(1) != cell_size: ",
+                                        h_prev_tensor->dim_size(1), " vs. ",
+                                        cell_size));
+
+    OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
+                errors::InvalidArgument(
+                    "w.dim_size(0) != input_size + cell_size: ",
+                    w_tensor->dim_size(0), " vs. ", input_size + cell_size));
+    OP_REQUIRES(
+        ctx, w_tensor->dim_size(1) == cell_size * 4,
+        errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ",
+                                w_tensor->dim_size(1), " vs. ", cell_size * 4));
+
+    OP_REQUIRES(
+        ctx, b_tensor->dim_size(0) == cell_size * 4,
+        errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ",
+                                b_tensor->dim_size(0), " vs. ", cell_size * 4));
+
+    OP_REQUIRES(
+        ctx, i_tensor->dim_size(0) == batch_size,
+        errors::InvalidArgument("i.dim_size(0) != batch_size: ",
+                                i_tensor->dim_size(0), " vs. ", batch_size));
+    OP_REQUIRES(
+        ctx, i_tensor->dim_size(1) == cell_size,
+        errors::InvalidArgument("i.dim_size(1) != cell_size: ",
+                                i_tensor->dim_size(1), " vs. ", cell_size));
+
+    OP_REQUIRES(
+        ctx, cs_tensor->dim_size(0) == batch_size,
+        errors::InvalidArgument("cs.dim_size(0) != batch_size: ",
+                                cs_tensor->dim_size(0), " vs. ", batch_size));
+    OP_REQUIRES(
+        ctx, cs_tensor->dim_size(1) == cell_size,
+        errors::InvalidArgument("cs.dim_size(1) != cell_size: ",
+                                cs_tensor->dim_size(1), " vs. ", cell_size));
+
+    OP_REQUIRES(
+        ctx, f_tensor->dim_size(0) == batch_size,
+        errors::InvalidArgument("f.dim_size(0) != batch_size: ",
+                                f_tensor->dim_size(0), " vs. ", batch_size));
+    OP_REQUIRES(
+        ctx, f_tensor->dim_size(1) == cell_size,
+        errors::InvalidArgument("i.dim_size(1) != cell_size: ",
+                                f_tensor->dim_size(1), " vs. ", cell_size));
+
+    OP_REQUIRES(
+        ctx, o_tensor->dim_size(0) == batch_size,
+        errors::InvalidArgument("o.dim_size(0) != batch_size: ",
+                                o_tensor->dim_size(0), " vs. ", batch_size));
+    OP_REQUIRES(
+        ctx, o_tensor->dim_size(1) == cell_size,
+        errors::InvalidArgument("o.dim_size(1) != cell_size: ",
+                                o_tensor->dim_size(1), " vs. ", cell_size));
+
+    OP_REQUIRES(
+        ctx, ci_tensor->dim_size(0) == batch_size,
+        errors::InvalidArgument("ci.dim_size(0) != batch_size: ",
+                                ci_tensor->dim_size(0), " vs. ", batch_size));
+    OP_REQUIRES(
+        ctx, ci_tensor->dim_size(1) == cell_size,
+        errors::InvalidArgument("ci.dim_size(1) != cell_size: ",
+                                ci_tensor->dim_size(1), " vs. ", cell_size));
+
+    OP_REQUIRES(
+        ctx, co_tensor->dim_size(0) == batch_size,
+        errors::InvalidArgument("co.dim_size(0) != batch_size: ",
+                                co_tensor->dim_size(0), " vs. ", batch_size));
+    OP_REQUIRES(
+        ctx, co_tensor->dim_size(1) == cell_size,
+        errors::InvalidArgument("co.dim_size(1) != cell_size: ",
+                                co_tensor->dim_size(1), " vs. ", cell_size));
+
+    OP_REQUIRES(ctx, cs_grad_tensor->dim_size(0) == batch_size,
+                errors::InvalidArgument(
+                    "cs_grad_tensor.dims(0) != batch_size: ",
+                    cs_grad_tensor->dim_size(0), " vs. ", batch_size));
+    OP_REQUIRES(ctx, cs_grad_tensor->dim_size(1) == cell_size,
+                errors::InvalidArgument("cs_grad_tensor.dims(1) != cell_size: ",
+                                        cs_grad_tensor->dim_size(1), " vs. ",
+                                        cell_size));
+
+    OP_REQUIRES(ctx, h_grad_tensor->dim_size(0) == batch_size,
+                errors::InvalidArgument("h_grad_tensor.dims(0) != batch_size: ",
+                                        h_grad_tensor->dim_size(0), " vs. ",
+                                        batch_size));
+    OP_REQUIRES(ctx, h_grad_tensor->dim_size(1) == cell_size,
+                errors::InvalidArgument("h_grad_tensor.dims(1) != cell_size: ",
+                                        h_grad_tensor->dim_size(1), " vs. ",
+                                        cell_size));
+
+    // Allocate our output tensors.
+    Tensor* cs_prev_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_output("cs_prev_grad",
+                                        TensorShape({batch_size, cell_size}),
+                                        &cs_prev_grad_tensor));
+
+    Tensor* dicfo_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(
+                            "dicfo", TensorShape({batch_size, cell_size * 4}),
+                            &dicfo_tensor));
+
+    Tensor* wci_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(),
+                                             &wci_grad_tensor));
+
+    Tensor* wcf_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(),
+                                             &wcf_grad_tensor));
+
+    Tensor* wco_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(),
+                                             &wco_grad_tensor));
+
+    // Allocate our temp tensors.
+    Tensor do_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           TensorShape({batch_size, cell_size}),
+                                           &do_tensor));
+
+    Tensor dcs_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           TensorShape({batch_size, cell_size}),
+                                           &dcs_tensor));
+
+    Tensor dci_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           TensorShape({batch_size, cell_size}),
+                                           &dci_tensor));
+
+    Tensor df_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           TensorShape({batch_size, cell_size}),
+                                           &df_tensor));
+
+    Tensor di_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           TensorShape({batch_size, cell_size}),
+                                           &di_tensor));
+
+    const Device& device = ctx->eigen_device<Device>();
+    perftools::gputools::Stream* stream =
+        std::is_same<Device, GPUDevice>::value
+            ? ctx->op_device_context()->stream()
+            : nullptr;
+
+    functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>());
+    functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>());
+    functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<float>());
+
+    functor::LSTMFusedCellBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
+                                                       cell_size)(
+        ctx, stream, device, use_peephole_, x_tensor->matrix<T>(),
+        cs_prev_tensor->matrix<T>(), h_prev_tensor->matrix<T>(),
+        w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
+        wco_tensor->vec<T>(), b_tensor->vec<T>(), i_tensor->matrix<T>(),
+        cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(),
+        ci_tensor->matrix<T>(), co_tensor->matrix<T>(),
+        cs_grad_tensor->matrix<T>(), h_grad_tensor->matrix<T>(),
+        do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
+        df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor->matrix<T>(),
+        cs_prev_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(),
+        wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>());
+  }
+
+ protected:
+  bool use_peephole_;
+};
+
+#define REGISTER_KERNEL(T)                                                 \
+  REGISTER_KERNEL_BUILDER(                                                 \
+      Name("LSTMFusedCellGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      LSTMFusedCellGradOp<CPUDevice, T, false>);
+REGISTER_KERNEL(float);
+// REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+namespace functor {
+#define DECLARE_GPU_SPEC(T)                                                   \
+  template <>                                                                 \
+  void LSTMFusedCellBprop<GPUDevice, T, true>::operator()(                    \
+      OpKernelContext* ctx, perftools::gputools::Stream* stream,              \
+      const GPUDevice& d, bool use_peephole,                                  \
+      typename TTypes<T>::ConstMatrix x,                                      \
+      typename TTypes<T>::ConstMatrix cs_prev,                                \
+      typename TTypes<T>::ConstMatrix h_prev,                                 \
+      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,    \
+      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,     \
+      typename TTypes<T>::ConstVec b, typename TTypes<T>::ConstMatrix i,      \
+      typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,  \
+      typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,  \
+      typename TTypes<T>::ConstMatrix co,                                     \
+      typename TTypes<T>::ConstMatrix cs_grad,                                \
+      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_, \
+      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,         \
+      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,           \
+      typename TTypes<T>::Matrix dicfo,                                       \
+      typename TTypes<T>::Matrix cs_prev_grad,                                \
+      typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,     \
+      typename TTypes<T>::Vec wco_grad);                                      \
+                                                                              \
+  extern template struct LSTMFusedCellBprop<GPUDevice, T, true>;
+
+DECLARE_GPU_SPEC(float);
+// DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+}  // namespace functor
+
+#define REGISTER_GPU_KERNEL(T)                                             \
+  REGISTER_KERNEL_BUILDER(                                                 \
+      Name("LSTMFusedCellGrad").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+      LSTMFusedCellGradOp<GPUDevice, T, true>);
+
+REGISTER_GPU_KERNEL(float);
+// REGISTER_GPU_KERNEL(double);
+#undef REGISTER_GPU_KERNEL
+#endif  // GOOGLE_CUDA
+
+template <typename Device, typename T, bool USE_CUBLAS>
+class FusedLSTMOp : public OpKernel {
+ public:
+  explicit FusedLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("max_len", &max_len_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* seq_len_max_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
+
+    OpInputList x_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("x", &x_list));
+    const int64 batch_size = x_list[0].dim_size(0);
+    const int64 input_size = x_list[0].dim_size(1);
+
+    const Tensor* cs_prev_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
+
+    const Tensor* h_prev_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
+
+    const Tensor* w_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
+
+    const Tensor* wci_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
+
+    const Tensor* wcf_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
+
+    const Tensor* wco_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
+
+    const Tensor* b_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
+    const int64 cell_size = b_tensor->dim_size(0) / 4;
+
+    OpOutputList i_list;
+    OP_REQUIRES_OK(ctx, ctx->output_list("i", &i_list));
+
+    OpOutputList cs_list;
+    OP_REQUIRES_OK(ctx, ctx->output_list("cs", &cs_list));
+
+    OpOutputList f_list;
+    OP_REQUIRES_OK(ctx, ctx->output_list("f", &f_list));
+
+    OpOutputList o_list;
+    OP_REQUIRES_OK(ctx, ctx->output_list("o", &o_list));
+
+    OpOutputList ci_list;
+    OP_REQUIRES_OK(ctx, ctx->output_list("ci", &ci_list));
+
+    OpOutputList co_list;
+    OP_REQUIRES_OK(ctx, ctx->output_list("co", &co_list));
+
+    OpOutputList h_list;
+    OP_REQUIRES_OK(ctx, ctx->output_list("h", &h_list));
+
+    TensorShape batch_cell_shape({batch_size, cell_size});
+    for (int64 t = 0; t < max_len_; ++t) {
+      Tensor* i_tensor = nullptr;
+      OP_REQUIRES_OK(ctx, i_list.allocate(t, batch_cell_shape, &i_tensor));
+
+      Tensor* cs_tensor = nullptr;
+      OP_REQUIRES_OK(ctx, cs_list.allocate(t, batch_cell_shape, &cs_tensor));
+
+      Tensor* f_tensor = nullptr;
+      OP_REQUIRES_OK(ctx, f_list.allocate(t, batch_cell_shape, &f_tensor));
+
+      Tensor* o_tensor = nullptr;
+      OP_REQUIRES_OK(ctx, o_list.allocate(t, batch_cell_shape, &o_tensor));
+
+      Tensor* ci_tensor = nullptr;
+      OP_REQUIRES_OK(ctx, ci_list.allocate(t, batch_cell_shape, &ci_tensor));
+
+      Tensor* co_tensor = nullptr;
+      OP_REQUIRES_OK(ctx, co_list.allocate(t, batch_cell_shape, &co_tensor));
+
+      Tensor* h_tensor = nullptr;
+      OP_REQUIRES_OK(ctx, h_list.allocate(t, batch_cell_shape, &h_tensor));
+    }
+
+    Tensor xh_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(
+                            DataTypeToEnum<T>::v(),
+                            TensorShape({batch_size, input_size + cell_size}),
+                            &xh_tensor));
+
+    Tensor icfo_tensor;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                      TensorShape({batch_size, cell_size * 4}),
+                                      &icfo_tensor));
+
+    const Device& device = ctx->eigen_device<Device>();
+    perftools::gputools::Stream* stream =
+        std::is_same<Device, GPUDevice>::value
+            ? ctx->op_device_context()->stream()
+            : nullptr;
+
+    const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
+    for (int64 t = 0; t < seq_len_max; ++t) {
+      const Tensor& x_tensor = x_list[t];
+      const Tensor& cs_prev_tensor2 =
+          t == 0 ? *cs_prev_tensor : *cs_list[t - 1];
+      const Tensor& h_prev_tensor2 = t == 0 ? *h_prev_tensor : *h_list[t - 1];
+
+      Tensor* i_tensor = i_list[t];
+      Tensor* cs_tensor = cs_list[t];
+      Tensor* f_tensor = f_list[t];
+      Tensor* o_tensor = o_list[t];
+      Tensor* ci_tensor = ci_list[t];
+      Tensor* co_tensor = co_list[t];
+      Tensor* h_tensor = h_list[t];
+
+      functor::LSTMFusedCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
+                                                         cell_size)(
+          ctx, stream, device, forget_bias_, cell_clip_, use_peephole_,
+          x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
+          h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
+          wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
+          b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor->matrix<T>(),
+          cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(),
+          ci_tensor->matrix<T>(), co_tensor->matrix<T>(),
+          icfo_tensor.matrix<T>(), h_tensor->matrix<T>());
+    }
+
+    for (int64 t = seq_len_max; t < max_len_; ++t) {
+      Tensor* cs_tensor = cs_list[t];
+      Tensor* h_tensor = h_list[t];
+
+      functor::TensorZero<Device, T>()(device, cs_tensor->flat<float>());
+      functor::TensorZero<Device, T>()(device, h_tensor->flat<float>());
+    }
+  }
+
+ private:
+  int64 max_len_;
+  float forget_bias_;
+  float cell_clip_;
+  bool use_peephole_;
+};
+
+#define REGISTER_KERNEL(T)                                         \
+  REGISTER_KERNEL_BUILDER(                                         \
+      Name("FusedLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      FusedLSTMOp<CPUDevice, T, false>);
+REGISTER_KERNEL(float);
+// REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+namespace functor {
+#define DECLARE_GPU_SPEC(T)                                              \
+  template <>                                                            \
+  void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d,          \
+                                            typename TTypes<T>::Flat t); \
+                                                                         \
+  extern template struct TensorZero<GPUDevice, T>;
+
+DECLARE_GPU_SPEC(float);
+// DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+}  // end namespace functor
+
+#define REGISTER_GPU_KERNEL(T)                           \
+  REGISTER_KERNEL_BUILDER(Name("FusedLSTM")              \
+                              .Device(DEVICE_GPU)        \
+                              .HostMemory("seq_len_max") \
+                              .TypeConstraint<T>("T"),   \
+                          FusedLSTMOp<GPUDevice, T, true>);
+
+REGISTER_GPU_KERNEL(float);
+// REGISTER_GPU_KERNEL(double);
+#undef REGISTER_GPU_KERNEL
+#endif  // GOOGLE_CUDA
+
+template <typename Device, typename T, bool USE_CUBLAS>
+class FusedLSTMGradOp : public OpKernel {
+ public:
+  explicit FusedLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("max_len", &max_len_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
+  }
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* seq_len_max_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
+
+    OpInputList x_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("x", &x_list));
+    const int64 batch_size = x_list[0].dim_size(0);
+    const int64 input_size = x_list[0].dim_size(1);
+
+    const Tensor* cs_prev_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
+
+    const Tensor* h_prev_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
+
+    const Tensor* w_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
+    const int64 cell_size = w_tensor->dim_size(1) / 4;
+    OP_REQUIRES(ctx, input_size + cell_size == w_tensor->dim_size(0),
+                errors::InvalidArgument("w matrix rows don't match: ",
+                                        input_size + cell_size, " vs. ",
+                                        w_tensor->dim_size(0)));
+
+    const Tensor* wci_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
+
+    const Tensor* wcf_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
+
+    const Tensor* wco_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
+
+    const Tensor* b_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
+    OP_REQUIRES(
+        ctx, cell_size == b_tensor->dim_size(0) / 4,
+        errors::InvalidArgument("w and b cell_size don't match: ", cell_size,
+                                " vs. ", b_tensor->dim_size(0)));
+
+    OpInputList i_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("i", &i_list));
+
+    OpInputList cs_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("cs", &cs_list));
+
+    OpInputList f_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("f", &f_list));
+
+    OpInputList o_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("o", &o_list));
+
+    OpInputList ci_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("ci", &ci_list));
+
+    OpInputList co_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("co", &co_list));
+
+    OpInputList h_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("h", &h_list));
+
+    OpInputList cs_grad_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("cs_grad", &cs_grad_list));
+
+    OpInputList h_grad_list;
+    OP_REQUIRES_OK(ctx, ctx->input_list("h_grad", &h_grad_list));
+
+    OpOutputList x_grad_list;
+    OP_REQUIRES_OK(ctx, ctx->output_list("x_grad", &x_grad_list));
+
+    Tensor* cs_prev_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_output("cs_prev_grad", cs_prev_tensor->shape(),
+                                        &cs_prev_grad_tensor));
+
+    Tensor* h_prev_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_output("h_prev_grad", h_prev_tensor->shape(),
+                                        &h_prev_grad_tensor));
+
+    Tensor* w_grad_tensor = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output("w_grad", w_tensor->shape(), &w_grad_tensor));
+
+    Tensor* wci_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output("wci_grad", wci_tensor->shape(),
+                                             &wci_grad_tensor));
+
+    Tensor* wcf_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output("wcf_grad", wcf_tensor->shape(),
+                                             &wcf_grad_tensor));
+
+    Tensor* wco_grad_tensor = nullptr;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output("wco_grad", wco_tensor->shape(),
+                                             &wco_grad_tensor));
+
+    Tensor* b_grad_tensor = nullptr;
+    OP_REQUIRES_OK(
+        ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor));
+
+    TensorShape batch_input_shape({batch_size, input_size});
+    TensorShape batch_cell_shape({batch_size, cell_size});
+    for (int64 t = 0; t < max_len_; ++t) {
+      Tensor* x_grad_tensor = nullptr;
+      OP_REQUIRES_OK(
+          ctx, x_grad_list.allocate(t, batch_input_shape, &x_grad_tensor));
+    }
+
+    Tensor xh_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(
+                            DataTypeToEnum<T>::v(),
+                            TensorShape({batch_size, input_size + cell_size}),
+                            &xh_tensor));
+
+    Tensor xh_grad_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           xh_tensor.shape(), &xh_grad_tensor));
+
+    Tensor do_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           batch_cell_shape, &do_tensor));
+
+    Tensor dcs_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           batch_cell_shape, &dcs_tensor));
+
+    Tensor dci_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           batch_cell_shape, &dci_tensor));
+
+    Tensor df_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           batch_cell_shape, &df_tensor));
+
+    Tensor di_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           batch_cell_shape, &di_tensor));
+
+    Tensor dicfo_tensor;
+    OP_REQUIRES_OK(ctx,
+                   ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                      TensorShape({batch_size, cell_size * 4}),
+                                      &dicfo_tensor));
+
+    Tensor cs_grad_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           batch_cell_shape, &cs_grad_tensor));
+
+    Tensor h_grad_tensor;
+    OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::v(),
+                                           batch_cell_shape, &h_grad_tensor));
+
+
+    const Device& device = ctx->eigen_device<Device>();
+    perftools::gputools::Stream* stream =
+        std::is_same<Device, GPUDevice>::value
+            ? ctx->op_device_context()->stream()
+            : nullptr;
+
+    functor::TensorZero<Device, T>()(device, cs_grad_tensor.flat<float>());
+    functor::TensorZero<Device, T>()(device,
+                                     cs_prev_grad_tensor->flat<float>());
+    functor::TensorZero<Device, T>()(device, h_grad_tensor.flat<float>());
+    functor::TensorZero<Device, T>()(device, h_prev_grad_tensor->flat<float>());
+    functor::TensorZero<Device, T>()(device, w_grad_tensor->flat<float>());
+    functor::TensorZero<Device, T>()(device, wci_grad_tensor->flat<float>());
+    functor::TensorZero<Device, T>()(device, wcf_grad_tensor->flat<float>());
+    functor::TensorZero<Device, T>()(device, wco_grad_tensor->flat<float>());
+    functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<float>());
+
+    const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
+    for (int64 t = seq_len_max - 1; t >= 0; --t) {
+      const Tensor& x_tensor = x_list[t];
+      const Tensor& cs_prev_tensor2 = t == 0 ? *cs_prev_tensor : cs_list[t - 1];
+      const Tensor& h_prev_tensor2 = t == 0 ? *h_prev_tensor : h_list[t - 1];
+      const Tensor& i_tensor = i_list[t];
+      const Tensor& cs_tensor = cs_list[t];
+      const Tensor& f_tensor = f_list[t];
+      const Tensor& o_tensor = o_list[t];
+      const Tensor& ci_tensor = ci_list[t];
+      const Tensor& co_tensor = co_list[t];
+
+      // Grab previous CS grad.
+      const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor;
+      functor::TensorAdd<Device, T>()(
+          device, const_cs_prev_grad_tensor.flat<T>(),
+          cs_grad_list[t].flat<T>(), cs_grad_tensor.flat<T>());
+
+      // Combine previous h grad and h grad coming on top.
+      const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor;
+      functor::TensorAdd<Device, T>()(
+          device, const_h_prev_grad_tensor.flat<T>(), h_grad_list[t].flat<T>(),
+          h_grad_tensor.flat<T>());
+
+      const Tensor& const_cs_grad_tensor = cs_grad_tensor;
+      const Tensor& const_h_grad_tensor = h_grad_tensor;
+
+      Tensor* x_grad_tensor = x_grad_list[t];
+      functor::FusedLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
+                                                     cell_size)(
+          ctx, stream, device, use_peephole_, x_tensor.matrix<T>(),
+          cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(),
+          w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(),
+          wco_tensor->vec<T>(), b_tensor->vec<T>(), xh_tensor.matrix<T>(),
+          i_tensor.matrix<T>(), cs_tensor.matrix<T>(), f_tensor.matrix<T>(),
+          o_tensor.matrix<T>(), ci_tensor.matrix<T>(), co_tensor.matrix<T>(),
+          const_cs_grad_tensor.matrix<T>(), const_h_grad_tensor.matrix<T>(),
+          do_tensor.matrix<T>(), dcs_tensor.matrix<T>(), dci_tensor.matrix<T>(),
+          df_tensor.matrix<T>(), di_tensor.matrix<T>(),
+          dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
+          h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(),
+          x_grad_tensor->matrix<T>(), w_grad_tensor->matrix<T>(),
+          wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(),
+          wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>());
+    }
+
+    for (int64 t = seq_len_max; t < max_len_; ++t) {
+      Tensor* x_grad_tensor = x_grad_list[t];
+      functor::TensorZero<Device, T>()(device, x_grad_tensor->flat<T>());
+    }
+  }
+
+ private:
+  int64 max_len_;
+  bool use_peephole_;
+};
+
+#define REGISTER_KERNEL(T)                                             \
+  REGISTER_KERNEL_BUILDER(                                             \
+      Name("FusedLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+      FusedLSTMGradOp<CPUDevice, T, false>);
+REGISTER_KERNEL(float);
+// REGISTER_KERNEL(double);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+namespace functor {
+#define DECLARE_GPU_SPEC(T)                                                    \
+  template <>                                                                  \
+  void TensorCopy<GPUDevice, T>::operator()(const GPUDevice& d,                \
+                                            typename TTypes<T>::ConstFlat src, \
+                                            typename TTypes<T>::Flat dst);     \
+                                                                               \
+  template <>                                                                  \
+  void TensorAdd<GPUDevice, T>::operator()(                                    \
+      const GPUDevice& d, typename TTypes<T>::ConstFlat a,                     \
+      typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c);            \
+                                                                               \
+  template <>                                                                  \
+  void FusedLSTMBprop<GPUDevice, T, true>::operator()(                         \
+      OpKernelContext* ctx, perftools::gputools::Stream* stream,               \
+      const GPUDevice& d, bool use_peephole,                                   \
+      typename TTypes<T>::ConstMatrix x,                                       \
+      typename TTypes<T>::ConstMatrix cs_prev,                                 \
+      typename TTypes<T>::ConstMatrix h_prev,                                  \
+      typename TTypes<T>::ConstMatrix w, typename TTypes<T>::ConstVec wci,     \
+      typename TTypes<T>::ConstVec wcf, typename TTypes<T>::ConstVec wco,      \
+      typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,           \
+      typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,   \
+      typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,    \
+      typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,  \
+      typename TTypes<T>::ConstMatrix cs_grad,                                 \
+      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,  \
+      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,          \
+      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,            \
+      typename TTypes<T>::Matrix dicfo,                                        \
+      typename TTypes<T>::Matrix cs_prev_grad,                                 \
+      typename TTypes<T>::Matrix h_prev_grad,                                  \
+      typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,   \
+      typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,     \
+      typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,      \
+      typename TTypes<T>::Vec b_grad);                                         \
+                                                                               \
+  extern template struct TensorCopy<GPUDevice, T>;                             \
+  extern template struct TensorAdd<GPUDevice, T>;                              \
+  extern template struct FusedLSTMBprop<GPUDevice, T, true>;
+
+DECLARE_GPU_SPEC(float);
+// DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+}  // end namespace functor
+
+#define REGISTER_GPU_KERNEL(T)                           \
+  REGISTER_KERNEL_BUILDER(Name("FusedLSTMGrad")          \
+                              .Device(DEVICE_GPU)        \
+                              .HostMemory("seq_len_max") \
+                              .TypeConstraint<T>("T"),   \
+                          FusedLSTMGradOp<GPUDevice, T, true>);
+
+REGISTER_GPU_KERNEL(float);
+// REGISTER_GPU_KERNEL(double);
+#undef REGISTER_GPU_KERNEL
+#endif  // GOOGLE_CUDA
+
+}  // end namespace tensorflow
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h
new file mode 100644
index 00000000000..bcb7bfa1e6e
--- /dev/null
+++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h
@@ -0,0 +1,420 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_activations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace perftools {
+namespace gputools {
+class Stream;
+}  // end namespace gputools
+}  // end namespace perftools
+
+namespace tensorflow {
+class OpKernelContext;
+
+namespace functor {
+
+template <typename Device, typename T>
+struct TensorZero {
+  void operator()(const Device& d, typename TTypes<T>::Flat t) {
+    t.device(d) = t.constant(T(0));
+  }
+};
+
+template <typename Device, typename T>
+struct TensorCopy {
+  void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
+                  typename TTypes<T>::Flat dst) {
+    dst.device(d) = src;
+  }
+};
+
+template <typename Device, typename T>
+struct TensorAdd {
+  void operator()(const Device& d, typename TTypes<T>::ConstFlat a,
+                  typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) {
+    c.device(d) = a + b;
+  }
+};
+
+template <typename Device, typename T>
+struct TensorZeroPadding {
+  void operator()(const Device& d, const int64 time_idx,
+                  typename TTypes<int64>::ConstVec seq_len,
+                  typename TTypes<float>::Vec mask,
+                  typename TTypes<float>::Matrix m) {
+    // mask is shape [batch_size].
+    mask.device(d) = seq_len.constant(time_idx) < seq_len;
+
+    // m_shape is [batch_size, 1].
+    Eigen::array<Eigen::DenseIndex, 2> m_shape({m.dimensions()[0], 1});
+    // broadcast_shape is [1, units].
+    Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({1, m.dimensions()[1]});
+
+    // m is shape [batch_size, units].
+    m.device(d) = m * mask.reshape(m_shape).broadcast(broadcast_shape);
+  }
+};
+
+template <typename T>
+struct TensorCuBlasGemm {
+  void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
+                  bool transa, bool transb, uint64 m, uint64 n, uint64 k,
+                  T alpha, const T* a, int lda, const T* b, int ldb, T beta,
+                  T* c, int ldc);
+};
+
+template <typename Device, typename T, bool USE_CUBLAS>
+struct TensorBlasGemm;
+
+template <typename Device, typename T>
+struct TensorBlasGemm<Device, T, true /* USE_CUBLAS */> {
+  static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
+                      const Device& d, bool transa, bool transb, T alpha,
+                      typename TTypes<T>::ConstMatrix a,
+                      typename TTypes<T>::ConstMatrix b, T beta,
+                      typename TTypes<T>::Matrix c) {
+    int64 m = c.dimensions()[0];
+    int64 n = c.dimensions()[1];
+    int64 k = transa ? a.dimensions()[0] : a.dimensions()[1];
+
+    TensorCuBlasGemm<T>()(ctx, stream, transb, transa, n, m, k, alpha, b.data(),
+                          transb ? k : n, a.data(), transa ? m : k, beta,
+                          c.data(), n);
+  }
+};
+
+template <typename Device, typename T>
+struct TensorBlasGemm<Device, T, false /* USE_CUBLAS */> {
+  static void compute(OpKernelContext* ctx, perftools::gputools::Stream* stream,
+                      const Device& d, bool transa, bool transb, T alpha,
+                      typename TTypes<T>::ConstMatrix a,
+                      typename TTypes<T>::ConstMatrix b, T beta,
+                      typename TTypes<T>::Matrix c) {
+    Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
+    contract_pairs[0] =
+        Eigen::IndexPair<Eigen::DenseIndex>(transa == false, transb == true);
+    if (alpha == T(1) && beta == T(0)) {
+      c.device(d) = a.contract(b, contract_pairs);
+    } else if (alpha == T(1) && beta == T(1)) {
+      c.device(d) += a.contract(b, contract_pairs);
+    } else {
+      c.device(d) = c.constant(alpha) * a.contract(b, contract_pairs) +
+                    c.constant(beta) * c;
+    }
+  }
+};
+
+struct LSTMFusedCell {
+  LSTMFusedCell(const int batch_size, const int input_size, const int cell_size)
+      : batch_size_(batch_size),
+        input_size_(input_size),
+        cell_size_(cell_size) {}
+
+  inline Eigen::array<Eigen::DenseIndex, 2> icfo_i_offsets() const {
+    return {0, 0};
+  }
+
+  inline Eigen::array<Eigen::DenseIndex, 2> icfo_c_offsets() const {
+    return {0, cell_size_};
+  }
+
+  inline Eigen::array<Eigen::DenseIndex, 2> icfo_f_offsets() const {
+    return {0, cell_size_ * 2};
+  }
+
+  inline Eigen::array<Eigen::DenseIndex, 2> icfo_o_offsets() const {
+    return {0, cell_size_ * 3};
+  }
+
+  inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const {
+    return {batch_size_, cell_size_};
+  }
+
+  inline Eigen::array<Eigen::DenseIndex, 2> xh_x_offsets() const {
+    return {0, 0};
+  }
+
+  inline Eigen::array<Eigen::DenseIndex, 2> xh_x_extents() const {
+    return {batch_size_, input_size_};
+  }
+
+  inline Eigen::array<Eigen::DenseIndex, 2> xh_h_offsets() const {
+    return {0, input_size_};
+  }
+
+  inline Eigen::array<Eigen::DenseIndex, 2> xh_h_extents() const {
+    return {batch_size_, cell_size_};
+  }
+
+ protected:
+  const int batch_size_;
+  const int input_size_;
+  const int cell_size_;
+};
+
+template <typename Device, typename T, bool USE_CUBLAS>
+struct LSTMFusedCellFprop : public LSTMFusedCell {
+  LSTMFusedCellFprop(const int batch_size, const int input_size,
+                     const int cell_size)
+      : LSTMFusedCell(batch_size, input_size, cell_size) {}
+
+  void operator()(OpKernelContext* ctx, perftools::gputools::Stream* stream,
+                  const Device& d, const T forget_bias, const T cell_clip,
+                  bool use_peephole, typename TTypes<T>::ConstMatrix x,
+                  typename TTypes<T>::ConstMatrix cs_prev,
+                  typename TTypes<T>::ConstMatrix h_prev,
+                  typename TTypes<T>::ConstMatrix w,
+                  typename TTypes<T>::ConstVec wci,
+                  typename TTypes<T>::ConstVec wcf,
+                  typename TTypes<T>::ConstVec wco,
+                  typename TTypes<T>::ConstVec b, typename TTypes<T>::Matrix xh,
+                  typename TTypes<T>::Matrix i, typename TTypes<T>::Matrix cs,
+                  typename TTypes<T>::Matrix f, typename TTypes<T>::Matrix o,
+                  typename TTypes<T>::Matrix ci, typename TTypes<T>::Matrix co,
+                  typename TTypes<T>::Matrix icfo,
+                  typename TTypes<T>::Matrix h) {
+    // Concat xh = [x, h].
+    xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
+    xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev;
+
+    // states1 = xh * w + b
+    typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
+    TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
+        ctx, stream, d, false, false, T(1), const_xh, w, T(0), icfo);
+    Eigen::array<Eigen::DenseIndex, 2> b_shape({1, b.dimensions()[0]});
+    Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
+    icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape);
+
+    Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_});
+    Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1});
+
+    // Input gate.
+    if (use_peephole) {
+      auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape);
+      i.device(d) =
+          (icfo.slice(icfo_i_offsets(), cell_extents()) + i_peep).sigmoid();
+    } else {
+      i.device(d) = icfo.slice(icfo_i_offsets(), cell_extents()).sigmoid();
+    }
+
+    // Cell input.
+    ci.device(d) = icfo.slice(icfo_c_offsets(), cell_extents()).tanh();
+
+    // Forget gate (w/ bias).
+    if (use_peephole) {
+      auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
+      f.device(d) = (icfo.slice(icfo_f_offsets(), cell_extents()) +
+                     f.constant(forget_bias) + f_peep)
+                        .sigmoid();
+    } else {
+      f.device(d) = (icfo.slice(icfo_f_offsets(), cell_extents()) +
+                     f.constant(forget_bias))
+                        .sigmoid();
+    }
+
+    // cs = ci .* i + f .* cs_prev
+    cs.device(d) = i * ci + f * cs_prev;
+
+    if (cell_clip > 0.0f) {
+      cs.device(d) =
+          cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op<T>());
+    }
+
+    // co = tanh(cs)
+    co.device(d) = cs.tanh();
+
+    // Output gate.
+    if (use_peephole) {
+      auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape);
+      o.device(d) =
+          (icfo.slice(icfo_o_offsets(), cell_extents()) + o_peep).sigmoid();
+    } else {
+      o.device(d) = icfo.slice(icfo_o_offsets(), cell_extents()).sigmoid();
+    }
+
+    // h = o .* co
+    h.device(d) = o * co;
+  }
+};
+
+template <typename Device, typename T, bool USE_CUBLAS>
+struct LSTMFusedCellBprop : public LSTMFusedCell {
+  LSTMFusedCellBprop(const int batch_size, const int input_size,
+                     const int cell_size)
+      : LSTMFusedCell(batch_size, input_size, cell_size) {}
+
+  void operator()(
+      OpKernelContext* ctx, perftools::gputools::Stream* stream,
+      const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x,
+      typename TTypes<T>::ConstMatrix cs_prev,
+      typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
+      typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
+      typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
+      typename TTypes<T>::ConstMatrix i, typename TTypes<T>::ConstMatrix cs,
+      typename TTypes<T>::ConstMatrix f, typename TTypes<T>::ConstMatrix o,
+      typename TTypes<T>::ConstMatrix ci, typename TTypes<T>::ConstMatrix co,
+      typename TTypes<T>::ConstMatrix cs_grad,
+      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
+      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
+      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
+      typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
+      typename TTypes<T>::Vec wci_grad, typename TTypes<T>::Vec wcf_grad,
+      typename TTypes<T>::Vec wco_grad) {
+    // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
+    do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
+
+    // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
+    dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
+
+    Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_});
+    Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1});
+    if (use_peephole) {
+      dcs.device(d) =
+          dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
+    }
+
+    // dci[t] = tanh'(ci[t]) dcs[t] i[t]
+    dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
+
+    // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
+    df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
+
+    // di[t] = sigm'(i[t]) dcs[t] ci[t]
+    di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
+
+    dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di;
+    dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci;
+    dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df;
+    dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_;
+
+    cs_prev_grad.device(d) = dcs * f;
+    if (use_peephole) {
+      cs_prev_grad.device(d) =
+          cs_prev_grad +
+          di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
+          df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
+    }
+
+    if (use_peephole) {
+      wci_grad.device(d) = (di * cs_prev).sum(Eigen::array<int, 1>({0}));
+      wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array<int, 1>({0}));
+      wco_grad.device(d) = (do_ * cs).sum(Eigen::array<int, 1>({0}));
+    }
+  }
+};
+
+template <typename Device, typename T, bool USE_CUBLAS>
+struct FusedLSTMBprop : public LSTMFusedCell {
+  FusedLSTMBprop(const int batch_size, const int input_size,
+                 const int cell_size)
+      : LSTMFusedCell(batch_size, input_size, cell_size) {}
+
+  void operator()(
+      OpKernelContext* ctx, perftools::gputools::Stream* stream,
+      const Device& d, bool use_peephole, typename TTypes<T>::ConstMatrix x,
+      typename TTypes<T>::ConstMatrix cs_prev,
+      typename TTypes<T>::ConstMatrix h_prev, typename TTypes<T>::ConstMatrix w,
+      typename TTypes<T>::ConstVec wci, typename TTypes<T>::ConstVec wcf,
+      typename TTypes<T>::ConstVec wco, typename TTypes<T>::ConstVec b,
+      typename TTypes<T>::Matrix xh, typename TTypes<T>::ConstMatrix i,
+      typename TTypes<T>::ConstMatrix cs, typename TTypes<T>::ConstMatrix f,
+      typename TTypes<T>::ConstMatrix o, typename TTypes<T>::ConstMatrix ci,
+      typename TTypes<T>::ConstMatrix co,
+      typename TTypes<T>::ConstMatrix cs_grad,
+      typename TTypes<T>::ConstMatrix h_grad, typename TTypes<T>::Matrix do_,
+      typename TTypes<T>::Matrix dcs, typename TTypes<T>::Matrix dci,
+      typename TTypes<T>::Matrix df, typename TTypes<T>::Matrix di,
+      typename TTypes<T>::Matrix dicfo, typename TTypes<T>::Matrix cs_prev_grad,
+      typename TTypes<T>::Matrix h_prev_grad,
+      typename TTypes<T>::Matrix xh_grad, typename TTypes<T>::Matrix x_grad,
+      typename TTypes<T>::Matrix w_grad, typename TTypes<T>::Vec wci_grad,
+      typename TTypes<T>::Vec wcf_grad, typename TTypes<T>::Vec wco_grad,
+      typename TTypes<T>::Vec b_grad) {
+    // do[t] = sigm'(o[t]) .* dh[t] .* co[t]
+    do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co;
+
+    // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1]
+    dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad;
+
+    Eigen::array<Eigen::DenseIndex, 2> p_shape({1, cell_size_});
+    Eigen::array<Eigen::DenseIndex, 2> p_broadcast_shape({batch_size_, 1});
+    if (use_peephole) {
+      dcs.device(d) =
+          dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape);
+    }
+
+    // dci[t] = tanh'(ci[t]) dcs[t] i[t]
+    dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i;
+
+    // df[t] = sigm'(f[t]) dcs[t] cs[t - 1]
+    df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev;
+
+    // di[t] = sigm'(i[t]) dcs[t] ci[t]
+    di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci;
+
+    dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di;
+    dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci;
+    dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df;
+    dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_;
+
+    cs_prev_grad.device(d) = dcs * f;
+    if (use_peephole) {
+      cs_prev_grad.device(d) =
+          cs_prev_grad +
+          di * wci.reshape(p_shape).broadcast(p_broadcast_shape) +
+          df * wcf.reshape(p_shape).broadcast(p_broadcast_shape);
+    }
+
+    // xh_grad.
+    typename TTypes<T>::ConstMatrix const_dicfo(dicfo.data(),
+                                                dicfo.dimensions());
+    TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
+        ctx, stream, d, false, true, T(1), const_dicfo, w, T(0), xh_grad);
+
+    // xh.
+    xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x;
+    xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev;
+    typename TTypes<T>::ConstMatrix const_xh(xh.data(), xh.dimensions());
+
+    // x_grad.
+    x_grad.device(d) = xh_grad.slice(xh_x_offsets(), xh_x_extents());
+    h_prev_grad.device(d) = xh_grad.slice(xh_h_offsets(), xh_h_extents());
+
+    // w_grad.
+    TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
+        ctx, stream, d, true, false, T(1), const_xh, const_dicfo, T(1), w_grad);
+
+    // b_grad.
+    b_grad.device(d) += dicfo.sum(Eigen::array<int, 1>({0}));
+
+    if (use_peephole) {
+      wci_grad.device(d) += (di * cs_prev).sum(Eigen::array<int, 1>({0}));
+      wcf_grad.device(d) += (df * cs_prev).sum(Eigen::array<int, 1>({0}));
+      wco_grad.device(d) += (do_ * cs).sum(Eigen::array<int, 1>({0}));
+    }
+  }
+};
+
+}  // namespace functor
+}  // namespace tensorflow
+
+#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_RNN_KERNELS_LSTM_OPS_H_
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc
new file mode 100644
index 00000000000..2c5e500c289
--- /dev/null
+++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc
@@ -0,0 +1,41 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/rnn/kernels/lstm_ops.h"
+
+namespace tensorflow {
+namespace functor {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define DEFINE_GPU_SPECS(T)                               \
+  template struct TensorZero<GPUDevice, T>;               \
+  template struct TensorCopy<GPUDevice, T>;               \
+  template struct TensorAdd<GPUDevice, T>;                \
+  template struct LSTMFusedCellFprop<GPUDevice, T, true>; \
+  template struct LSTMFusedCellBprop<GPUDevice, T, true>; \
+  template struct FusedLSTMBprop<GPUDevice, T, true>;
+
+DEFINE_GPU_SPECS(float);
+// DEFINE_GPU_SPECS(double);
+#undef DEFINE_GPU_SPECS
+
+}  // end namespace functor
+}  // end namespace tensorflow
+#endif  // GOOGLE_CUDA
diff --git a/tensorflow/contrib/rnn/ops/lstm_ops.cc b/tensorflow/contrib/rnn/ops/lstm_ops.cc
new file mode 100644
index 00000000000..a55c6232886
--- /dev/null
+++ b/tensorflow/contrib/rnn/ops/lstm_ops.cc
@@ -0,0 +1,180 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("LSTMFusedCell")
+    .Input("x: T")
+    .Input("cs_prev: T")
+    .Input("h_prev: T")
+    .Input("w: T")
+    .Input("wci: T")
+    .Input("wcf: T")
+    .Input("wco: T")
+    .Input("b: T")
+    .Output("i: T")
+    .Output("cs: T")
+    .Output("f: T")
+    .Output("o: T")
+    .Output("ci: T")
+    .Output("co: T")
+    .Output("h: T")
+    .Attr("forget_bias: float = 1.0")
+    .Attr("cell_clip: float = 3.0")
+    .Attr("use_peephole: bool = false")
+    .Attr("T: {float}")
+    .Doc(R"doc(
+Computes the LSTM cell forward propagation for 1 time step.
+
+This implementation uses 1 weight matrix and 1 bias vector, there is no
+diagonal peephole connection.
+
+This kernel op implements the following mathematical equations:
+
+```python
+xh = [x, h_prev]
+[i, f, ci, o] = xh * w + b
+f = f + forget_bias
+
+i = sigmoid(i)
+f = sigmoid(f)
+ci = tanh(ci)
+o = sigmoid(o)
+
+cs = ci .* i + cs_prev .* f
+co = tanh(cs)
+
+h = co .* o
+```
+
+forget_bias: The forget gate bias.
+x: The input to the LSTM cell.
+w: The weight matrix.
+b: The bias vector.
+i: The input gate.
+cs: The cell state before the tanh.
+f: The forget gate.
+o: The output gate.
+ci: The cell input.
+co: The cell after the tanh.
+h: The output h vector.
+)doc");
+
+REGISTER_OP("LSTMFusedCellGrad")
+    .Input("x: T")
+    .Input("cs_prev: T")
+    .Input("h_prev: T")
+    .Input("w: T")
+    .Input("wci: T")
+    .Input("wcf: T")
+    .Input("wco: T")
+    .Input("b: T")
+    .Input("i: T")
+    .Input("cs: T")
+    .Input("f: T")
+    .Input("o: T")
+    .Input("ci: T")
+    .Input("co: T")
+    .Input("cs_grad: T")
+    .Input("h_grad: T")
+    .Output("cs_prev_grad: T")
+    .Output("dicfo: T")
+    .Output("wci_grad: T")
+    .Output("wcf_grad: T")
+    .Output("wco_grad: T")
+    .Attr("use_peephole: bool")
+    .Attr("T: {float}")
+    .Doc(R"doc(
+Computes the LSTM cell backward propagation for 1 timestep.
+
+This implementation is to be used in conjunction of LSTMFusedCell.
+
+x: The input to the LSTM cell.
+cs_prev: The previous cell state.
+h_prev: The previous h state.
+w: The weight matrix.
+b: The bias vector.
+i: The input gate.
+cs: The cell state before the tanh.
+f: The forget gate.
+o: The output gate.
+ci: The cell input.
+co: The cell after the tanh.
+h_grad: THe gradient of h vector.
+cs_prev_grad: The gradient of cs.
+dicfo: The derivative wrt to [i, cs, f, o].
+)doc");
+
+REGISTER_OP("FusedLSTM")
+    .Input("seq_len_max: int64")
+    .Input("x: max_len * T")
+    .Input("cs_prev: T")
+    .Input("h_prev: T")
+    .Input("w: T")
+    .Input("wci: T")
+    .Input("wcf: T")
+    .Input("wco: T")
+    .Input("b: T")
+    .Output("i: max_len * T")
+    .Output("cs: max_len * T")
+    .Output("f: max_len * T")
+    .Output("o: max_len * T")
+    .Output("ci: max_len * T")
+    .Output("co: max_len * T")
+    .Output("h: max_len * T")
+    .Attr("max_len: int")
+    .Attr("forget_bias: float = 1.0")
+    .Attr("cell_clip: float = 3.0")
+    .Attr("use_peephole: bool = false")
+    .Attr("T: {float}")
+    .Doc(R"doc(
+)doc");
+
+REGISTER_OP("FusedLSTMGrad")
+    .Input("seq_len_max: int64")
+    .Input("x: max_len * T")
+    .Input("cs_prev: T")
+    .Input("h_prev: T")
+    .Input("w: T")
+    .Input("wci: T")
+    .Input("wcf: T")
+    .Input("wco: T")
+    .Input("b: T")
+    .Input("i: max_len * T")
+    .Input("cs: max_len * T")
+    .Input("f: max_len * T")
+    .Input("o: max_len * T")
+    .Input("ci: max_len * T")
+    .Input("co: max_len * T")
+    .Input("h: max_len * T")
+    .Input("cs_grad: max_len * T")
+    .Input("h_grad: max_len * T")
+    .Output("x_grad: max_len * T")
+    .Output("cs_prev_grad: T")
+    .Output("h_prev_grad: T")
+    .Output("w_grad: T")
+    .Output("wci_grad: T")
+    .Output("wcf_grad: T")
+    .Output("wco_grad: T")
+    .Output("b_grad: T")
+    .Attr("max_len: int")
+    .Attr("use_peephole: bool")
+    .Attr("T: {float}")
+    .Doc(R"doc(
+)doc");
+
+}  // end namespace tensorflow
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
new file mode 100644
index 00000000000..70aeb5ff559
--- /dev/null
+++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
@@ -0,0 +1,290 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""LSTM Fused Cell ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+import tensorflow as tf
+
+from tensorflow.contrib.rnn.python.ops import lstm_ops
+
+
+fused_lstm = lstm_ops._fused_lstm  # pylint: disable=protected-access
+
+
+class LSTMFusedCellTest(tf.test.TestCase):
+  _use_gpu = False
+
+  def testNoneDimsWithDynamicRNN(self):
+    with self.test_session(use_gpu=self._use_gpu, graph=tf.Graph()) as sess:
+      batch_size = 4
+      num_steps = 5
+      input_dim = 6
+      cell_size = 7
+
+      cell = tf.contrib.rnn.LSTMFusedCell(cell_size)
+      x = tf.placeholder(tf.float32, shape=(None, None, input_dim))
+
+      output, _ = tf.nn.dynamic_rnn(cell, x, time_major=True, dtype=tf.float32)
+      sess.run(tf.initialize_all_variables())
+      feed = {}
+      feed[x] = np.random.randn(num_steps, batch_size, input_dim)
+      sess.run(output, feed)
+
+  def testLSTMFusedCell(self):
+    with self.test_session(use_gpu=self._use_gpu, graph=tf.Graph()) as sess:
+      with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+        x = tf.zeros([1, 2])
+        m0 = tf.zeros([1, 2])
+        m1 = tf.zeros([1, 2])
+        m2 = tf.zeros([1, 2])
+        m3 = tf.zeros([1, 2])
+        g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell(
+            [tf.contrib.rnn.LSTMFusedCell(2)] * 2,
+            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+        sess.run([tf.initialize_all_variables()])
+        res = sess.run([g, out_m0, out_m1, out_m2, out_m3],
+                       {x.name: np.array([[1., 1.]]),
+                        m0.name: 0.1 * np.ones([1, 2]),
+                        m1.name: 0.1 * np.ones([1, 2]),
+                        m2.name: 0.1 * np.ones([1, 2]),
+                        m3.name: 0.1 * np.ones([1, 2])})
+        self.assertEqual(len(res), 5)
+        self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
+        # These numbers are from testBasicLSTMCell and only test c/h.
+        self.assertAllClose(res[1], [[0.68967271, 0.68967271]])
+        self.assertAllClose(res[2], [[0.44848421, 0.44848421]])
+        self.assertAllClose(res[3], [[0.39897051, 0.39897051]])
+        self.assertAllClose(res[4], [[0.24024698, 0.24024698]])
+
+  def testLSTMBasicToBlockCell(self):
+    with self.test_session(use_gpu=self._use_gpu) as sess:
+      x = tf.zeros([1, 2])
+      x_values = np.random.randn(1, 2)
+
+      m0_val = 0.1 * np.ones([1, 2])
+      m1_val = -0.1 * np.ones([1, 2])
+      m2_val = -0.2 * np.ones([1, 2])
+      m3_val = 0.2 * np.ones([1, 2])
+
+      initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=19890212)
+      with tf.variable_scope("basic", initializer=initializer):
+        m0 = tf.zeros([1, 2])
+        m1 = tf.zeros([1, 2])
+        m2 = tf.zeros([1, 2])
+        m3 = tf.zeros([1, 2])
+        g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell(
+            [tf.nn.rnn_cell.BasicLSTMCell(2, state_is_tuple=True)] * 2,
+            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+        sess.run([tf.initialize_all_variables()])
+        basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3],
+                             {x.name: x_values,
+                              m0.name: m0_val,
+                              m1.name: m1_val,
+                              m2.name: m2_val,
+                              m3.name: m3_val})
+
+      with tf.variable_scope("block", initializer=initializer):
+        m0 = tf.zeros([1, 2])
+        m1 = tf.zeros([1, 2])
+        m2 = tf.zeros([1, 2])
+        m3 = tf.zeros([1, 2])
+        g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell(
+            [tf.contrib.rnn.LSTMFusedCell(2)] * 2,
+            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+        sess.run([tf.initialize_all_variables()])
+        block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3],
+                             {x.name: x_values,
+                              m0.name: m0_val,
+                              m1.name: m1_val,
+                              m2.name: m2_val,
+                              m3.name: m3_val})
+
+      self.assertEqual(len(basic_res), len(block_res))
+      for basic, block in zip(basic_res, block_res):
+        self.assertAllClose(basic, block)
+
+  def testLSTMBasicToBlockCellPeeping(self):
+    with self.test_session(use_gpu=self._use_gpu) as sess:
+      x = tf.zeros([1, 2])
+      x_values = np.random.randn(1, 2)
+
+      m0_val = 0.1 * np.ones([1, 2])
+      m1_val = -0.1 * np.ones([1, 2])
+      m2_val = -0.2 * np.ones([1, 2])
+      m3_val = 0.2 * np.ones([1, 2])
+
+      initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=19890212)
+      with tf.variable_scope("basic", initializer=initializer):
+        m0 = tf.zeros([1, 2])
+        m1 = tf.zeros([1, 2])
+        m2 = tf.zeros([1, 2])
+        m3 = tf.zeros([1, 2])
+        g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell(
+            [tf.nn.rnn_cell.LSTMCell(2,
+                                     use_peepholes=True,
+                                     state_is_tuple=True)] * 2,
+            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+        sess.run([tf.initialize_all_variables()])
+        basic_res = sess.run([g, out_m0, out_m1, out_m2, out_m3],
+                             {x.name: x_values,
+                              m0.name: m0_val,
+                              m1.name: m1_val,
+                              m2.name: m2_val,
+                              m3.name: m3_val})
+
+      with tf.variable_scope("block", initializer=initializer):
+        m0 = tf.zeros([1, 2])
+        m1 = tf.zeros([1, 2])
+        m2 = tf.zeros([1, 2])
+        m3 = tf.zeros([1, 2])
+        g, ((out_m0, out_m1), (out_m2, out_m3)) = tf.nn.rnn_cell.MultiRNNCell(
+            [tf.contrib.rnn.LSTMFusedCell(2, use_peephole=True)] * 2,
+            state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
+        sess.run([tf.initialize_all_variables()])
+        block_res = sess.run([g, out_m0, out_m1, out_m2, out_m3],
+                             {x.name: x_values,
+                              m0.name: m0_val,
+                              m1.name: m1_val,
+                              m2.name: m2_val,
+                              m3.name: m3_val})
+
+      self.assertEqual(len(basic_res), len(block_res))
+      for basic, block in zip(basic_res, block_res):
+        self.assertAllClose(basic, block)
+
+  def testLSTMBasicToBlock(self):
+    with self.test_session(use_gpu=self._use_gpu) as sess:
+      batch_size = 2
+      input_size = 3
+      cell_size = 4
+      sequence_length = 5
+
+      inputs = []
+      for _ in range(sequence_length):
+        inp = tf.convert_to_tensor(
+            np.random.randn(batch_size, input_size),
+            dtype=tf.float32)
+        inputs.append(inp)
+
+      initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=19890212)
+      with tf.variable_scope("basic", initializer=initializer):
+        cell = tf.nn.rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True)
+        outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+
+        sess.run([tf.initialize_all_variables()])
+        basic_outputs = sess.run(outputs)
+        basic_grads = sess.run(tf.gradients(outputs, inputs))
+        basic_wgrads = sess.run(tf.gradients(outputs, tf.trainable_variables()))
+
+      with tf.variable_scope("block", initializer=initializer):
+        w = tf.get_variable("w",
+                            shape=[input_size + cell_size, cell_size * 4],
+                            dtype=tf.float32)
+        b = tf.get_variable("b",
+                            shape=[cell_size * 4],
+                            dtype=tf.float32,
+                            initializer=tf.zeros_initializer)
+
+        _, _, _, _, _, _, outputs = fused_lstm(
+            tf.convert_to_tensor(sequence_length,
+                                 dtype=tf.int64),
+            inputs,
+            w,
+            b,
+            cell_clip=0)
+
+        sess.run([tf.initialize_all_variables()])
+        block_outputs = sess.run(outputs)
+        block_grads = sess.run(tf.gradients(outputs, inputs))
+        block_wgrads = sess.run(tf.gradients(outputs, [w, b]))
+
+      self.assertAllClose(basic_outputs, block_outputs)
+      self.assertAllClose(basic_grads, block_grads)
+      for basic, block in zip(basic_wgrads, block_wgrads):
+        self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2)
+
+  def testLSTMBasicToBlockPeeping(self):
+    with self.test_session(use_gpu=self._use_gpu) as sess:
+      batch_size = 2
+      input_size = 3
+      cell_size = 4
+      sequence_length = 5
+
+      inputs = []
+      for _ in range(sequence_length):
+        inp = tf.convert_to_tensor(
+            np.random.randn(batch_size, input_size),
+            dtype=tf.float32)
+        inputs.append(inp)
+
+      initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=19890212)
+      with tf.variable_scope("basic", initializer=initializer):
+        cell = tf.nn.rnn_cell.LSTMCell(cell_size,
+                                       use_peepholes=True,
+                                       state_is_tuple=True)
+        outputs, _ = tf.nn.rnn(cell, inputs, dtype=tf.float32)
+
+        sess.run([tf.initialize_all_variables()])
+        basic_outputs = sess.run(outputs)
+        basic_grads = sess.run(tf.gradients(outputs, inputs))
+        basic_wgrads = sess.run(tf.gradients(outputs, tf.trainable_variables()))
+
+      with tf.variable_scope("block", initializer=initializer):
+        w = tf.get_variable("w",
+                            shape=[input_size + cell_size, cell_size * 4],
+                            dtype=tf.float32)
+        b = tf.get_variable("b",
+                            shape=[cell_size * 4],
+                            dtype=tf.float32,
+                            initializer=tf.zeros_initializer)
+
+        wci = tf.get_variable("wci", shape=[cell_size], dtype=tf.float32)
+        wcf = tf.get_variable("wcf", shape=[cell_size], dtype=tf.float32)
+        wco = tf.get_variable("wco", shape=[cell_size], dtype=tf.float32)
+
+        _, _, _, _, _, _, outputs = fused_lstm(
+            tf.convert_to_tensor(sequence_length,
+                                 dtype=tf.int64),
+            inputs,
+            w,
+            b,
+            wci=wci,
+            wcf=wcf,
+            wco=wco,
+            cell_clip=0,
+            use_peephole=True)
+
+        sess.run([tf.initialize_all_variables()])
+        block_outputs = sess.run(outputs)
+        block_grads = sess.run(tf.gradients(outputs, inputs))
+        block_wgrads = sess.run(tf.gradients(outputs, [w, b, wci, wcf, wco]))
+
+      self.assertAllClose(basic_outputs, block_outputs)
+      self.assertAllClose(basic_grads, block_grads)
+      for basic, block in zip(basic_wgrads, block_wgrads):
+        self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2)
+
+
+class LSTMFusedCellGpuTest(LSTMFusedCellTest):
+  _use_gpu = True
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
new file mode 100644
index 00000000000..2ecc415d351
--- /dev/null
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -0,0 +1,456 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""LSTM Fused Cell ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import load_library
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.platform import resource_loader
+
+_lstm_ops_so = load_library.load_op_library(
+    resource_loader.get_path_to_datafile("_lstm_ops.so"))
+assert _lstm_ops_so, "Could not load _lstm_ops.so."
+
+
+# pylint: disable=invalid-name
+def _lstm_fused_cell(x,
+                     cs_prev,
+                     h_prev,
+                     w,
+                     b,
+                     wci=None,
+                     wcf=None,
+                     wco=None,
+                     forget_bias=None,
+                     cell_clip=None,
+                     use_peephole=None,
+                     name=None):
+  r"""Computes the LSTM cell forward propagation for 1 time step.
+
+  This implementation uses 1 weight matrix and 1 bias vector, there is no
+  diagonal peephole connection.
+
+  This kernel op implements the following mathematical equations:
+
+  ```python
+  xh = [x, h_prev]
+  [i, f, ci, o] = xh * w + b
+  f = f + forget_bias
+
+  i = sigmoid(i)
+  f = sigmoid(f)
+  ci = tanh(ci)
+  o = sigmoid(o)
+
+  cs = ci .* i + cs_prev .* f
+  co = tanh(cs)
+
+  h = co .* o
+  ```
+
+  Args:
+    x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
+      The input to the LSTM cell.
+    cs_prev: A `Tensor`. Must have the same type as `x`.
+    h_prev: A `Tensor`. Must have the same type as `x`.
+    w: A `Tensor`. Must have the same type as `x`. The weight matrix.
+    b: A `Tensor`. Must have the same type as `x`. The bias vector.
+    wci: A `Tensor`. Must have the same type as `x`.
+    wcf: A `Tensor`. Must have the same type as `x`.
+    wco: A `Tensor`. Must have the same type as `x`.
+    forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
+    cell_clip: An optional `float`. Defaults to `3`.
+    use_peephole: An optional `bool`. Defaults to `False`.
+    name: A name for the operation (optional).
+
+  Returns:
+    A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
+    i: A `Tensor`. Has the same type as `x`. The input gate.
+    cs: A `Tensor`. Has the same type as `x`. The cell state before the tanh.
+    f: A `Tensor`. Has the same type as `x`. The forget gate.
+    o: A `Tensor`. Has the same type as `x`. The output gate.
+    ci: A `Tensor`. Has the same type as `x`. The cell input.
+    co: A `Tensor`. Has the same type as `x`. The cell after the tanh.
+    h: A `Tensor`. Has the same type as `x`. The output h vector.
+
+  Raises:
+    ValueError: If cell_size is None.
+  """
+  if wci is None:
+    cell_size = cs_prev.get_shape().with_rank(2)[1].value
+    if cell_size is None:
+      raise ValueError("cell_size from `cs_prev` should not be None.")
+    wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size])
+    wco = wci
+    wcf = wci
+
+  # pylint: disable=protected-access
+  return _lstm_ops_so.lstm_fused_cell(x=x,
+                                      cs_prev=cs_prev,
+                                      h_prev=h_prev,
+                                      w=w,
+                                      wci=wci,
+                                      wco=wco,
+                                      wcf=wcf,
+                                      b=b,
+                                      forget_bias=forget_bias,
+                                      cell_clip=cell_clip,
+                                      use_peephole=use_peephole,
+                                      name=name)
+  # pylint: enable=protected-access
+
+
+def _fused_lstm(seq_len_max,
+                x,
+                w,
+                b,
+                cs_prev=None,
+                h_prev=None,
+                wci=None,
+                wcf=None,
+                wco=None,
+                forget_bias=None,
+                cell_clip=None,
+                use_peephole=None,
+                name=None):
+  r"""TODO(williamchan): add doc.
+
+  Args:
+    seq_len_max: A `Tensor` of type `int64`.
+    x: A list of at least 1 `Tensor` objects of the same type in: `float32`.
+    w: A `Tensor`. Must have the same type as `x`.
+    b: A `Tensor`. Must have the same type as `x`.
+    cs_prev: A `Tensor`. Must have the same type as `x`.
+    h_prev: A `Tensor`. Must have the same type as `x`.
+    wci: A `Tensor`. Must have the same type as `x`.
+    wcf: A `Tensor`. Must have the same type as `x`.
+    wco: A `Tensor`. Must have the same type as `x`.
+    forget_bias: An optional `float`. Defaults to `1`.
+    cell_clip: An optional `float`. Defaults to `3`.
+    use_peephole: An optional `bool`. Defaults to `False`.
+    name: A name for the operation (optional).
+
+  Returns:
+    A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
+    i: A list with the same number of `Tensor` objects as `x` of `Tensor`
+    objects of the same type as x.
+    cs: A list with the same number of `Tensor` objects as `x` of `Tensor`
+    objects of the same type as x.
+    f: A list with the same number of `Tensor` objects as `x` of `Tensor`
+    objects of the same type as x.
+    o: A list with the same number of `Tensor` objects as `x` of `Tensor`
+    objects of the same type as x.
+    ci: A list with the same number of `Tensor` objects as `x` of `Tensor`
+    objects of the same type as x.
+    co: A list with the same number of `Tensor` objects as `x` of `Tensor`
+    objects of the same type as x.
+    h: A list with the same number of `Tensor` objects as `x` of `Tensor`
+    objects of the same type as x.
+
+  Raises:
+    ValueError: If `b` does not have a valid shape.
+  """
+  batch_size = x[0].get_shape().with_rank(2)[0].value
+  cell_size4 = b.get_shape().with_rank(1)[0].value
+  if cell_size4 is None:
+    raise ValueError("`b` shape must not be None.")
+  cell_size = cell_size4 / 4
+  zero_state = None
+  if cs_prev is None or h_prev is None:
+    zero_state = array_ops.constant(0,
+                                    dtype=dtypes.float32,
+                                    shape=[batch_size, cell_size])
+  if cs_prev is None:
+    cs_prev = zero_state
+  if h_prev is None:
+    h_prev = zero_state
+  if wci is None:
+    wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size])
+    wco = wci
+    wcf = wci
+
+  # pylint: disable=protected-access
+  return _lstm_ops_so.fused_lstm(seq_len_max=seq_len_max,
+                                 x=x,
+                                 cs_prev=cs_prev,
+                                 h_prev=h_prev,
+                                 w=w,
+                                 wci=wci,
+                                 wco=wco,
+                                 wcf=wcf,
+                                 b=b,
+                                 forget_bias=forget_bias,
+                                 cell_clip=cell_clip,
+                                 name=name,
+                                 use_peephole=use_peephole)
+  # pylint: enable=protected-access
+  # pylint: enable=invalid-name
+
+
+ops.RegisterShape("LSTMFusedCell")(None)
+_lstm_fused_cell_grad_outputs = ["cs_prev_grad", "dicfo"]
+
+
+@ops.RegisterShape("LSTMFusedCell")
+def _LSTMFusedCellShape(op):
+  batch_size = op.inputs[0].get_shape().with_rank(2)[0].value
+  cell_size = op.inputs[1].get_shape().with_rank(2)[1].value
+
+  return (tensor_shape.TensorShape([batch_size, cell_size]),
+          tensor_shape.TensorShape([batch_size, cell_size]),
+          tensor_shape.TensorShape([batch_size, cell_size]),
+          tensor_shape.TensorShape([batch_size, cell_size]),
+          tensor_shape.TensorShape([batch_size, cell_size]),
+          tensor_shape.TensorShape([batch_size, cell_size]),
+          tensor_shape.TensorShape([batch_size, cell_size]))
+
+
+@ops.RegisterGradient("LSTMFusedCell")
+def _LSTMFusedCellGrad(op, *grad):
+  """Gradient for LSTMFusedCell."""
+  (x, cs_prev, h_prev, w, wci, wco, wcf, b) = op.inputs
+  (i, cs, f, o, ci, co, _) = op.outputs
+  (_, cs_grad, _, _, _, _, h_grad) = grad
+
+  batch_size = x.get_shape().with_rank(2)[0].value
+  if batch_size is None:
+    batch_size = -1
+  input_size = x.get_shape().with_rank(2)[1].value
+  if input_size is None:
+    raise ValueError("input_size from `x` should not be None.")
+  cell_size = cs_prev.get_shape().with_rank(2)[1].value
+  if cell_size is None:
+    raise ValueError("cell_size from `cs_prev` should not be None.")
+
+  (cs_prev_grad, dicfo, wci_grad, wcf_grad,
+   wco_grad) = _lstm_ops_so.lstm_fused_cell_grad(
+       x,
+       cs_prev,
+       h_prev,
+       w,
+       wci,
+       wcf,
+       wco,
+       b,
+       i,
+       cs,
+       f,
+       o,
+       ci,
+       co,
+       cs_grad,
+       h_grad,
+       use_peephole=op.get_attr("use_peephole"))
+
+  # Backprop from dicfo to xh.
+  xh_grad = math_ops.matmul(dicfo, w, transpose_b=True)
+
+  x_grad = array_ops.slice(xh_grad, (0, 0), (batch_size, input_size))
+  x_grad.get_shape().merge_with(x.get_shape())
+
+  h_prev_grad = array_ops.slice(xh_grad, (0, input_size),
+                                (batch_size, cell_size))
+  h_prev_grad.get_shape().merge_with(h_prev.get_shape())
+
+  # Backprop from dicfo to w.
+  xh = array_ops.concat(1, [x, h_prev])
+  w_grad = math_ops.matmul(xh, dicfo, transpose_a=True)
+  w_grad.get_shape().merge_with(w.get_shape())
+
+  # Backprop from dicfo to b.
+  b_grad = nn_ops.bias_add_grad(dicfo)
+  b_grad.get_shape().merge_with(b.get_shape())
+
+  return (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
+          wco_grad, b_grad)
+
+
+@ops.RegisterShape("LSTMFusedCellGrad")
+def _LSTMFusedCellGradShape(op):
+  batch_size = op.inputs[0].get_shape().with_rank(2)[0].value
+  cell_size = op.inputs[1].get_shape().with_rank(2)[1].value
+
+  return [tensor_shape.TensorShape([batch_size, cell_size]),
+          tensor_shape.TensorShape([batch_size, cell_size * 4]),
+          tensor_shape.TensorShape([cell_size]),
+          tensor_shape.TensorShape([cell_size]),
+          tensor_shape.TensorShape([cell_size])]
+
+
+@ops.RegisterShape("FusedLSTM")
+def _FusedLSTMShape(op):
+  max_len = op.get_attr("max_len")
+
+  x = op.inputs[1]
+  b = op.inputs[-1]
+
+  batch_size = x.get_shape().with_rank(2)[0].value
+  cell_size = b.get_shape().with_rank(1)[0].value / 4
+
+  return [tensor_shape.TensorShape([batch_size, cell_size])] * max_len * 7
+
+
+@ops.RegisterGradient("FusedLSTM")
+def _FusedLSTMGrad(op, *grad):
+  """Gradient for FusedLSTM."""
+  max_len = op.get_attr("max_len")
+
+  seq_len_max = op.inputs[0]
+  x = op.inputs[1:1 + max_len]
+  cs_prev = op.inputs[-7]
+  h_prev = op.inputs[-6]
+  w = op.inputs[-5]
+  wci = op.inputs[-4]
+  wco = op.inputs[-3]
+  wcf = op.inputs[-2]
+  b = op.inputs[-1]
+
+  i = op.outputs[0 * max_len:1 * max_len]
+  cs = op.outputs[1 * max_len:2 * max_len]
+  f = op.outputs[2 * max_len:3 * max_len]
+  o = op.outputs[3 * max_len:4 * max_len]
+  ci = op.outputs[4 * max_len:5 * max_len]
+  co = op.outputs[5 * max_len:6 * max_len]
+  h = op.outputs[6 * max_len:7 * max_len]
+
+  cs_grad = grad[-max_len * 2:-max_len]
+  h_grad = grad[-max_len:]
+
+  (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, wcf_grad,
+   b_grad) = _lstm_ops_so.fused_lstm_grad(
+       seq_len_max,
+       x,
+       cs_prev,
+       h_prev,
+       w,
+       wci,
+       wco,
+       wcf,
+       b,
+       i,
+       cs,
+       f,
+       o,
+       ci,
+       co,
+       h,
+       cs_grad,
+       h_grad,
+       use_peephole=op.get_attr("use_peephole"))
+
+  return [None] + x_grad + [cs_prev_grad, h_prev_grad, w_grad, wci_grad,
+                            wco_grad, wcf_grad, b_grad]
+
+
+@ops.RegisterShape("FusedLSTMGrad")
+def _FusedLSTMGradShape(op):
+  """Shape for FusedLSTM."""
+  max_len = op.get_attr("max_len")
+
+  x = op.inputs[1]
+  cs_prev = op.inputs[1 + max_len]
+  h_prev = op.inputs[2 + max_len]
+  w = op.inputs[3 + max_len]
+  wci = op.inputs[4 + max_len]
+  wco = op.inputs[5 + max_len]
+  wcf = op.inputs[6 + max_len]
+  b = op.inputs[7 + max_len]
+
+  x_shape = x.get_shape().with_rank(2)
+  cs_prev_shape = cs_prev.get_shape().with_rank(2)
+  h_prev_shape = h_prev.get_shape().with_rank(2)
+  w_shape = w.get_shape().with_rank(2)
+  wci_shape = wci.get_shape().with_rank(1)
+  wco_shape = wco.get_shape().with_rank(1)
+  wcf_shape = wcf.get_shape().with_rank(1)
+  b_shape = b.get_shape().with_rank(1)
+
+  return [x_shape] * max_len + [cs_prev_shape, h_prev_shape, w_shape, wci_shape,
+                                wco_shape, wcf_shape, b_shape]
+
+
+class LSTMFusedCell(rnn_cell.RNNCell):
+  """Basic LSTM recurrent network cell.
+
+  The implementation is based on: http://arxiv.org/abs/1409.2329.
+
+  We add forget_bias (default: 1) to the biases of the forget gate in order to
+  reduce the scale of forgetting in the beginning of the training.
+
+  Unlike BasicLSTMCell, this is a monolithic op and should be much faster. The
+  weight and bias matrixes should be compatible as long as the variabel scope
+  matches.
+  """
+
+  def __init__(self, num_units, forget_bias=1.0, use_peephole=False):
+    """Initialize the basic LSTM cell.
+
+    Args:
+      num_units: int, The number of units in the LSTM cell.
+      forget_bias: float, The bias added to forget gates (see above).
+      use_peephole: Whether to use peephole connectios or not.
+    """
+    self._num_units = num_units
+    self._forget_bias = forget_bias
+    self._use_peephole = use_peephole
+
+  @property
+  def state_size(self):
+    return (self._num_units,) * 2
+
+  @property
+  def output_size(self):
+    return self._num_units
+
+  def __call__(self, x, states_prev, scope=None):
+    """Long short-term memory cell (LSTM)."""
+    with vs.variable_scope(scope or type(self).__name__):
+      x_shape = x.get_shape().with_rank(2)
+      if not x_shape[1]:
+        raise ValueError("Expecting x_shape[1] to be sets: %s" % str(x_shape))
+      if len(states_prev) != 2:
+        raise ValueError("Expecting states_prev to be a tuple with length 2.")
+      input_size = x_shape[1]
+      w = vs.get_variable("W", [input_size + self._num_units,
+                                self._num_units * 4])
+      b = vs.get_variable("b", [w.get_shape().with_rank(2)[1]],
+                          initializer=init_ops.constant_initializer(0.0))
+      wci = vs.get_variable("wci", [self._num_units])
+      wco = vs.get_variable("wco", [self._num_units])
+      wcf = vs.get_variable("wcf", [self._num_units])
+      (cs_prev, h_prev) = states_prev
+      (_, cs, _, _, _, _, h) = _lstm_fused_cell(x,
+                                                cs_prev,
+                                                h_prev,
+                                                w,
+                                                b,
+                                                wci=wci,
+                                                wco=wco,
+                                                wcf=wcf,
+                                                forget_bias=self._forget_bias,
+                                                use_peephole=self._use_peephole)
+
+      return (h, (cs, h))
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 7d00e73f90a..0ea41e10102 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -27,12 +27,6 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import rnn_cell
 from tensorflow.python.ops import variable_scope as vs
-from tensorflow.python.ops.math_ops import reduce_sum
-from tensorflow.python.ops.math_ops import sigmoid
-from tensorflow.python.ops.math_ops import tanh
-from tensorflow.python.ops.nn_ops import conv2d
-from tensorflow.python.ops.nn_ops import softmax
-
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import nest
 
@@ -104,7 +98,7 @@ class CoupledInputForgetGateLSTMCell(rnn_cell.RNNCell):
                initializer=None, num_proj=None, proj_clip=None,
                num_unit_shards=1, num_proj_shards=1,
                forget_bias=1.0, state_is_tuple=False,
-               activation=tanh):
+               activation=math_ops.tanh):
     """Initialize the parameters for an LSTM cell.
 
     Args:
@@ -188,6 +182,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell.RNNCell):
       ValueError: If input size cannot be inferred from inputs via
         static shape inference.
     """
+    sigmoid = math_ops.sigmoid
+
     num_proj = self._num_units if self._num_proj is None else self._num_proj
 
     if self._state_is_tuple:
@@ -322,6 +318,8 @@ class TimeFreqLSTMCell(rnn_cell.RNNCell):
       ValueError: if an input_size was specified and the provided inputs have
         a different dimension.
     """
+    sigmoid = math_ops.sigmoid
+    tanh = math_ops.tanh
 
     freq_inputs = self._make_tf_features(inputs)
     dtype = inputs.dtype
@@ -489,6 +487,8 @@ class GridLSTMCell(rnn_cell.RNNCell):
       ValueError: if an input_size was specified and the provided inputs have
         a different dimension.
     """
+    sigmoid = math_ops.sigmoid
+    tanh = math_ops.tanh
 
     freq_inputs = self._make_tf_features(inputs)
     dtype = inputs.dtype
@@ -771,6 +771,11 @@ class AttentionCellWrapper(rnn_cell.RNNCell):
       return output, new_state
 
   def _attention(self, query, attn_states):
+    conv2d = nn_ops.conv2d
+    reduce_sum = math_ops.reduce_sum
+    softmax = nn_ops.softmax
+    tanh = math_ops.tanh
+
     with vs.variable_scope("Attention"):
       k = vs.get_variable("AttnW", [1, 1, self._attn_size, self._attn_vec_size])
       v = vs.get_variable("AttnV", [self._attn_vec_size])
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index da0e8ca4c95..abf4f0ee1f1 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -570,6 +570,7 @@ filegroup(
     name = "android_srcs",
     srcs = [
         ":proto_text_srcs_all",
+        "//tensorflow/core/debug:android_srcs",
         "//tensorflow/core/kernels:android_srcs",
         "//tensorflow/core/platform/default/build_config:android_srcs",
         "//tensorflow/core/util/ctc:android_srcs",
@@ -580,8 +581,6 @@ filegroup(
             "client/**/*.cc",
             "common_runtime/**/*.h",
             "common_runtime/**/*.cc",
-            "debug/**/*.h",
-            "debug/**/*.cc",
             "framework/**/*.h",
             "framework/**/*.cc",
             "graph/**/*.h",
@@ -1103,49 +1102,13 @@ tf_cuda_library(
     linkstatic = 1,
     deps = [
         ":core_cpu_internal",
-        ":debug_graph_utils",
         ":framework",
         ":gpu_tracer",
         ":lib",
         ":lib_internal",
         ":proto_text",
         ":protos_all_cc",
-    ],
-    alwayslink = 1,
-)
-
-tf_cuda_library(
-    name = "debug_gateway_internal",
-    srcs = ["debug/debug_gateway.cc"],
-    hdrs = ["debug/debug_gateway.h"],
-    copts = tf_copts(),
-    linkstatic = 1,
-    deps = [
-        ":core_cpu_internal",
-        ":direct_session_internal",
-        ":framework",
-        ":gpu_tracer",
-        ":lib",
-        ":lib_internal",
-        ":proto_text",
-        ":protos_all_cc",
-    ],
-    alwayslink = 1,
-)
-
-tf_cuda_library(
-    name = "debug_graph_utils",
-    srcs = ["debug/debug_graph_utils.cc"],
-    hdrs = ["debug/debug_graph_utils.h"],
-    copts = tf_copts(),
-    linkstatic = 1,
-    deps = [
-        ":core_cpu_internal",
-        ":framework",
-        ":lib",
-        ":lib_internal",
-        ":proto_text",
-        ":protos_all_cc",
+        "//tensorflow/core/debug:debug_graph_utils",
     ],
     alwayslink = 1,
 )
@@ -1604,35 +1567,6 @@ tf_cc_test(
     ],
 )
 
-tf_cc_test_gpu(
-    name = "debug/debug_gateway_test",
-    size = "small",
-    args = ["--heap_check=local"],
-    linkstatic = tf_kernel_tests_linkstatic(),
-    tags = tf_cuda_tests_tags() + ["nomac"],
-    deps = [
-        ":all_kernels",
-        ":core_cpu",
-        ":core_cpu_internal",
-        ":debug_gateway_internal",
-        ":debug_graph_utils",
-        ":direct_session",
-        ":direct_session_internal",
-        ":framework",
-        ":framework_internal",
-        ":gpu_runtime",
-        ":lib",
-        ":lib_internal",
-        ":protos_all_cc",
-        ":test",
-        ":test_main",
-        ":testlib",
-        "//tensorflow/cc:cc_ops",
-        "//tensorflow/core/kernels:debug_ops",
-        "//tensorflow/core/kernels:ops_util",
-    ],
-)
-
 tf_cc_test(
     name = "common_runtime/direct_session_with_tracking_alloc_test",
     size = "small",
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
new file mode 100644
index 00000000000..da4c45520e1
--- /dev/null
+++ b/tensorflow/core/debug/BUILD
@@ -0,0 +1,157 @@
+# Description:
+# TensorFlow Debugger (tfdbg).
+#
+# Public Android targets:
+# filegroup ":android_srcs" - Debugger source files for Android.
+
+package(
+    default_visibility = ["//tensorflow:internal"],
+    features = ["-parse_headers"],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_copts",
+    "tf_cc_test",
+    "tf_cuda_library",
+)
+load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
+
+# For platform specific build config
+load(
+    "//tensorflow/core:platform/default/build_config.bzl",
+    "tf_kernel_tests_linkstatic",
+)
+load(
+    "//tensorflow/core:platform/default/build_config_root.bzl",
+    "tf_cuda_tests_tags",
+)
+
+tf_cuda_library(
+    name = "debug_gateway_internal",
+    srcs = ["debug_gateway.cc"],
+    hdrs = ["debug_gateway.h"],
+    copts = tf_copts(),
+    linkstatic = 1,
+    deps = [
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:direct_session_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_tracer",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:proto_text",
+        "//tensorflow/core:protos_all_cc",
+    ],
+    alwayslink = 1,
+)
+
+tf_cuda_library(
+    name = "debug_graph_utils",
+    srcs = ["debug_graph_utils.cc"],
+    hdrs = ["debug_graph_utils.h"],
+    copts = tf_copts(),
+    linkstatic = 1,
+    deps = [
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:proto_text",
+        "//tensorflow/core:protos_all_cc",
+    ],
+    alwayslink = 1,
+)
+
+tf_cuda_library(
+    name = "debug_io_utils",
+    srcs = ["debug_io_utils.cc"],
+    hdrs = ["debug_io_utils.h"],
+    copts = tf_copts(),
+    linkstatic = 1,
+    deps = [
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:proto_text",
+        "//tensorflow/core:protos_all_cc",
+    ],
+    alwayslink = 1,
+)
+
+tf_cc_test_gpu(
+    name = "debug_gateway_test",
+    size = "small",
+    args = ["--heap_check=local"],
+    linkstatic = tf_kernel_tests_linkstatic(),
+    tags = tf_cuda_tests_tags() + ["nomac"],
+    deps = [
+        ":debug_gateway_internal",
+        ":debug_graph_utils",
+        "//tensorflow/cc:cc_ops",
+        "//tensorflow/core:all_kernels",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:direct_session",
+        "//tensorflow/core:direct_session_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:gpu_runtime",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+        "//tensorflow/core/kernels:debug_ops",
+        "//tensorflow/core/kernels:ops_util",
+    ],
+)
+
+tf_cc_test(
+    name = "debug_io_utils_test",
+    size = "small",
+    linkstatic = tf_kernel_tests_linkstatic(),
+    deps = [
+        ":debug_io_utils",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
+filegroup(
+    name = "android_srcs",
+    srcs = [
+        "debug_graph_utils.cc",
+        "debug_graph_utils.h",
+    ],
+    visibility = ["//visibility:public"],
+)
+
+# -----------------------------------------------------------------------------
+# Google-internal targets.  These must be at the end for syncrepo.
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc
index 8374b245cb6..118847686d3 100644
--- a/tensorflow/core/debug/debug_graph_utils.cc
+++ b/tensorflow/core/debug/debug_graph_utils.cc
@@ -36,6 +36,8 @@ Status DebugNodeInserter::InsertNodes(
   // A map from tensor name (e.g., "node_a:0") to list of debug op names
   // (e.g., {"DebugIdentity", "DebugNanCount"})
   std::unordered_map<string, std::vector<string>> tensor_watches;
+  // A map from tensor name to debug_url.
+  std::unordered_map<string, std::vector<string>> tensor_watch_urls;
 
   // Cache the proto content for fast lookup later
   for (const DebugTensorWatch& watch : watches) {
@@ -58,6 +60,12 @@ Status DebugNodeInserter::InsertNodes(
     }
 
     tensor_watches[tensor_name] = debug_ops;
+
+    std::vector<string> urls;
+    for (const string& url : watch.debug_urls()) {
+      urls.push_back(url);
+    }
+    tensor_watch_urls[tensor_name] = urls;
   }
 
   if (tensor_watches.empty()) {
@@ -150,9 +158,9 @@ Status DebugNodeInserter::InsertNodes(
         const string& debug_op_name = tensor_watches[tensor_name][i];
 
         Node* debug_node;
-        Status debug_s =
-            CreateDebugNode(graph, device_type, copy_node->name(), src_dt,
-                            tensor_name, i, debug_op_name, &debug_node);
+        Status debug_s = CreateDebugNode(
+            graph, device_type, copy_node->name(), src_dt, tensor_name,
+            tensor_watch_urls[tensor_name], i, debug_op_name, &debug_node);
         if (!debug_s.ok()) {
           return Status(
               error::FAILED_PRECONDITION,
@@ -267,17 +275,17 @@ Status DebugNodeInserter::CreateCopyNode(
 Status DebugNodeInserter::CreateDebugNode(
     Graph* graph, const DeviceType device_type,
     const string& src_copy_node_name, const DataType src_dt,
-    const string& tensor_name, const int debug_op_num,
-    const string& debug_op_name, Node** debug_node) {
+    const string& tensor_name, const std::vector<string>& debug_urls,
+    const int debug_op_num, const string& debug_op_name, Node** debug_node) {
   NodeDef node_def;
   const KernelDef* kdef;
 
   const string debug_node_name =
       GetDebugNodeName(tensor_name, debug_op_num, debug_op_name);
-  // TODO(cais): Hook up with DebugTensorWatch proto
   auto builder = NodeDefBuilder(debug_node_name, debug_op_name)
                      .Input(src_copy_node_name, 0, src_dt)
-                     .Attr("tensor_name", tensor_name);
+                     .Attr("tensor_name", tensor_name)
+                     .Attr("debug_urls", debug_urls);
 
   if (!builder.Finalize(&node_def).ok()) {
     return Status(
diff --git a/tensorflow/core/debug/debug_graph_utils.h b/tensorflow/core/debug/debug_graph_utils.h
index 41789a30ffe..ea61dee4d08 100644
--- a/tensorflow/core/debug/debug_graph_utils.h
+++ b/tensorflow/core/debug/debug_graph_utils.h
@@ -94,6 +94,7 @@ class DebugNodeInserter {
                                 const string& src_copy_node_name,
                                 const DataType src_dt,
                                 const string& tensor_name,
+                                const std::vector<string>& debug_urls,
                                 const int debug_op_num,
                                 const string& debug_op_name, Node** debug_node);
 };
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
new file mode 100644
index 00000000000..474577a79c0
--- /dev/null
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -0,0 +1,211 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/debug/debug_io_utils.h"
+
+#include <vector>
+
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/event.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Encapsulate the tensor value inside a Summary proto, and then inside an
+// Event proto.
+Event WrapTensorAsEvent(const string& tensor_name, const string& debug_op,
+                        const Tensor& tensor, const uint64 wall_time_us) {
+  Event event;
+  event.set_wall_time(static_cast<double>(wall_time_us));
+
+  Summary::Value* summ_val = event.mutable_summary()->add_value();
+
+  // Create the debug node_name in the Summary proto.
+  // For example, if tensor_name = "foo/node_a:0", and the debug_op is
+  // "DebugIdentity", the debug node_name in the Summary proto will be
+  // "foo/node_a:0:DebugIdentity".
+  const string debug_node_name = strings::StrCat(tensor_name, ":", debug_op);
+  summ_val->set_node_name(debug_node_name);
+
+  if (tensor.dtype() == DT_STRING) {
+    // Treat DT_STRING specially, so that tensor_util.MakeNdarray can convert
+    // the TensorProto to string-type numpy array. MakeNdarray does not work
+    // with strings encoded by AsProtoTensorContent() in tensor_content.
+    tensor.AsProtoField(summ_val->mutable_tensor());
+  } else {
+    tensor.AsProtoTensorContent(summ_val->mutable_tensor());
+  }
+
+  return event;
+}
+
+}  // namespace
+
+// static
+const char* const DebugIO::kFileURLScheme = "file://";
+// static
+const char* const DebugIO::kGrpcURLScheme = "grpc://";
+
+Status DebugIO::PublishDebugTensor(const string& tensor_name,
+                                   const string& debug_op, const Tensor& tensor,
+                                   const uint64 wall_time_us,
+                                   const gtl::ArraySlice<string>& debug_urls) {
+  // Split the tensor_name into node name and output slot index.
+  std::vector<string> name_items = str_util::Split(tensor_name, ':');
+  string node_name;
+  int32 output_slot = 0;
+  if (name_items.size() == 2) {
+    node_name = name_items[0];
+    if (!strings::safe_strto32(name_items[1], &output_slot)) {
+      return Status(error::INVALID_ARGUMENT,
+                    strings::StrCat("Invalid string value for output_slot: \"",
+                                    name_items[1], "\""));
+    }
+  } else if (name_items.size() == 1) {
+    node_name = name_items[0];
+  } else {
+    return Status(
+        error::INVALID_ARGUMENT,
+        strings::StrCat("Failed to parse tensor name: \"", tensor_name, "\""));
+  }
+
+  int num_failed_urls = 0;
+  for (const string& url : debug_urls) {
+    if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
+      const string dump_root_dir = url.substr(strlen(kFileURLScheme));
+
+      Status s =
+          DebugFileIO::DumpTensorToDir(node_name, output_slot, debug_op, tensor,
+                                       wall_time_us, dump_root_dir, nullptr);
+      if (!s.ok()) {
+        num_failed_urls++;
+      }
+    } else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
+      // TODO(cais): Implement PublishTensor with grpc urls.
+      return Status(error::UNIMPLEMENTED,
+                    strings::StrCat("Puslishing to GRPC debug target is not ",
+                                    "implemented yet"));
+    } else {
+      return Status(error::UNAVAILABLE,
+                    strings::StrCat("Invalid debug target URL: ", url));
+    }
+  }
+
+  if (num_failed_urls == 0) {
+    return Status::OK();
+  } else {
+    return Status(
+        error::INTERNAL,
+        strings::StrCat("Puslishing to ", num_failed_urls, " of ",
+                        debug_urls.size(), " debug target URLs failed"));
+  }
+}
+
+// static
+Status DebugFileIO::DumpTensorToDir(
+    const string& node_name, const int32 output_slot, const string& debug_op,
+    const Tensor& tensor, const uint64 wall_time_us,
+    const string& dump_root_dir, string* dump_file_path) {
+  const string file_path = GetDumpFilePath(dump_root_dir, node_name,
+                                           output_slot, debug_op, wall_time_us);
+
+  if (dump_file_path != nullptr) {
+    *dump_file_path = file_path;
+  }
+
+  return DumpTensorToEventFile(node_name, output_slot, debug_op, tensor,
+                               wall_time_us, file_path);
+}
+
+// static
+string DebugFileIO::GetDumpFilePath(const string& dump_root_dir,
+                                    const string& node_name,
+                                    const int32 output_slot,
+                                    const string& debug_op,
+                                    const uint64 wall_time_us) {
+  return io::JoinPath(
+      dump_root_dir, strings::StrCat(node_name, "_", output_slot, "_", debug_op,
+                                     "_", wall_time_us));
+}
+
+// static
+Status DebugFileIO::DumpTensorToEventFile(
+    const string& node_name, const int32 output_slot, const string& debug_op,
+    const Tensor& tensor, const uint64 wall_time_us, const string& file_path) {
+  Env* env(Env::Default());
+
+  // Create the directory if necessary.
+  string file_dir = io::Dirname(file_path).ToString();
+  Status s = DebugFileIO::RecursiveCreateDir(env, file_dir);
+
+  if (!s.ok()) {
+    return Status(error::FAILED_PRECONDITION,
+                  strings::StrCat("Failed to create directory  ", file_dir,
+                                  ", due to: ", s.error_message()));
+  }
+
+  const string tensor_name = strings::StrCat(node_name, ":", output_slot);
+  Event event = WrapTensorAsEvent(tensor_name, debug_op, tensor, wall_time_us);
+
+  string event_str;
+  event.SerializeToString(&event_str);
+
+  std::unique_ptr<WritableFile> f = nullptr;
+  TF_CHECK_OK(env->NewWritableFile(file_path, &f));
+  f->Append(event_str);
+  TF_CHECK_OK(f->Close());
+
+  return Status::OK();
+}
+
+// static
+Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
+  if (env->FileExists(dir) && env->IsDirectory(dir).ok()) {
+    // The path already exists as a directory. Return OK right away.
+    return Status::OK();
+  }
+
+  string parent_dir = io::Dirname(dir).ToString();
+  if (!env->FileExists(parent_dir)) {
+    // The parent path does not exist yet, create it first.
+    Status s = RecursiveCreateDir(env, parent_dir);  // Recursive call
+    if (!s.ok()) {
+      return Status(
+          error::FAILED_PRECONDITION,
+          strings::StrCat("Failed to create directory  ", parent_dir));
+    }
+  } else if (env->FileExists(parent_dir) &&
+             !env->IsDirectory(parent_dir).ok()) {
+    // The path exists, but it is a file.
+    return Status(error::FAILED_PRECONDITION,
+                  strings::StrCat("Failed to create directory  ", parent_dir,
+                                  " because the path exists as a file "));
+  }
+
+  env->CreateDir(dir);
+  // Guard against potential race in creating directories by doing a check
+  // after the CreateDir call.
+  if (env->FileExists(dir) && env->IsDirectory(dir).ok()) {
+    return Status::OK();
+  } else {
+    return Status(error::ABORTED,
+                  strings::StrCat("Failed to create directory  ", parent_dir));
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
new file mode 100644
index 00000000000..553ae9ab7d2
--- /dev/null
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -0,0 +1,107 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
+#define TENSORFLOW_DEBUG_IO_UTILS_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+class DebugIO {
+ public:
+  // Publish a tensor to a debug target URL.
+  //
+  // Args:
+  //   tensor_name: Name of the tensor being published: node_name followed by
+  //     a colon, followed by the output slot index. E.g., "node_a:0".
+  //   debug_op: Name of the debug op, e.g., "DebugIdentity".
+  //   tensor: The Tensor object being published.
+  //   wall_time_us: Time stamp for the Tensor. Unit: microseconds (us).
+  //   debug_urls: An array of debug target URLs, e.g.,
+  //     "file:///foo/tfdbg_dump", "grpc://localhot:11011"
+  static Status PublishDebugTensor(const string& tensor_name,
+                                   const string& debug_op, const Tensor& tensor,
+                                   const uint64 wall_time_us,
+                                   const gtl::ArraySlice<string>& debug_urls);
+
+ private:
+  static const char* const kFileURLScheme;
+  static const char* const kGrpcURLScheme;
+};
+
+// Helper class for debug ops.
+class DebugFileIO {
+ public:
+  // Encapsulate the Tensor in an Event protobuf and write it to a directory.
+  // The actual path of the dump file will be a contactenation of
+  // dump_root_dir, tensor_name, along with the wall_time.
+  //
+  // For example:
+  //   let dump_root_dir = "/tmp/tfdbg_dump",
+  //       node_name = "foo/bar",
+  //       output_slot = 0,
+  //       debug_op = DebugIdentity,
+  //       and wall_time_us = 1467891234512345,
+  // the dump file will be generated at path:
+  //   /tmp/tfdbg_dump/foo/bar_0_DebugIdentity_1467891234512345.
+  //
+  // Args:
+  //   node_name: Name of the node from which the tensor is output.
+  //   output_slot: Output slot index.
+  //   debug_op: Name of the debug op, e.g., "DebugIdentity".
+  //   tensor: The Tensor object to be dumped to file.
+  //   wall_time_us: Wall time at which the Tensor is generated during graph
+  //     execution. Unit: microseconds (us).
+  //   dump_root_dir: Root diretory for dumping the tensor.
+  //   dump_file_path: The actual dump file path (passed as reference).
+  static Status DumpTensorToDir(const string& node_name,
+                                const int32 output_slot, const string& debug_op,
+                                const Tensor& tensor, const uint64 wall_time_us,
+                                const string& dump_root_dir,
+                                string* dump_file_path);
+
+  // Get the full path to the dump file.
+  //
+  // Args:
+  //   dump_root_dir: The dump root directory, e.g., /tmp/tfdbg_dump
+  //   node_name: Name of the node from which the dumped tensor is generated,
+  //     e.g., foo/bar/node_a
+  //   output_slot: Output slot index of the said node, e.g., 0.
+  //   debug_op: Name of the debug op, e.g., DebugIdentity.
+  //   wall_time_us: Time stamp of the dumped tensor, in microseconds (us).
+  static string GetDumpFilePath(const string& dump_root_dir,
+                                const string& node_name,
+                                const int32 output_slot, const string& debug_op,
+                                const uint64 wall_time_us);
+
+ private:
+  // Encapsulate the Tensor in an Event protobuf and write it to file.
+  static Status DumpTensorToEventFile(
+      const string& node_name, const int32 output_slot, const string& debug_op,
+      const Tensor& tensor, const uint64 wall_time_us, const string& file_path);
+
+  // Implemented ad hoc here for now.
+  // TODO(cais): Replace with shared implementation once http://b/30497715 is
+  // fixed.
+  static Status RecursiveCreateDir(Env* env, const string& dir);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_DEBUG_IO_UTILS_H_
diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc
new file mode 100644
index 00000000000..ecdda643c3a
--- /dev/null
+++ b/tensorflow/core/debug/debug_io_utils_test.cc
@@ -0,0 +1,382 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/debug/debug_io_utils.h"
+
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/util/event.pb.h"
+
+namespace tensorflow {
+namespace {
+
+class DebugIOUtilsTest : public ::testing::Test {
+ public:
+  void Initialize() {
+    env_ = Env::Default();
+
+    tensor_a_.reset(new Tensor(DT_FLOAT, TensorShape({2, 2})));
+    tensor_a_->flat<float>()(0) = 5.0;
+    tensor_a_->flat<float>()(1) = 3.0;
+    tensor_a_->flat<float>()(2) = -1.0;
+    tensor_a_->flat<float>()(3) = 0.0;
+
+    tensor_b_.reset(new Tensor(DT_STRING, TensorShape{2}));
+    tensor_b_->flat<string>()(0) = "corge";
+    tensor_b_->flat<string>()(1) = "garply";
+  }
+
+  Status ReadEventFromFile(const string& dump_file_path, Event* event) {
+    string content;
+    uint64 file_size = 0;
+
+    Status s = env_->GetFileSize(dump_file_path, &file_size);
+    if (!s.ok()) {
+      return s;
+    }
+
+    content.resize(file_size);
+
+    std::unique_ptr<RandomAccessFile> file;
+    s = env_->NewRandomAccessFile(dump_file_path, &file);
+    if (!s.ok()) {
+      return s;
+    }
+
+    StringPiece result;
+    s = file->Read(0, file_size, &result, &(content)[0]);
+    if (!s.ok()) {
+      return s;
+    }
+
+    event->ParseFromString(content);
+    return Status::OK();
+  }
+
+  Env* env_;
+  std::unique_ptr<Tensor> tensor_a_;
+  std::unique_ptr<Tensor> tensor_b_;
+};
+
+TEST_F(DebugIOUtilsTest, DumpFloatTensorToFileSunnyDay) {
+  Initialize();
+
+  const string test_dir = testing::TmpDir();
+
+  // Append levels of nonexisting directories, to test that the function can
+  // create directories.
+  const string kNodeName = "foo/bar/qux/tensor_a";
+  const string kDebugOpName = "DebugIdentity";
+  const int32 output_slot = 0;
+  uint64 wall_time = env_->NowMicros();
+
+  string dump_file_path;
+  TF_ASSERT_OK(DebugFileIO::DumpTensorToDir(kNodeName, output_slot,
+                                            kDebugOpName, *tensor_a_, wall_time,
+                                            test_dir, &dump_file_path));
+
+  // Read the file into a Event proto.
+  Event event;
+  TF_ASSERT_OK(ReadEventFromFile(dump_file_path, &event));
+
+  ASSERT_GE(wall_time, event.wall_time());
+  ASSERT_EQ(1, event.summary().value().size());
+  ASSERT_EQ(strings::StrCat(kNodeName, ":", output_slot, ":", kDebugOpName),
+            event.summary().value(0).node_name());
+
+  Tensor a_prime(DT_FLOAT);
+  ASSERT_TRUE(a_prime.FromProto(event.summary().value(0).tensor()));
+
+  // Verify tensor shape and value.
+  ASSERT_EQ(tensor_a_->shape(), a_prime.shape());
+  for (int i = 0; i < a_prime.flat<float>().size(); ++i) {
+    ASSERT_EQ(tensor_a_->flat<float>()(i), a_prime.flat<float>()(i));
+  }
+
+  // Tear down temporary file and directories.
+  int64 undeleted_files = 0;
+  int64 undeleted_dirs = 0;
+  ASSERT_TRUE(
+      env_->DeleteRecursively(test_dir, &undeleted_files, &undeleted_dirs)
+          .ok());
+  ASSERT_EQ(0, undeleted_files);
+  ASSERT_EQ(0, undeleted_dirs);
+}
+
+TEST_F(DebugIOUtilsTest, DumpStringTensorToFileSunnyDay) {
+  Initialize();
+
+  const string test_dir = testing::TmpDir();
+
+  const string kNodeName = "quux/grault/tensor_b";
+  const string kDebugOpName = "DebugIdentity";
+  const int32 output_slot = 1;
+  uint64 wall_time = env_->NowMicros();
+
+  string dump_file_name;
+  Status s = DebugFileIO::DumpTensorToDir(kNodeName, output_slot, kDebugOpName,
+                                          *tensor_b_, wall_time, test_dir,
+                                          &dump_file_name);
+  ASSERT_TRUE(s.ok());
+
+  // Read the file into a Event proto.
+  Event event;
+  TF_ASSERT_OK(ReadEventFromFile(dump_file_name, &event));
+
+  ASSERT_GE(wall_time, event.wall_time());
+  ASSERT_EQ(1, event.summary().value().size());
+  ASSERT_EQ(strings::StrCat(kNodeName, ":", output_slot, ":", kDebugOpName),
+            event.summary().value(0).node_name());
+
+  Tensor b_prime(DT_STRING);
+  ASSERT_TRUE(b_prime.FromProto(event.summary().value(0).tensor()));
+
+  // Verify tensor shape and value.
+  ASSERT_EQ(tensor_b_->shape(), b_prime.shape());
+  for (int i = 0; i < b_prime.flat<string>().size(); ++i) {
+    ASSERT_EQ(tensor_b_->flat<string>()(i), b_prime.flat<string>()(i));
+  }
+
+  // Tear down temporary file and directories.
+  int64 undeleted_files = 0;
+  int64 undeleted_dirs = 0;
+  ASSERT_TRUE(
+      env_->DeleteRecursively(test_dir, &undeleted_files, &undeleted_dirs)
+          .ok());
+  ASSERT_EQ(0, undeleted_files);
+  ASSERT_EQ(0, undeleted_dirs);
+}
+
+TEST_F(DebugIOUtilsTest, DumpTensorToFileCannotCreateDirectory) {
+  Initialize();
+
+  // First, create the file at the path.
+  const string test_dir = testing::TmpDir();
+  const string txt_file_name = strings::StrCat(test_dir, "/baz");
+
+  if (!env_->FileExists(test_dir)) {
+    ASSERT_TRUE(env_->CreateDir(test_dir).ok());
+  }
+  ASSERT_FALSE(env_->FileExists(txt_file_name));
+
+  std::unique_ptr<WritableFile> file;
+  ASSERT_TRUE(env_->NewWritableFile(txt_file_name, &file).ok());
+  file->Append("text in baz");
+  file->Flush();
+  file->Close();
+
+  // Verify that the path exists and that it is a file, not a directory.
+  ASSERT_TRUE(env_->FileExists(txt_file_name));
+  ASSERT_FALSE(env_->IsDirectory(txt_file_name).ok());
+
+  // Second, try to dump the tensor to a path that requires "baz" to be a
+  // directory, which should lead to an error.
+  const string kNodeName = "baz/tensor_a";
+  const string kDebugOpName = "DebugIdentity";
+  const int32 output_slot = 0;
+  uint64 wall_time = env_->NowMicros();
+
+  string dump_file_name;
+  Status s = DebugFileIO::DumpTensorToDir(kNodeName, output_slot, kDebugOpName,
+                                          *tensor_a_, wall_time, test_dir,
+                                          &dump_file_name);
+  ASSERT_FALSE(s.ok());
+
+  // Tear down temporary file and directories.
+  int64 undeleted_files = 0;
+  int64 undeleted_dirs = 0;
+  ASSERT_TRUE(
+      env_->DeleteRecursively(test_dir, &undeleted_files, &undeleted_dirs)
+          .ok());
+  ASSERT_EQ(0, undeleted_files);
+  ASSERT_EQ(0, undeleted_dirs);
+}
+
+TEST_F(DebugIOUtilsTest, PublishTensorToMultipleFileURLs) {
+  Initialize();
+
+  const int kNumDumpRoots = 3;
+  const string kNodeName = "foo/bar/qux/tensor_a";
+  const string kDebugOpName = "DebugIdentity";
+  const int32 output_slot = 0;
+
+  uint64 wall_time = env_->NowMicros();
+
+  std::vector<string> dump_roots;
+  std::vector<string> dump_file_paths;
+  std::vector<string> urls;
+  for (int i = 0; i < kNumDumpRoots; ++i) {
+    string dump_root = strings::StrCat(testing::TmpDir(), "/", i);
+
+    dump_roots.push_back(dump_root);
+    dump_file_paths.push_back(DebugFileIO::GetDumpFilePath(
+        dump_root, kNodeName, output_slot, kDebugOpName, wall_time));
+    urls.push_back(strings::StrCat("file://", dump_root));
+  }
+
+  for (int i = 1; i < kNumDumpRoots; ++i) {
+    ASSERT_NE(dump_roots[0], dump_roots[i]);
+  }
+
+  const string tensor_name = strings::StrCat(kNodeName, ":", output_slot);
+  const string debug_node_name =
+      strings::StrCat(tensor_name, ":", kDebugOpName);
+  Status s = DebugIO::PublishDebugTensor(tensor_name, kDebugOpName, *tensor_a_,
+                                         wall_time, urls);
+  ASSERT_TRUE(s.ok());
+
+  // Try reading the file into a Event proto.
+  for (int i = 0; i < kNumDumpRoots; ++i) {
+    // Read the file into a Event proto.
+    Event event;
+    TF_ASSERT_OK(ReadEventFromFile(dump_file_paths[i], &event));
+
+    ASSERT_GE(wall_time, event.wall_time());
+    ASSERT_EQ(1, event.summary().value().size());
+    ASSERT_EQ(debug_node_name, event.summary().value(0).node_name());
+
+    Tensor a_prime(DT_FLOAT);
+    ASSERT_TRUE(a_prime.FromProto(event.summary().value(0).tensor()));
+
+    // Verify tensor shape and value.
+    ASSERT_EQ(tensor_a_->shape(), a_prime.shape());
+    for (int i = 0; i < a_prime.flat<float>().size(); ++i) {
+      ASSERT_EQ(tensor_a_->flat<float>()(i), a_prime.flat<float>()(i));
+    }
+  }
+
+  // Tear down temporary file and directories.
+  for (int i = 0; i < kNumDumpRoots; ++i) {
+    int64 undeleted_files = 0;
+    int64 undeleted_dirs = 0;
+    ASSERT_TRUE(env_->DeleteRecursively(dump_roots[i], &undeleted_files,
+                                        &undeleted_dirs)
+                    .ok());
+    ASSERT_EQ(0, undeleted_files);
+    ASSERT_EQ(0, undeleted_dirs);
+  }
+}
+
+TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) {
+  Initialize();
+
+  const int kConcurrentPubs = 3;
+  const string kNodeName = "tensor_a";
+  const string kDebugOpName = "DebugIdentity";
+  const int32 kOutputSlot = 0;
+
+  thread::ThreadPool* tp =
+      new thread::ThreadPool(Env::Default(), "test", kConcurrentPubs);
+  uint64 wall_time = env_->NowMicros();
+
+  const string dump_root_base = testing::TmpDir();
+  const string tensor_name = strings::StrCat(kNodeName, ":", kOutputSlot);
+  const string debug_node_name =
+      strings::StrCat(tensor_name, ":", kDebugOpName);
+
+  mutex mu;
+  std::vector<string> dump_roots GUARDED_BY(mu);
+  std::vector<string> dump_file_paths GUARDED_BY(mu);
+
+  int dump_count GUARDED_BY(mu) = 0;
+  int done_count GUARDED_BY(mu) = 0;
+  Notification all_done;
+
+  auto fn = [this, &dump_count, &done_count, &mu, &dump_root_base, &dump_roots,
+             &dump_file_paths, &wall_time, &tensor_name, &debug_node_name,
+             &kNodeName, &kDebugOpName, &kConcurrentPubs, &all_done]() {
+    // "gumpy" is the shared directory part of the path.
+    string dump_root;
+    string debug_url;
+    {
+      mutex_lock l(mu);
+      dump_root =
+          strings::StrCat(dump_root_base, "grumpy/", "dump_", dump_count++);
+
+      dump_roots.push_back(dump_root);
+      dump_file_paths.push_back(DebugFileIO::GetDumpFilePath(
+          dump_root, kNodeName, kOutputSlot, kDebugOpName, wall_time));
+
+      debug_url = strings::StrCat("file://", dump_root);
+    }
+
+    std::vector<string> urls;
+    urls.push_back(debug_url);
+    Status s = DebugIO::PublishDebugTensor(tensor_name, kDebugOpName,
+                                           *tensor_a_, wall_time, urls);
+    ASSERT_TRUE(s.ok());
+
+    {
+      mutex_lock l(mu);
+
+      done_count++;
+      if (done_count == kConcurrentPubs) {
+        all_done.Notify();
+      }
+    }
+  };
+
+  for (int i = 0; i < kConcurrentPubs; ++i) {
+    tp->Schedule(fn);
+  }
+
+  // Wait for all dumping calls to finish.
+  all_done.WaitForNotification();
+  delete tp;
+
+  {
+    mutex_lock l(mu);
+
+    for (int i = 1; i < kConcurrentPubs; ++i) {
+      ASSERT_NE(dump_roots[0], dump_roots[i]);
+    }
+
+    // Try reading the file into a Event proto.
+    for (int i = 0; i < kConcurrentPubs; ++i) {
+      // Read the file into a Event proto.
+      Event event;
+      TF_ASSERT_OK(ReadEventFromFile(dump_file_paths[i], &event));
+
+      ASSERT_GE(wall_time, event.wall_time());
+      ASSERT_EQ(1, event.summary().value().size());
+      ASSERT_EQ(debug_node_name, event.summary().value(0).node_name());
+
+      Tensor a_prime(DT_FLOAT);
+      ASSERT_TRUE(a_prime.FromProto(event.summary().value(0).tensor()));
+
+      // Verify tensor shape and value.
+      ASSERT_EQ(tensor_a_->shape(), a_prime.shape());
+      for (int i = 0; i < a_prime.flat<float>().size(); ++i) {
+        ASSERT_EQ(tensor_a_->flat<float>()(i), a_prime.flat<float>()(i));
+      }
+    }
+
+    // Tear down temporary file and directories.
+    int64 undeleted_files = 0;
+    int64 undeleted_dirs = 0;
+    ASSERT_TRUE(env_->DeleteRecursively(dump_root_base, &undeleted_files,
+                                        &undeleted_dirs)
+                    .ok());
+    ASSERT_EQ(0, undeleted_files);
+    ASSERT_EQ(0, undeleted_dirs);
+  }
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 8d9066ab52c..a078488dd18 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -423,6 +423,7 @@ tf_kernel_libraries(
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:proto_text",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/debug:debug_io_utils",
         "//third_party/eigen3",
     ],
 )
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 3e46970812f..8132cf1f6b0 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -17,6 +17,7 @@ limitations under the License.
 #define TENSORFLOW_KERNELS_DEBUG_OP_H_
 
 #include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/debug/debug_io_utils.h"
 #include "tensorflow/core/framework/device_base.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/tensor_util.h"
@@ -73,10 +74,16 @@ class DebugIdentityOp : public OpKernel {
  public:
   explicit DebugIdentityOp(OpKernelConstruction* context) : OpKernel(context) {
     OP_REQUIRES_OK(context, context->GetAttr("tensor_name", &tensor_name_));
-    // TODO(cais): Add debug_url
+    OP_REQUIRES_OK(context, context->GetAttr("debug_urls", &debug_urls_));
   }
 
   void Compute(OpKernelContext* context) override {
+    if (!debug_urls_.empty()) {
+      DebugIO::PublishDebugTensor(tensor_name_, "DebugIdentity",
+                                  context->input(0),
+                                  Env::Default()->NowMicros(), debug_urls_);
+    }
+
     context->set_output(0, context->input(0));
   }
 
@@ -84,6 +91,7 @@ class DebugIdentityOp : public OpKernel {
 
  private:
   string tensor_name_;
+  std::vector<string> debug_urls_;
 };
 
 // NaN-counter op for debugging.
@@ -92,6 +100,7 @@ class DebugNanCountOp : public OpKernel {
  public:
   explicit DebugNanCountOp(OpKernelConstruction* context) : OpKernel(context) {
     OP_REQUIRES_OK(context, context->GetAttr("tensor_name", &tensor_name_));
+    OP_REQUIRES_OK(context, context->GetAttr("debug_urls", &debug_urls_));
   }
 
   void Compute(OpKernelContext* context) override {
@@ -120,6 +129,7 @@ class DebugNanCountOp : public OpKernel {
 
  private:
   string tensor_name_;
+  std::vector<string> debug_urls_;
 };
 
 // TODO(cais): Add DebugInfinityCount
diff --git a/tensorflow/core/kernels/debug_ops_test.cc b/tensorflow/core/kernels/debug_ops_test.cc
index e584d43e22d..e526754d316 100644
--- a/tensorflow/core/kernels/debug_ops_test.cc
+++ b/tensorflow/core/kernels/debug_ops_test.cc
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include <dirent.h>
+#include <string.h>
+#include <fstream>
+#include <vector>
+
 #include "tensorflow/core/framework/fake_input.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/node_def_builder.h"
@@ -22,20 +27,32 @@ limitations under the License.
 #include "tensorflow/core/kernels/ops_testutil.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/event.pb.h"
 
 namespace tensorflow {
 namespace {
 
 class DebugIdentityOpTest : public OpsTestBase {
  protected:
-  Status Init(DataType input_type) {
+  Status Init(DataType input_type, const std::vector<string> debug_urls) {
+    env_ = Env::Default();
+
     TF_CHECK_OK(NodeDefBuilder("op", "DebugIdentity")
                     .Input(FakeInput(input_type))
                     .Attr("tensor_name", "FakeTensor:0")
+                    .Attr("debug_urls", debug_urls)
                     .Finalize(node_def()));
     return InitOp();
   }
+
+  Status Init(DataType input_type) {
+    std::vector<string> empty_debug_urls;
+    return Init(input_type, empty_debug_urls);
+  }
+
+  Env* env_;
 };
 
 TEST_F(DebugIdentityOpTest, Int32Success_6) {
@@ -48,6 +65,80 @@ TEST_F(DebugIdentityOpTest, Int32Success_6) {
   test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
 }
 
+TEST_F(DebugIdentityOpTest, Int32Success_6_FileURLs) {
+  const int kNumDumpDirs = 3;
+
+  const string tmp_dir = testing::TmpDir();
+
+  std::vector<string> dump_roots;
+  std::vector<string> debug_urls;
+  for (int i = 0; i < kNumDumpDirs; ++i) {
+    const string dump_root = strings::StrCat(tmp_dir, "_", i);
+    dump_roots.push_back(dump_root);
+
+    debug_urls.push_back(strings::StrCat("file://", dump_root));
+  }
+
+  uint64 wall_time = Env::Default()->NowMicros();
+
+  TF_ASSERT_OK(Init(DT_INT32, debug_urls));
+  AddInputFromArray<int32>(TensorShape({6}), {1, 2, 3, 4, 5, 6});
+  TF_ASSERT_OK(RunOpKernel());
+  Tensor expected(allocator(), DT_INT32, TensorShape({6}));
+  test::FillValues<int32>(&expected, {1, 2, 3, 4, 5, 6});
+  // Verify the identity output
+  test::ExpectTensorEqual<int32>(expected, *GetOutput(0));
+
+  for (int i = 0; i < kNumDumpDirs; ++i) {
+    ASSERT_TRUE(env_->FileExists(dump_roots[i]));
+    ASSERT_TRUE(env_->IsDirectory(dump_roots[i]).ok());
+
+    DIR* dir = opendir(dump_roots[i].c_str());
+    struct dirent* ent;
+    int dump_files_found = 0;
+    while ((ent = readdir(dir)) != NULL) {
+      if (strcmp(ent->d_name, ".") && strcmp(ent->d_name, "..")) {
+        dump_files_found++;
+
+        // Try reading the file into a Event proto.
+        const string dump_file_path =
+            strings::StrCat(dump_roots[i], "/", ent->d_name);
+        std::fstream ifs(dump_file_path, std::ios::in | std::ios::binary);
+        Event event;
+        event.ParseFromIstream(&ifs);
+        ifs.close();
+
+        ASSERT_GE(event.wall_time(), wall_time);
+        ASSERT_EQ(1, event.summary().value().size());
+        ASSERT_EQ(strings::StrCat("FakeTensor", ":", 0, ":", "DebugIdentity"),
+                  event.summary().value(0).node_name());
+
+        Tensor tensor_prime(DT_INT32);
+        ASSERT_TRUE(tensor_prime.FromProto(event.summary().value(0).tensor()));
+
+        // Verify tensor shape and value from the dump file.
+        ASSERT_EQ(TensorShape({6}), tensor_prime.shape());
+
+        for (int j = 0; j < 6; ++j) {
+          ASSERT_EQ(j + 1, tensor_prime.flat<int32>()(j));
+        }
+      }
+    }
+    closedir(dir);
+
+    ASSERT_EQ(1, dump_files_found);
+
+    // Remove temporary dump directory and file.
+    int64 undeleted_files = 0;
+    int64 undeleted_dirs = 0;
+    ASSERT_TRUE(env_->DeleteRecursively(dump_roots[i], &undeleted_files,
+                                        &undeleted_dirs)
+                    .ok());
+    ASSERT_EQ(0, undeleted_files);
+    ASSERT_EQ(0, undeleted_dirs);
+  }
+}
+
 TEST_F(DebugIdentityOpTest, Int32Success_2_3) {
   TF_ASSERT_OK(Init(DT_INT32));
   AddInputFromArray<int32>(TensorShape({2, 3}), {1, 2, 3, 4, 5, 6});
@@ -66,8 +157,6 @@ TEST_F(DebugIdentityOpTest, StringSuccess) {
   test::ExpectTensorEqual<string>(expected, *GetOutput(0));
 }
 
-TEST_F(DebugIdentityOpTest, RefInputError) { TF_ASSERT_OK(Init(DT_INT32_REF)); }
-
 // Tests for DebugNanCountOp
 class DebugNanCountOpTest : public OpsTestBase {
  protected:
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc
index 0861fa99821..63ad0059d45 100644
--- a/tensorflow/core/kernels/shape_ops.cc
+++ b/tensorflow/core/kernels/shape_ops.cc
@@ -253,6 +253,8 @@ class ExpandDimsOp : public OpKernel {
                            " and output shape ", output_shape.DebugString()));
     }
   }
+
+  bool IsExpensive() override { return false; }
 };
 REGISTER_KERNEL_BUILDER(Name("ExpandDims").Device(DEVICE_CPU).HostMemory("dim"),
                         ExpandDimsOp);
@@ -342,6 +344,8 @@ class SqueezeOp : public OpKernel {
     }
   }
 
+  bool IsExpensive() override { return false; }
+
  private:
   std::unordered_set<int32> squeeze_dims_;
 };
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index fe3d7406961..5ba4e0cce69 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -2985,6 +2985,7 @@ REGISTER_OP("DebugIdentity")
     .Output("output: T")
     .Attr("T: type")
     .Attr("tensor_name: string = ''")
+    .Attr("debug_urls: list(string) = []")
     .Doc(R"doc(
 Debug Identity Op.
 
@@ -2993,6 +2994,8 @@ Provides an identity mapping of the non-Ref type input tensor for debugging.
 input: Input tensor, non-Reference type.
 output: Output tensor that equals the input tensor.
 tensor_name: Name of the input tensor.
+debug_urls: List of URLs to debug targets, e.g.,
+            file:///foo/tfdbg_dump, grpc:://localhost:11011
 )doc");
 
 REGISTER_OP("DebugNanCount")
@@ -3000,6 +3003,7 @@ REGISTER_OP("DebugNanCount")
     .Output("output: int64")  // The debug signal (nan count) is int64
     .Attr("T: type")
     .Attr("tensor_name: string = ''")
+    .Attr("debug_urls: list(string) = []")
     .Doc(R"doc(
 Debug NaN Value Counter Op
 
@@ -3008,6 +3012,8 @@ Counts number of NaNs in the input tensor, for debugging.
 input: Input tensor, non-Reference type.
 output: An integer output tensor that is the number of NaNs in the input.
 tensor_name: Name of the input tensor.
+debug_urls: List of URLs to debug targets, e.g.,
+            file:///foo/tfdbg_dump, grpc:://localhost:11011
 )doc");
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
index 8b3230c3fed..6c7556076a9 100644
--- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
@@ -7876,6 +7876,36 @@ op {
     }
   }
 }
+op {
+  name: "DebugIdentity"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
 op {
   name: "DebugNanCount"
   input_arg {
@@ -7898,6 +7928,36 @@ op {
     }
   }
 }
+op {
+  name: "DebugNanCount"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  output_arg {
+    name: "output"
+    type: DT_INT64
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "tensor_name"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+  }
+}
 op {
   name: "DecodeCSV"
   input_arg {
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 5da2754fd47..01bb4bc82f8 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4278,6 +4278,15 @@ op {
     }
     description: "Name of the input tensor."
   }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+    description: "List of URLs to debug targets, e.g.,\nfile:///foo/tfdbg_dump, grpc:://localhost:11011"
+  }
   summary: "Debug Identity Op."
   description: "Provides an identity mapping of the non-Ref type input tensor for debugging."
 }
@@ -4305,6 +4314,15 @@ op {
     }
     description: "Name of the input tensor."
   }
+  attr {
+    name: "debug_urls"
+    type: "list(string)"
+    default_value {
+      list {
+      }
+    }
+    description: "List of URLs to debug targets, e.g.,\nfile:///foo/tfdbg_dump, grpc:://localhost:11011"
+  }
   summary: "Debug NaN Value Counter Op"
   description: "Counts number of NaNs in the input tensor, for debugging."
 }
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index 1ad9f7175fc..ac213385054 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -662,13 +662,19 @@ keep_dims: If true, retain reduced dimensions with length 1.
 output: `R-K`-D.  The reduced Tensor.
 )doc");
 
-#define SPARSE_DENSE_CWISE_SIGNATURE() \
-  Input("sp_indices: int64")           \
-      .Input("sp_values: T")           \
-      .Input("sp_shape: int64")        \
-      .Input("dense: T")               \
-      .Output("output: T")             \
-      .Attr("T: numbertype")
+#define SPARSE_DENSE_CWISE_SIGNATURE()                           \
+  Input("sp_indices: int64")                                     \
+      .Input("sp_values: T")                                     \
+      .Input("sp_shape: int64")                                  \
+      .Input("dense: T")                                         \
+      .Output("output: T")                                       \
+      .Attr("T: numbertype")                                     \
+      .SetShapeFn([](InferenceContext* c) {                      \
+        const Shape* input;                                      \
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &input)); \
+        c->set_output(0, c->Vector(c->Dim(input, 0)));           \
+        return Status::OK();                                     \
+      })
 
 REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE().Doc(R"doc(
 Component-wise multiplies a SparseTensor by a dense Tensor.
@@ -722,6 +728,8 @@ dense: `R`-D.  The dense Tensor operand.
 output: 1-D.  The `N` values that are operated on.
 )doc");
 
+#undef SPARSE_DENSE_CWISE_SIGNATURE
+
 REGISTER_OP("SparseSoftmax")
     .Input("sp_indices: int64")
     .Input("sp_values: T")
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 33f38019d01..6f78f8cd8a9 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -180,7 +180,7 @@ message ConfigProto {
   int64 operation_timeout_in_ms = 11;
 };
 
-// EXPERIMENTAL. Option for watching a node
+// EXPERIMENTAL. Option for watching a node.
 message DebugTensorWatch {
   // Name of the node to watch.
   string node_name = 1;
@@ -196,6 +196,12 @@ message DebugTensorWatch {
   // One or more than one probes on a tensor.
   // e.g., {"DebugIdentity", "DebugNanCount"}
   repeated string debug_ops = 3;
+
+  // URL(s) for debug targets(s).
+  //   E.g., "file:///foo/tfdbg_dump", "grpc://localhost:11011"
+  // Each debug op listed in debug_ops will publish its output tensor (debug
+  // signal) to all URLs in debug_urls.
+  repeated string debug_urls = 4;
 }
 
 // EXPERIMENTAL. Options for a single Run() call.
diff --git a/tensorflow/g3doc/api_docs/python/contrib.losses.md b/tensorflow/g3doc/api_docs/python/contrib.losses.md
index 846718e196c..26d297b38f3 100644
--- a/tensorflow/g3doc/api_docs/python/contrib.losses.md
+++ b/tensorflow/g3doc/api_docs/python/contrib.losses.md
@@ -140,6 +140,31 @@ Notice that the function adds the given losses to the regularization losses.
 *  <b>`ValueError`</b>: if `losses` is not iterable.
 
 
+- - -
+
+### `tf.contrib.losses.hinge_loss(logits, target, scope=None)` {#hinge_loss}
+
+Method that returns the loss tensor for hinge loss.
+
+##### Args:
+
+
+*  <b>`logits`</b>: The logits, a float tensor.
+*  <b>`target`</b>: The ground truth output tensor. Its shape should match the shape of
+    logits. The values of the tensor are expected to be 0.0 or 1.0.
+*  <b>`scope`</b>: The scope for the operations performed in computing the loss.
+
+##### Returns:
+
+  A `Tensor` of same shape as logits and target representing the loss values
+    across the batch.
+
+##### Raises:
+
+
+*  <b>`ValueError`</b>: If the shapes of `logits` and `target` don't match.
+
+
 - - -
 
 ### `tf.contrib.losses.log_loss(predictions, targets, weight=1.0, epsilon=1e-07, scope=None)` {#log_loss}
diff --git a/tensorflow/g3doc/api_docs/python/contrib.rnn.md b/tensorflow/g3doc/api_docs/python/contrib.rnn.md
new file mode 100644
index 00000000000..201e23c66d3
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/contrib.rnn.md
@@ -0,0 +1,409 @@
+<!-- This file is machine generated: DO NOT EDIT! -->
+
+# RNN (contrib)
+[TOC]
+
+Additional RNN operations and cells.
+
+## This package provides additional contributed RNNCells.
+
+### Fused RNNCells
+- - -
+
+### `class tf.contrib.rnn.LSTMFusedCell` {#LSTMFusedCell}
+
+Basic LSTM recurrent network cell.
+
+The implementation is based on: http://arxiv.org/abs/1409.2329.
+
+We add forget_bias (default: 1) to the biases of the forget gate in order to
+reduce the scale of forgetting in the beginning of the training.
+
+Unlike BasicLSTMCell, this is a monolithic op and should be much faster. The
+weight and bias matrixes should be compatible as long as the variabel scope
+matches.
+- - -
+
+#### `tf.contrib.rnn.LSTMFusedCell.__init__(num_units, forget_bias=1.0, use_peephole=False)` {#LSTMFusedCell.__init__}
+
+Initialize the basic LSTM cell.
+
+##### Args:
+
+
+*  <b>`num_units`</b>: int, The number of units in the LSTM cell.
+*  <b>`forget_bias`</b>: float, The bias added to forget gates (see above).
+*  <b>`use_peephole`</b>: Whether to use peephole connectios or not.
+
+
+- - -
+
+#### `tf.contrib.rnn.LSTMFusedCell.output_size` {#LSTMFusedCell.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.LSTMFusedCell.state_size` {#LSTMFusedCell.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.LSTMFusedCell.zero_state(batch_size, dtype)` {#LSTMFusedCell.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
+
+
+### LSTM-like cells
+- - -
+
+### `class tf.contrib.rnn.CoupledInputForgetGateLSTMCell` {#CoupledInputForgetGateLSTMCell}
+
+Long short-term memory unit (LSTM) recurrent network cell.
+
+The default non-peephole implementation is based on:
+
+  http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
+
+S. Hochreiter and J. Schmidhuber.
+"Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+The peephole implementation is based on:
+
+  https://research.google.com/pubs/archive/43905.pdf
+
+Hasim Sak, Andrew Senior, and Francoise Beaufays.
+"Long short-term memory recurrent neural network architectures for
+ large scale acoustic modeling." INTERSPEECH, 2014.
+
+The coupling of input and forget gate is based on:
+
+  http://arxiv.org/pdf/1503.04069.pdf
+
+Greff et al. "LSTM: A Search Space Odyssey"
+
+The class uses optional peep-hole connections, and an optional projection
+layer.
+- - -
+
+#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.__init__(num_units, use_peepholes=False, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=False, activation=tanh)` {#CoupledInputForgetGateLSTMCell.__init__}
+
+Initialize the parameters for an LSTM cell.
+
+##### Args:
+
+
+*  <b>`num_units`</b>: int, The number of units in the LSTM cell
+*  <b>`use_peepholes`</b>: bool, set True to enable diagonal/peephole connections.
+*  <b>`initializer`</b>: (optional) The initializer to use for the weight and
+    projection matrices.
+*  <b>`num_proj`</b>: (optional) int, The output dimensionality for the projection
+    matrices.  If None, no projection is performed.
+*  <b>`proj_clip`</b>: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
+  provided, then the projected values are clipped elementwise to within
+  `[-proj_clip, proj_clip]`.
+
+*  <b>`num_unit_shards`</b>: How to split the weight matrix.  If >1, the weight
+    matrix is stored across num_unit_shards.
+*  <b>`num_proj_shards`</b>: How to split the projection matrix.  If >1, the
+    projection matrix is stored across num_proj_shards.
+*  <b>`forget_bias`</b>: Biases of the forget gate are initialized by default to 1
+    in order to reduce the scale of forgetting at the beginning of
+    the training.
+*  <b>`state_is_tuple`</b>: If True, accepted and returned states are 2-tuples of
+    the `c_state` and `m_state`.  By default (False), they are concatenated
+    along the column axis.  This default behavior will soon be deprecated.
+*  <b>`activation`</b>: Activation function of the inner states.
+
+
+- - -
+
+#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.output_size` {#CoupledInputForgetGateLSTMCell.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.state_size` {#CoupledInputForgetGateLSTMCell.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.zero_state(batch_size, dtype)` {#CoupledInputForgetGateLSTMCell.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
+
+- - -
+
+### `class tf.contrib.rnn.TimeFreqLSTMCell` {#TimeFreqLSTMCell}
+
+Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
+
+This implementation is based on:
+
+  Tara N. Sainath and Bo Li
+  "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
+  for LVCSR Tasks." submitted to INTERSPEECH, 2016.
+
+It uses peep-hole connections and optional cell clipping.
+- - -
+
+#### `tf.contrib.rnn.TimeFreqLSTMCell.__init__(num_units, use_peepholes=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None)` {#TimeFreqLSTMCell.__init__}
+
+Initialize the parameters for an LSTM cell.
+
+##### Args:
+
+
+*  <b>`num_units`</b>: int, The number of units in the LSTM cell
+*  <b>`use_peepholes`</b>: bool, set True to enable diagonal/peephole connections.
+*  <b>`cell_clip`</b>: (optional) A float value, if provided the cell state is clipped
+    by this value prior to the cell output activation.
+*  <b>`initializer`</b>: (optional) The initializer to use for the weight and
+    projection matrices.
+*  <b>`num_unit_shards`</b>: int, How to split the weight matrix.  If >1, the weight
+    matrix is stored across num_unit_shards.
+*  <b>`forget_bias`</b>: float, Biases of the forget gate are initialized by default
+    to 1 in order to reduce the scale of forgetting at the beginning
+    of the training.
+*  <b>`feature_size`</b>: int, The size of the input feature the LSTM spans over.
+*  <b>`frequency_skip`</b>: int, The amount the LSTM filter is shifted by in
+    frequency.
+
+
+- - -
+
+#### `tf.contrib.rnn.TimeFreqLSTMCell.output_size` {#TimeFreqLSTMCell.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.TimeFreqLSTMCell.state_size` {#TimeFreqLSTMCell.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.TimeFreqLSTMCell.zero_state(batch_size, dtype)` {#TimeFreqLSTMCell.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
+
+- - -
+
+### `class tf.contrib.rnn.GridLSTMCell` {#GridLSTMCell}
+
+Grid Long short-term memory unit (LSTM) recurrent network cell.
+
+The default is based on:
+  Nal Kalchbrenner, Ivo Danihelka and Alex Graves
+  "Grid Long Short-Term Memory," Proc. ICLR 2016.
+  http://arxiv.org/abs/1507.01526
+
+When peephole connections are used, the implementation is based on:
+  Tara N. Sainath and Bo Li
+  "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
+  for LVCSR Tasks." submitted to INTERSPEECH, 2016.
+
+The code uses optional peephole connections, shared_weights and cell clipping.
+- - -
+
+#### `tf.contrib.rnn.GridLSTMCell.__init__(num_units, use_peepholes=False, share_time_frequency_weights=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None)` {#GridLSTMCell.__init__}
+
+Initialize the parameters for an LSTM cell.
+
+##### Args:
+
+
+*  <b>`num_units`</b>: int, The number of units in the LSTM cell
+*  <b>`use_peepholes`</b>: bool, default False. Set True to enable diagonal/peephole
+    connections.
+*  <b>`share_time_frequency_weights`</b>: bool, default False. Set True to enable
+    shared cell weights between time and frequency LSTMs.
+*  <b>`cell_clip`</b>: (optional) A float value, if provided the cell state is clipped
+    by this value prior to the cell output activation.
+*  <b>`initializer`</b>: (optional) The initializer to use for the weight and
+    projection matrices.
+*  <b>`num_unit_shards`</b>: int, How to split the weight matrix.  If >1, the weight
+    matrix is stored across num_unit_shards.
+*  <b>`forget_bias`</b>: float, Biases of the forget gate are initialized by default
+    to 1 in order to reduce the scale of forgetting at the beginning
+    of the training.
+*  <b>`feature_size`</b>: int, The size of the input feature the LSTM spans over.
+*  <b>`frequency_skip`</b>: int, The amount the LSTM filter is shifted by in
+    frequency.
+
+
+- - -
+
+#### `tf.contrib.rnn.GridLSTMCell.output_size` {#GridLSTMCell.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.GridLSTMCell.state_size` {#GridLSTMCell.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.GridLSTMCell.zero_state(batch_size, dtype)` {#GridLSTMCell.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
+
+
+### RNNCell wrappers
+- - -
+
+### `class tf.contrib.rnn.AttentionCellWrapper` {#AttentionCellWrapper}
+
+Basic attention cell wrapper.
+
+Implementation based on https://arxiv.org/pdf/1601.06733.pdf.
+- - -
+
+#### `tf.contrib.rnn.AttentionCellWrapper.__init__(cell, attn_length, attn_size=None, attn_vec_size=None, input_size=None, state_is_tuple=False)` {#AttentionCellWrapper.__init__}
+
+Create a cell with attention.
+
+##### Args:
+
+
+*  <b>`cell`</b>: an RNNCell, an attention is added to it.
+*  <b>`attn_length`</b>: integer, the size of an attention window.
+*  <b>`attn_size`</b>: integer, the size of an attention vector. Equal to
+      cell.output_size by default.
+*  <b>`attn_vec_size`</b>: integer, the number of convolutional features calculated
+      on attention state and a size of the hidden layer built from
+      base cell state. Equal attn_size to by default.
+*  <b>`input_size`</b>: integer, the size of a hidden linear layer,
+      built from inputs and attention. Derived from the input tensor
+      by default.
+*  <b>`state_is_tuple`</b>: If True, accepted and returned states are n-tuples, where
+    `n = len(cells)`.  By default (False), the states are all
+    concatenated along the column axis.
+
+##### Raises:
+
+
+*  <b>`TypeError`</b>: if cell is not an RNNCell.
+*  <b>`ValueError`</b>: if cell returns a state tuple but the flag
+      `state_is_tuple` is `False` or if attn_length is zero or less.
+
+
+- - -
+
+#### `tf.contrib.rnn.AttentionCellWrapper.output_size` {#AttentionCellWrapper.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.AttentionCellWrapper.state_size` {#AttentionCellWrapper.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.AttentionCellWrapper.zero_state(batch_size, dtype)` {#AttentionCellWrapper.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
+
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index 6108155a0ec..9bc9111fc39 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -1105,7 +1105,10 @@ DEPRECATED: Use outputs.
 
 ### `class tf.Tensor` {#Tensor}
 
-Represents a value produced by an `Operation`.
+Represents one of the outputs of an `Operation`.
+
+*Note:* the `Tensor` class will be replaced by `Output` in the future.
+Currently these two are aliases for each other.
 
 A `Tensor` is a symbolic handle to one of the outputs of an
 `Operation`. It does not hold the values of that operation's output,
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.rnn.AttentionCellWrapper.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.rnn.AttentionCellWrapper.md
new file mode 100644
index 00000000000..73f35490f75
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.rnn.AttentionCellWrapper.md
@@ -0,0 +1,70 @@
+Basic attention cell wrapper.
+
+Implementation based on https://arxiv.org/pdf/1601.06733.pdf.
+- - -
+
+#### `tf.contrib.rnn.AttentionCellWrapper.__init__(cell, attn_length, attn_size=None, attn_vec_size=None, input_size=None, state_is_tuple=False)` {#AttentionCellWrapper.__init__}
+
+Create a cell with attention.
+
+##### Args:
+
+
+*  <b>`cell`</b>: an RNNCell, an attention is added to it.
+*  <b>`attn_length`</b>: integer, the size of an attention window.
+*  <b>`attn_size`</b>: integer, the size of an attention vector. Equal to
+      cell.output_size by default.
+*  <b>`attn_vec_size`</b>: integer, the number of convolutional features calculated
+      on attention state and a size of the hidden layer built from
+      base cell state. Equal attn_size to by default.
+*  <b>`input_size`</b>: integer, the size of a hidden linear layer,
+      built from inputs and attention. Derived from the input tensor
+      by default.
+*  <b>`state_is_tuple`</b>: If True, accepted and returned states are n-tuples, where
+    `n = len(cells)`.  By default (False), the states are all
+    concatenated along the column axis.
+
+##### Raises:
+
+
+*  <b>`TypeError`</b>: if cell is not an RNNCell.
+*  <b>`ValueError`</b>: if cell returns a state tuple but the flag
+      `state_is_tuple` is `False` or if attn_length is zero or less.
+
+
+- - -
+
+#### `tf.contrib.rnn.AttentionCellWrapper.output_size` {#AttentionCellWrapper.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.AttentionCellWrapper.state_size` {#AttentionCellWrapper.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.AttentionCellWrapper.zero_state(batch_size, dtype)` {#AttentionCellWrapper.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md
index 73af134a7a5..6925d9d6d7c 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md
@@ -1,4 +1,7 @@
-Represents a value produced by an `Operation`.
+Represents one of the outputs of an `Operation`.
+
+*Note:* the `Tensor` class will be replaced by `Output` in the future.
+Currently these two are aliases for each other.
 
 A `Tensor` is a symbolic handle to one of the outputs of an
 `Operation`. It does not hold the values of that operation's output,
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.hinge_loss.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.hinge_loss.md
new file mode 100644
index 00000000000..57758e07104
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.losses.hinge_loss.md
@@ -0,0 +1,22 @@
+### `tf.contrib.losses.hinge_loss(logits, target, scope=None)` {#hinge_loss}
+
+Method that returns the loss tensor for hinge loss.
+
+##### Args:
+
+
+*  <b>`logits`</b>: The logits, a float tensor.
+*  <b>`target`</b>: The ground truth output tensor. Its shape should match the shape of
+    logits. The values of the tensor are expected to be 0.0 or 1.0.
+*  <b>`scope`</b>: The scope for the operations performed in computing the loss.
+
+##### Returns:
+
+  A `Tensor` of same shape as logits and target representing the loss values
+    across the batch.
+
+##### Raises:
+
+
+*  <b>`ValueError`</b>: If the shapes of `logits` and `target` don't match.
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.rnn.GridLSTMCell.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.rnn.GridLSTMCell.md
new file mode 100644
index 00000000000..509f59748cd
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.rnn.GridLSTMCell.md
@@ -0,0 +1,77 @@
+Grid Long short-term memory unit (LSTM) recurrent network cell.
+
+The default is based on:
+  Nal Kalchbrenner, Ivo Danihelka and Alex Graves
+  "Grid Long Short-Term Memory," Proc. ICLR 2016.
+  http://arxiv.org/abs/1507.01526
+
+When peephole connections are used, the implementation is based on:
+  Tara N. Sainath and Bo Li
+  "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
+  for LVCSR Tasks." submitted to INTERSPEECH, 2016.
+
+The code uses optional peephole connections, shared_weights and cell clipping.
+- - -
+
+#### `tf.contrib.rnn.GridLSTMCell.__init__(num_units, use_peepholes=False, share_time_frequency_weights=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None)` {#GridLSTMCell.__init__}
+
+Initialize the parameters for an LSTM cell.
+
+##### Args:
+
+
+*  <b>`num_units`</b>: int, The number of units in the LSTM cell
+*  <b>`use_peepholes`</b>: bool, default False. Set True to enable diagonal/peephole
+    connections.
+*  <b>`share_time_frequency_weights`</b>: bool, default False. Set True to enable
+    shared cell weights between time and frequency LSTMs.
+*  <b>`cell_clip`</b>: (optional) A float value, if provided the cell state is clipped
+    by this value prior to the cell output activation.
+*  <b>`initializer`</b>: (optional) The initializer to use for the weight and
+    projection matrices.
+*  <b>`num_unit_shards`</b>: int, How to split the weight matrix.  If >1, the weight
+    matrix is stored across num_unit_shards.
+*  <b>`forget_bias`</b>: float, Biases of the forget gate are initialized by default
+    to 1 in order to reduce the scale of forgetting at the beginning
+    of the training.
+*  <b>`feature_size`</b>: int, The size of the input feature the LSTM spans over.
+*  <b>`frequency_skip`</b>: int, The amount the LSTM filter is shifted by in
+    frequency.
+
+
+- - -
+
+#### `tf.contrib.rnn.GridLSTMCell.output_size` {#GridLSTMCell.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.GridLSTMCell.state_size` {#GridLSTMCell.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.GridLSTMCell.zero_state(batch_size, dtype)` {#GridLSTMCell.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.rnn.CoupledInputForgetGateLSTMCell.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.rnn.CoupledInputForgetGateLSTMCell.md
new file mode 100644
index 00000000000..0e36b224bc6
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.rnn.CoupledInputForgetGateLSTMCell.md
@@ -0,0 +1,93 @@
+Long short-term memory unit (LSTM) recurrent network cell.
+
+The default non-peephole implementation is based on:
+
+  http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
+
+S. Hochreiter and J. Schmidhuber.
+"Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+The peephole implementation is based on:
+
+  https://research.google.com/pubs/archive/43905.pdf
+
+Hasim Sak, Andrew Senior, and Francoise Beaufays.
+"Long short-term memory recurrent neural network architectures for
+ large scale acoustic modeling." INTERSPEECH, 2014.
+
+The coupling of input and forget gate is based on:
+
+  http://arxiv.org/pdf/1503.04069.pdf
+
+Greff et al. "LSTM: A Search Space Odyssey"
+
+The class uses optional peep-hole connections, and an optional projection
+layer.
+- - -
+
+#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.__init__(num_units, use_peepholes=False, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=False, activation=tanh)` {#CoupledInputForgetGateLSTMCell.__init__}
+
+Initialize the parameters for an LSTM cell.
+
+##### Args:
+
+
+*  <b>`num_units`</b>: int, The number of units in the LSTM cell
+*  <b>`use_peepholes`</b>: bool, set True to enable diagonal/peephole connections.
+*  <b>`initializer`</b>: (optional) The initializer to use for the weight and
+    projection matrices.
+*  <b>`num_proj`</b>: (optional) int, The output dimensionality for the projection
+    matrices.  If None, no projection is performed.
+*  <b>`proj_clip`</b>: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
+  provided, then the projected values are clipped elementwise to within
+  `[-proj_clip, proj_clip]`.
+
+*  <b>`num_unit_shards`</b>: How to split the weight matrix.  If >1, the weight
+    matrix is stored across num_unit_shards.
+*  <b>`num_proj_shards`</b>: How to split the projection matrix.  If >1, the
+    projection matrix is stored across num_proj_shards.
+*  <b>`forget_bias`</b>: Biases of the forget gate are initialized by default to 1
+    in order to reduce the scale of forgetting at the beginning of
+    the training.
+*  <b>`state_is_tuple`</b>: If True, accepted and returned states are 2-tuples of
+    the `c_state` and `m_state`.  By default (False), they are concatenated
+    along the column axis.  This default behavior will soon be deprecated.
+*  <b>`activation`</b>: Activation function of the inner states.
+
+
+- - -
+
+#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.output_size` {#CoupledInputForgetGateLSTMCell.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.state_size` {#CoupledInputForgetGateLSTMCell.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.CoupledInputForgetGateLSTMCell.zero_state(batch_size, dtype)` {#CoupledInputForgetGateLSTMCell.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.rnn.TimeFreqLSTMCell.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.rnn.TimeFreqLSTMCell.md
new file mode 100644
index 00000000000..e870477b6ba
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.rnn.TimeFreqLSTMCell.md
@@ -0,0 +1,70 @@
+Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
+
+This implementation is based on:
+
+  Tara N. Sainath and Bo Li
+  "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
+  for LVCSR Tasks." submitted to INTERSPEECH, 2016.
+
+It uses peep-hole connections and optional cell clipping.
+- - -
+
+#### `tf.contrib.rnn.TimeFreqLSTMCell.__init__(num_units, use_peepholes=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None)` {#TimeFreqLSTMCell.__init__}
+
+Initialize the parameters for an LSTM cell.
+
+##### Args:
+
+
+*  <b>`num_units`</b>: int, The number of units in the LSTM cell
+*  <b>`use_peepholes`</b>: bool, set True to enable diagonal/peephole connections.
+*  <b>`cell_clip`</b>: (optional) A float value, if provided the cell state is clipped
+    by this value prior to the cell output activation.
+*  <b>`initializer`</b>: (optional) The initializer to use for the weight and
+    projection matrices.
+*  <b>`num_unit_shards`</b>: int, How to split the weight matrix.  If >1, the weight
+    matrix is stored across num_unit_shards.
+*  <b>`forget_bias`</b>: float, Biases of the forget gate are initialized by default
+    to 1 in order to reduce the scale of forgetting at the beginning
+    of the training.
+*  <b>`feature_size`</b>: int, The size of the input feature the LSTM spans over.
+*  <b>`frequency_skip`</b>: int, The amount the LSTM filter is shifted by in
+    frequency.
+
+
+- - -
+
+#### `tf.contrib.rnn.TimeFreqLSTMCell.output_size` {#TimeFreqLSTMCell.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.TimeFreqLSTMCell.state_size` {#TimeFreqLSTMCell.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.TimeFreqLSTMCell.zero_state(batch_size, dtype)` {#TimeFreqLSTMCell.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.rnn.LSTMFusedCell.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.rnn.LSTMFusedCell.md
new file mode 100644
index 00000000000..fec80caecf1
--- /dev/null
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.rnn.LSTMFusedCell.md
@@ -0,0 +1,60 @@
+Basic LSTM recurrent network cell.
+
+The implementation is based on: http://arxiv.org/abs/1409.2329.
+
+We add forget_bias (default: 1) to the biases of the forget gate in order to
+reduce the scale of forgetting in the beginning of the training.
+
+Unlike BasicLSTMCell, this is a monolithic op and should be much faster. The
+weight and bias matrixes should be compatible as long as the variabel scope
+matches.
+- - -
+
+#### `tf.contrib.rnn.LSTMFusedCell.__init__(num_units, forget_bias=1.0, use_peephole=False)` {#LSTMFusedCell.__init__}
+
+Initialize the basic LSTM cell.
+
+##### Args:
+
+
+*  <b>`num_units`</b>: int, The number of units in the LSTM cell.
+*  <b>`forget_bias`</b>: float, The bias added to forget gates (see above).
+*  <b>`use_peephole`</b>: Whether to use peephole connectios or not.
+
+
+- - -
+
+#### `tf.contrib.rnn.LSTMFusedCell.output_size` {#LSTMFusedCell.output_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.LSTMFusedCell.state_size` {#LSTMFusedCell.state_size}
+
+
+
+
+- - -
+
+#### `tf.contrib.rnn.LSTMFusedCell.zero_state(batch_size, dtype)` {#LSTMFusedCell.zero_state}
+
+Return zero-filled state tensor(s).
+
+##### Args:
+
+
+*  <b>`batch_size`</b>: int, float, or unit Tensor representing the batch size.
+*  <b>`dtype`</b>: the data type to use for the state.
+
+##### Returns:
+
+  If `state_size` is an int or TensorShape, then the return value is a
+  `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.
+
+  If `state_size` is a nested list or tuple, then the return value is
+  a nested list or tuple (of the same structure) of `2-D` tensors with
+the shapes `[batch_size x s]` for each s in `state_size`.
+
+
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index bad44886b63..448a32d72a5 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -745,12 +745,20 @@
   * [`get_losses`](../../api_docs/python/contrib.losses.md#get_losses)
   * [`get_regularization_losses`](../../api_docs/python/contrib.losses.md#get_regularization_losses)
   * [`get_total_loss`](../../api_docs/python/contrib.losses.md#get_total_loss)
+  * [`hinge_loss`](../../api_docs/python/contrib.losses.md#hinge_loss)
   * [`log_loss`](../../api_docs/python/contrib.losses.md#log_loss)
   * [`sigmoid_cross_entropy`](../../api_docs/python/contrib.losses.md#sigmoid_cross_entropy)
   * [`softmax_cross_entropy`](../../api_docs/python/contrib.losses.md#softmax_cross_entropy)
   * [`sum_of_pairwise_squares`](../../api_docs/python/contrib.losses.md#sum_of_pairwise_squares)
   * [`sum_of_squares`](../../api_docs/python/contrib.losses.md#sum_of_squares)
 
+* **[RNN (contrib)](../../api_docs/python/contrib.rnn.md)**:
+  * [`AttentionCellWrapper`](../../api_docs/python/contrib.rnn.md#AttentionCellWrapper)
+  * [`CoupledInputForgetGateLSTMCell`](../../api_docs/python/contrib.rnn.md#CoupledInputForgetGateLSTMCell)
+  * [`GridLSTMCell`](../../api_docs/python/contrib.rnn.md#GridLSTMCell)
+  * [`LSTMFusedCell`](../../api_docs/python/contrib.rnn.md#LSTMFusedCell)
+  * [`TimeFreqLSTMCell`](../../api_docs/python/contrib.rnn.md#TimeFreqLSTMCell)
+
 * **[Metrics (contrib)](../../api_docs/python/contrib.metrics.md)**:
   * [`accuracy`](../../api_docs/python/contrib.metrics.md#accuracy)
   * [`aggregate_metric_map`](../../api_docs/python/contrib.metrics.md#aggregate_metric_map)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index c5c5573f211..fce04e50f7c 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1182,6 +1182,18 @@ py_test(
     ],
 )
 
+py_test(
+    name = "session_debug_test",
+    size = "small",
+    srcs = ["debug/session_debug_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":framework",
+        ":framework_test_lib",
+        ":session",
+    ],
+)
+
 cuda_py_test(
     name = "timeline_test",
     size = "small",
diff --git a/tensorflow/python/debug/session_debug_test.py b/tensorflow/python/debug/session_debug_test.py
new file mode 100644
index 00000000000..d9fdb240c9d
--- /dev/null
+++ b/tensorflow/python/debug/session_debug_test.py
@@ -0,0 +1,298 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for debugger functionalities in tf.Session."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import os
+import shutil
+import tempfile
+
+import numpy as np
+from six.moves import xrange  # pylint: disable=redefined-builtin
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.util import event_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+
+class SessionDebugTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self.dump_root_ = tempfile.mkdtemp()
+
+  def tearDown(self):
+    # Tear down temporary dump directory.
+    shutil.rmtree(self.dump_root_)
+
+  def _addDebugTensorWatch(self,
+                           run_opts,
+                           node_name,
+                           output_slot,
+                           debug_op="DebugIdentity",
+                           debug_urls=None):
+    watch_opts = run_opts.debug_tensor_watch_opts
+
+    # Add debug tensor watch for u.
+    watch = watch_opts.add()
+    watch.node_name = node_name
+    watch.output_slot = 0
+    watch.debug_ops.append(debug_op)
+
+    if debug_urls:
+      for debug_url in debug_urls:
+        watch.debug_urls.append(debug_url)
+
+  def _verifyTensorDumpFile(self, dump_file, expected_tensor_name, debug_op,
+                            wall_time_lower_bound, expected_tensor_val):
+    """Helper method: Verify a Tensor debug dump file and its content.
+
+    Args:
+      dump_file: Path to the dump file.
+      expected_tensor_name: Expected name of the tensor, e.g., node_a:0.
+      debug_op: Name of the debug Op, e.g., DebugIdentity.
+      wall_time_lower_bound: Lower bound of the wall time.
+      expected_tensor_val: Expected tensor value, as a numpy array.
+    """
+    self.assertTrue(os.path.isfile(dump_file))
+
+    event = event_pb2.Event()
+    f = open(dump_file, "rb")
+    event.ParseFromString(f.read())
+
+    wall_time = event.wall_time
+    debg_node_name = event.summary.value[0].node_name
+
+    tensor_value = tensor_util.MakeNdarray(event.summary.value[0].tensor)
+
+    self.assertGreater(wall_time, wall_time_lower_bound)
+    self.assertEqual("%s:%s" % (expected_tensor_name, debug_op), debg_node_name)
+
+    if expected_tensor_val.dtype.type is np.string_:
+      self.assertEqual(str(expected_tensor_val), str(tensor_value))
+    else:
+      self.assertAllClose(expected_tensor_val, tensor_value)
+
+  def testDumpToFileOverlaoppinpParentDir(self):
+    with session.Session() as sess:
+      u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
+      v_init_val = np.array([[2.0], [-1.0]])
+
+      # Use node names with overlapping namespace (i.e., parent directory) to
+      # test concurrent, non-racing directory creation.
+      u_name = "testDumpToFile/u"
+      v_name = "testDumpToFile/v"
+
+      u_init = constant_op.constant(u_init_val, shape=[2, 2])
+      u = variables.Variable(u_init, name=u_name)
+      v_init = constant_op.constant(v_init_val, shape=[2, 1])
+      v = variables.Variable(v_init, name=v_name)
+
+      w = math_ops.matmul(u, v, name="testDumpToFile/matmul")
+
+      u.initializer.run()
+      v.initializer.run()
+
+      run_options = config_pb2.RunOptions()
+      debug_url = "file://%s" % self.dump_root_
+
+      # Add debug tensor watch for u.
+      self._addDebugTensorWatch(
+          run_options, "%s/read" % u_name, 0, debug_urls=[debug_url])
+      # Add debug tensor watch for v.
+      self._addDebugTensorWatch(
+          run_options, "%s/read" % v_name, 0, debug_urls=[debug_url])
+
+      run_metadata = config_pb2.RunMetadata()
+
+      # Invoke Session.run().
+      sess.run(w, options=run_options, run_metadata=run_metadata)
+
+      # Verify the dump file for u.
+      dump_files = os.listdir(os.path.join(self.dump_root_, u_name))
+      self.assertEqual(1, len(dump_files))
+      self.assertTrue(dump_files[0].startswith("read_0_"))
+
+      dump_file = os.path.join(self.dump_root_, u_name, dump_files[0])
+      self._verifyTensorDumpFile(dump_file, "%s/read:0" % u_name,
+                                 "DebugIdentity", 0, u_init_val)
+
+      # Verify the dump file for v.
+      dump_files = os.listdir(os.path.join(self.dump_root_, v_name))
+      self.assertEqual(1, len(dump_files))
+      self.assertTrue(dump_files[0].startswith("read_0_"))
+
+      dump_file = os.path.join(self.dump_root_, v_name, dump_files[0])
+      self._verifyTensorDumpFile(dump_file, "%s/read:0" % v_name,
+                                 "DebugIdentity", 0, v_init_val)
+
+  def testDumpStringTensorsToFileSystem(self):
+    with session.Session() as sess:
+      str1_init_val = np.array(b"abc")
+      str2_init_val = np.array(b"def")
+
+      str1_init = constant_op.constant(str1_init_val)
+      str2_init = constant_op.constant(str2_init_val)
+
+      str1_name = "str1"
+      str2_name = "str2"
+      str1 = variables.Variable(str1_init, name=str1_name)
+      str2 = variables.Variable(str2_init, name=str2_name)
+      # Concatenate str1 and str2
+      str_concat = math_ops.add(str1, str2, name="str_concat")
+
+      str1.initializer.run()
+      str2.initializer.run()
+
+      run_options = config_pb2.RunOptions()
+      debug_url = "file://%s" % self.dump_root_
+
+      # Add debug tensor watch for u.
+      self._addDebugTensorWatch(
+          run_options, "%s/read" % str1_name, 0, debug_urls=[debug_url])
+      # Add debug tensor watch for v.
+      self._addDebugTensorWatch(
+          run_options, "%s/read" % str2_name, 0, debug_urls=[debug_url])
+
+      run_metadata = config_pb2.RunMetadata()
+
+      # Invoke Session.run().
+      sess.run(str_concat, options=run_options, run_metadata=run_metadata)
+
+      # Verify the dump file for str1.
+      dump_files = os.listdir(os.path.join(self.dump_root_, str1_name))
+      self.assertEqual(1, len(dump_files))
+      self.assertTrue(dump_files[0].startswith("read_0_"))
+      dump_file = os.path.join(self.dump_root_, str1_name, dump_files[0])
+      self._verifyTensorDumpFile(dump_file, "%s/read:0" % str1_name,
+                                 "DebugIdentity", 0, str1_init_val)
+
+      # Verify the dump file for str2.
+      dump_files = os.listdir(os.path.join(self.dump_root_, str2_name))
+      self.assertEqual(1, len(dump_files))
+      self.assertTrue(dump_files[0].startswith("read_0_"))
+      dump_file = os.path.join(self.dump_root_, str2_name, dump_files[0])
+      self._verifyTensorDumpFile(dump_file, "%s/read:0" % str2_name,
+                                 "DebugIdentity", 0, str2_init_val)
+
+  def testDumpToFileWhileLoop(self):
+    with session.Session() as sess:
+      num_iter = 10
+
+      # "u" is the Variable being updated in the loop.
+      u_name = "testDumpToFileWhileLoop/u"
+      u_namespace = u_name.split("/")[0]
+
+      u_init_val = np.array(11.0)
+      u_init = constant_op.constant(u_init_val)
+      u = variables.Variable(u_init, name=u_name)
+
+      # "v" is the increment.
+      v_name = "testDumpToFileWhileLoop/v"
+      v_namespace = v_name.split("/")[0]
+
+      v_init_val = np.array(2.0)
+      v_init = constant_op.constant(v_init_val)
+      v = variables.Variable(v_init, name=v_name)
+
+      u.initializer.run()
+      v.initializer.run()
+
+      i = constant_op.constant(0, name="testDumpToFileWhileLoop/i")
+
+      def cond(i):
+        return math_ops.less(i, num_iter)
+
+      def body(i):
+        new_u = state_ops.assign_add(u, v)
+        new_i = math_ops.add(i, 1)
+        op = control_flow_ops.group(new_u)
+        new_i = control_flow_ops.with_dependencies([op], new_i)
+        return [new_i]
+
+      loop = control_flow_ops.while_loop(cond, body, [i], parallel_iterations=1)
+
+      # Create RunOptions for debug-watching tensors
+      run_options = config_pb2.RunOptions()
+      debug_url = "file://%s" % self.dump_root_
+
+      # Add debug tensor watch for u.
+      self._addDebugTensorWatch(run_options, u_name, 0, debug_urls=[debug_url])
+      # Add debug tensor watch for v.
+      self._addDebugTensorWatch(
+          run_options, "%s/read" % v_name, 0, debug_urls=[debug_url])
+      # Add debug tensor watch for while/Identity.
+      self._addDebugTensorWatch(
+          run_options, "while/Identity", 0, debug_urls=[debug_url])
+
+      run_metadata = config_pb2.RunMetadata()
+
+      r = sess.run(loop, options=run_options, run_metadata=run_metadata)
+
+      self.assertEqual(num_iter, r)
+
+      u_val_final = sess.run(u)
+      self.assertAllClose(u_init_val + num_iter * v_init_val, u_val_final)
+
+      # Verify dump files
+      self.assertTrue(os.path.isdir(self.dump_root_))
+
+      self.assertTrue(os.path.isdir(os.path.join(self.dump_root_, u_namespace)))
+      self.assertTrue(
+          os.path.isdir(os.path.join(self.dump_root_, v_namespace, "v")))
+
+      # Verify the dump file for tensor "u".
+      dump_files = glob.glob(
+          os.path.join(self.dump_root_, u_namespace, "u_0_*"))
+      self.assertEqual(1, len(dump_files))
+      dump_file = os.path.join(self.dump_root_, u_namespace, dump_files[0])
+      self.assertTrue(os.path.isfile(dump_file))
+      self._verifyTensorDumpFile(dump_file, "%s:0" % u_name, "DebugIdentity", 0,
+                                 u_init_val)
+
+      # Verify the dump file for tensor "v".
+      dump_files = os.listdir(os.path.join(self.dump_root_, v_name))
+      self.assertEqual(1, len(dump_files))
+      self.assertTrue(dump_files[0].startswith("read_0_"))
+
+      dump_file = os.path.join(self.dump_root_, v_name, dump_files[0])
+      self._verifyTensorDumpFile(dump_file, "%s/read:0" % v_name,
+                                 "DebugIdentity", 0, v_init_val)
+
+      # Verify the dump files for tensor while/Identity
+      while_identity_dump_files = sorted(
+          os.listdir(os.path.join(self.dump_root_, "while")))
+      self.assertEqual(num_iter, len(while_identity_dump_files))
+
+      # Verify the content of the individual
+      for k in xrange(len(while_identity_dump_files)):
+        dump_file_path = os.path.join(self.dump_root_, "while",
+                                      while_identity_dump_files[k])
+        self._verifyTensorDumpFile(dump_file_path, "while/Identity:0",
+                                   "DebugIdentity", 0, np.array(k))
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/framework/framework_lib.py b/tensorflow/python/framework/framework_lib.py
index b06605cf592..3f77187a25c 100644
--- a/tensorflow/python/framework/framework_lib.py
+++ b/tensorflow/python/framework/framework_lib.py
@@ -72,6 +72,7 @@ from tensorflow.python.framework.device import DeviceSpec
 from tensorflow.python.framework.ops import Graph
 from tensorflow.python.framework.ops import Operation
 from tensorflow.python.framework.ops import Tensor
+from tensorflow.python.framework.ops import Output
 from tensorflow.python.framework.ops import SparseTensor
 from tensorflow.python.framework.ops import SparseTensorValue
 from tensorflow.python.framework.ops import IndexedSlices
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index bfea7b6aca7..49d9cec7c19 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -65,6 +65,7 @@ def get_module_to_name():
       tf.contrib.learn.monitors: (
           "tf.contrib.learn.monitors"),
       tf.contrib.losses: "tf.contrib.losses",
+      tf.contrib.rnn: "tf.contrib.rnn",
       tf.contrib.metrics: "tf.contrib.metrics",
       tf.contrib.util: "tf.contrib.util",
   }
@@ -171,6 +172,7 @@ def all_libraries(module_to_name, members, documented):
       library("contrib.learn.monitors", "Monitors (contrib)",
               tf.contrib.learn.monitors),
       library("contrib.losses", "Losses (contrib)", tf.contrib.losses),
+      library("contrib.rnn", "RNN (contrib)", tf.contrib.rnn),
       library("contrib.metrics", "Metrics (contrib)", tf.contrib.metrics),
       library("contrib.util", "Utilities (contrib)", tf.contrib.util),
       library("contrib.copy_graph", "Copying Graph Elements (contrib)",
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f89f3d46972..854d46b955e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -185,7 +185,10 @@ def register_dense_tensor_like_type(tensor_type):
 
 
 class Tensor(object):
-  """Represents a value produced by an `Operation`.
+  """Represents one of the outputs of an `Operation`.
+
+  *Note:* the `Tensor` class will be replaced by `Output` in the future.
+  Currently these two are aliases for each other.
 
   A `Tensor` is a symbolic handle to one of the outputs of an
   `Operation`. It does not hold the values of that operation's output,
@@ -556,6 +559,10 @@ class Tensor(object):
     return _eval_using_default_session(self, feed_dict, self.graph, session)
 
 
+# TODO(josh11b): Switch everyone from "Tensor" to "Output" to match C++ API.
+Output = Tensor
+
+
 def _TensorTensorConversionFunction(t, dtype=None, name=None, as_ref=False):
   _ = name, as_ref
   if dtype and not dtype.is_compatible_with(t.dtype):
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index f63cf812474..abed6b5777a 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -846,8 +846,7 @@ def per_image_whitening(image):
   stddev = math_ops.sqrt(variance)
 
   # Apply a minimum normalization that protects us against uniform images.
-  min_stddev = math_ops.inv(
-      math_ops.sqrt(math_ops.cast(num_pixels, dtypes.float32)))
+  min_stddev = math_ops.rsqrt(math_ops.cast(num_pixels, dtypes.float32))
   pixel_value_scale = math_ops.maximum(stddev, min_stddev)
   pixel_value_offset = image_mean
 
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 0620a3da2c4..2331f21d479 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -161,7 +161,7 @@ def _SegmentMeanGrad(op, grad):
           array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1)])
   ones = array_ops.fill(ones_shape,
                         constant_op.constant(1, dtype=grad.dtype))
-  scaled_grad = grad * math_ops.inv(math_ops.segment_sum(ones, op.inputs[1]))
+  scaled_grad = math_ops.div(grad, math_ops.segment_sum(ones, op.inputs[1]))
   return array_ops.gather(scaled_grad, op.inputs[1]), None
 
 
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 73e51aab7de..562c0408b94 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1125,7 +1125,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
                                                dtype=x.dtype)
     # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
     binary_tensor = math_ops.floor(random_tensor)
-    ret = x * math_ops.inv(keep_prob) * binary_tensor
+    ret = math_ops.div(x, keep_prob) * binary_tensor
     ret.set_shape(x.get_shape())
     return ret