From 17fdd80807b817b0940e572fea6e555c02ace71f Mon Sep 17 00:00:00 2001
From: Eugene Brevdo <ebrevdo@google.com>
Date: Thu, 20 Apr 2017 13:25:52 -0800
Subject: [PATCH 01/27] [tf contrib seq2seq] Changes to dynamic decoding

1. dynamic_decode now returns a third parameter: the sequence lengths
   from decoding (minibatch entries that finished earlier have a shorter
   sequence length)

2. beam search decoder now uses the gather_tree C++ op

3. the gather_tree c++ op now expects sequence_length to be a matrix shaped
   `[batch_size, beam_width]` (each beam may have its own sequence length).
Change: 153756869
---
 .../seq2seq/kernels/beam_search_ops.cc        | 40 ++++----
 .../contrib/seq2seq/kernels/beam_search_ops.h |  2 +-
 .../seq2seq/kernels/beam_search_ops_gpu.cu.cc |  4 +-
 .../contrib/seq2seq/ops/beam_search_ops.cc    | 11 ++-
 .../kernel_tests/attention_wrapper_test.py    |  2 +-
 .../kernel_tests/beam_search_decoder_test.py  | 96 ++++++++++---------
 .../kernel_tests/beam_search_ops_test.py      |  9 +-
 .../python/kernel_tests/decoder_test.py       | 14 ++-
 .../seq2seq/python/ops/beam_search_decoder.py | 47 ++-------
 .../contrib/seq2seq/python/ops/decoder.py     | 37 ++++---
 10 files changed, 130 insertions(+), 132 deletions(-)

diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
index 3b0568794dc..ec493b84635 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc
@@ -56,14 +56,19 @@ class GatherTreeOp : public OpKernel {
         errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
                                 step_ids_shape.DebugString()));
     OP_REQUIRES(
-        ctx, TensorShapeUtils::IsVector(sequence_length.shape()),
-        errors::InvalidArgument("sequence_length must be a vector, saw shape: ",
+        ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()),
+        errors::InvalidArgument("sequence_length must be a matrix, saw shape: ",
                                 sequence_length.shape().DebugString()));
     OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1),
                 errors::InvalidArgument(
-                    "Inconsistent batch sizes: sequence_length.shape[1] (",
+                    "Inconsistent batch sizes: sequence_length.shape[0] (",
                     sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (",
-                    step_ids_shape.dim_size(0), ")"));
+                    step_ids_shape.dim_size(1), ")"));
+    OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2),
+                errors::InvalidArgument(
+                    "Inconsistent batch sizes: sequence_length.shape[1] (",
+                    sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (",
+                    step_ids_shape.dim_size(2), ")"));
     OP_REQUIRES(
         ctx, step_ids_shape == parent_ids.shape(),
         errors::InvalidArgument(
@@ -74,7 +79,7 @@ class GatherTreeOp : public OpKernel {
     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
     typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
     typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
-    typename TTypes<T>::ConstVec seq_len_t = sequence_length.vec<T>();
+    typename TTypes<T>::ConstMatrix seq_len_t = sequence_length.matrix<T>();
     typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
     functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
                                      seq_len_t, beams_t);
@@ -96,7 +101,7 @@ struct GatherTree<CPUDevice, int32> {
   void operator()(OpKernelContext* ctx, const CPUDevice& d,
                   typename TTypes<int32, 3>::ConstTensor step_ids,
                   typename TTypes<int32, 3>::ConstTensor parent_ids,
-                  typename TTypes<int32>::ConstVec sequence_length,
+                  typename TTypes<int32>::ConstMatrix sequence_length,
                   typename TTypes<int32, 3>::Tensor beams) {
     const int64 max_time = parent_ids.dimension(0);
     const int64 batch_size = parent_ids.dimension(1);
@@ -104,15 +109,10 @@ struct GatherTree<CPUDevice, int32> {
     beams.setConstant(-1);
 
     auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) {
-      int32 seq_len_b = -1;
-      int32 old_batch = -1;
       for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
         const int32 batch = i / beam_width;
         const int32 beam = i % beam_width;
-        if (batch != old_batch) {
-          seq_len_b = sequence_length(batch);
-          old_batch = batch;
-        }
+        int32 seq_len_b = sequence_length(batch, beam);
         if (seq_len_b == 0) {
           continue;
         }
@@ -148,14 +148,14 @@ struct GatherTree<CPUDevice, int32> {
 
 #if GOOGLE_CUDA
 namespace functor {
-#define DECLARE_GPU_SPEC(T)                          \
-  template <>                                        \
-  void GatherTree<GPUDevice, T>::operator()(         \
-      OpKernelContext* ctx, const GPUDevice& d,      \
-      typename TTypes<T, 3>::ConstTensor step_ids,   \
-      typename TTypes<T, 3>::ConstTensor parent_ids, \
-      typename TTypes<T>::ConstVec sequence_length,  \
-      typename TTypes<T, 3>::Tensor beams);          \
+#define DECLARE_GPU_SPEC(T)                            \
+  template <>                                          \
+  void GatherTree<GPUDevice, T>::operator()(           \
+      OpKernelContext* ctx, const GPUDevice& d,        \
+      typename TTypes<T, 3>::ConstTensor step_ids,     \
+      typename TTypes<T, 3>::ConstTensor parent_ids,   \
+      typename TTypes<T>::ConstMatrix sequence_length, \
+      typename TTypes<T, 3>::Tensor beams);            \
   extern template struct GatherTree<GPUDevice, T>;
 
 DECLARE_GPU_SPEC(int32);
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
index 501a2eae848..124d07264e7 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h
@@ -31,7 +31,7 @@ struct GatherTree {
   void operator()(OpKernelContext* ctx, const Device& d,
                   typename TTypes<T, 3>::ConstTensor step_ids,
                   typename TTypes<T, 3>::ConstTensor parent_ids,
-                  typename TTypes<T>::ConstVec sequence_length,
+                  typename TTypes<T>::ConstMatrix sequence_length,
                   typename TTypes<T, 3>::Tensor beams);
 };
 
diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
index 8d8fc810015..e3c0d0bfa98 100644
--- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
+++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc
@@ -33,7 +33,7 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
   CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
     const int32 batch = i / beam_width;
     const int32 beam = i % beam_width;
-    const int32 seq_len_b = ldg(sequence_length + batch);
+    const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
 #define GET_IX(time_ix, beam_ix) \
   (batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
     const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
@@ -59,7 +59,7 @@ struct GatherTree<GPUDevice, T> {
   void operator()(OpKernelContext* ctx, const GPUDevice& d,
                   typename TTypes<T, 3>::ConstTensor step_ids,
                   typename TTypes<T, 3>::ConstTensor parent_ids,
-                  typename TTypes<T>::ConstVec sequence_length,
+                  typename TTypes<T>::ConstMatrix sequence_length,
                   typename TTypes<T, 3>::Tensor beams) {
     const int32 max_time = parent_ids.dimension(0);
     const int32 batch_size = parent_ids.dimension(1);
diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
index c167736d882..6c445cd4606 100644
--- a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
+++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc
@@ -32,17 +32,20 @@ REGISTER_OP("GatherTree")
       ShapeHandle step_ids, parent_ids, sequence_length;
 
       // step_ids, parent_ids, and output are all shaped:
-      //   [batch_size, max_time, beam_width].
-      // sequence_length is shaped [batch_size].
+      //   [max_time, batch_size, beam_width].
+      // sequence_length is shaped [batch_size, beam_width].
       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
-      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sequence_length));
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length));
 
       DimensionHandle batch_size = c->Dim(step_ids, 1);
+      DimensionHandle beam_width = c->Dim(step_ids, 2);
 
       TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
       TF_RETURN_IF_ERROR(
           c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size));
+      TF_RETURN_IF_ERROR(
+          c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width));
 
       c->set_output(0, step_ids);
       return tensorflow::Status::OK();
@@ -58,7 +61,7 @@ TODO(ebrevdo): fill in
 
 step_ids: `[max_time, batch_size, beam_width]`.
 parent_ids: `[max_time, batch_size, beam_width]`.
-sequence_length: `[batch_size]`.
+sequence_length: `[batch_size, beam_width]`.
 beams: `[max_time, batch_size, beam_width]`.
 )doc");
 
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index aa84ae060c9..888479e218e 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -109,7 +109,7 @@ class AttentionWrapperTest(test.TestCase):
             initial_state=cell.zero_state(
                 dtype=dtypes.float32, batch_size=batch_size))
 
-        final_outputs, final_state = decoder.dynamic_decode(my_decoder)
+        final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
 
       self.assertTrue(
           isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
index 512df183171..a72d962d784 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py
@@ -24,6 +24,7 @@ import numpy as np
 from tensorflow.contrib.rnn import core_rnn_cell
 from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
 from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
+from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
 from tensorflow.contrib.seq2seq.python.ops import decoder
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -41,24 +42,32 @@ class TestGatherTree(test.TestCase):
   """Tests the gather_tree function."""
 
   def test_gather_tree(self):
-    predicted_ids = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
-                              [[2, 3, 4], [5, 6, 7],
-                               [8, 9, 10]]]).transpose([1, 0, 2])
-    parent_ids = np.array([
-        [[0, 0, 0], [0, 1, 1], [2, 1, 2]],
-        [[0, 0, 0], [1, 2, 0], [2, 1, 1]],
-    ]).transpose([1, 0, 2])
-    expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
-                                [[2, 4, 4], [7, 6, 6],
-                                 [8, 9, 10]]]).transpose([1, 0, 2])
+    # (max_time = 3, batch_size = 2, beam_width = 3)
 
-    res = beam_search_decoder._gather_tree(
-        ops.convert_to_tensor(predicted_ids), ops.convert_to_tensor(parent_ids))
+    # create (batch_size, max_time, beam_width) matrix and transpose it
+    predicted_ids = np.array(
+        [[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
+         [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
+        dtype=np.int32).transpose([1, 0, 2])
+    parent_ids = np.array(
+        [[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
+         [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
+        dtype=np.int32).transpose([1, 0, 2])
+
+    # sequence_lengths is shaped (batch_size = 2, beam_width = 3)
+    sequence_lengths = [[3, 3, 3], [3, 3, 3]]
+
+    expected_result = np.array(
+        [[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
+         [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
+
+    res = beam_search_ops.gather_tree(
+        predicted_ids, parent_ids, sequence_lengths)
 
     with self.test_session() as sess:
       res_ = sess.run(res)
 
-    np.testing.assert_array_equal(expected_result, res_)
+    self.assertAllEqual(expected_result, res_)
 
 
 class TestEosMasking(test.TestCase):
@@ -80,18 +89,18 @@ class TestEosMasking(test.TestCase):
       probs = sess.run(probs)
       masked = sess.run(masked)
 
-      np.testing.assert_array_equal(probs[0][0], masked[0][0])
-      np.testing.assert_array_equal(probs[0][2], masked[0][2])
-      np.testing.assert_array_equal(probs[1][0], masked[1][0])
+      self.assertAllEqual(probs[0][0], masked[0][0])
+      self.assertAllEqual(probs[0][2], masked[0][2])
+      self.assertAllEqual(probs[1][0], masked[1][0])
 
-      np.testing.assert_equal(masked[0][1][0], 0)
-      np.testing.assert_equal(masked[1][1][0], 0)
-      np.testing.assert_equal(masked[1][2][0], 0)
+      self.assertEqual(masked[0][1][0], 0)
+      self.assertEqual(masked[1][1][0], 0)
+      self.assertEqual(masked[1][2][0], 0)
 
       for i in range(1, 5):
-        np.testing.assert_approx_equal(masked[0][1][i], np.finfo('float32').min)
-        np.testing.assert_approx_equal(masked[1][1][i], np.finfo('float32').min)
-        np.testing.assert_approx_equal(masked[1][2][i], np.finfo('float32').min)
+        self.assertAllClose(masked[0][1][i], np.finfo('float32').min)
+        self.assertAllClose(masked[1][1][i], np.finfo('float32').min)
+        self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
 
 
 class TestBeamStep(test.TestCase):
@@ -142,12 +151,11 @@ class TestBeamStep(test.TestCase):
       outputs_, next_state_, state_, log_probs_ = sess.run(
           [outputs, next_beam_state, beam_state, log_probs])
 
-    np.testing.assert_array_equal(outputs_.predicted_ids, [[3, 3, 2], [2, 2,
-                                                                       1]])
-    np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
-    np.testing.assert_array_equal(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
-    np.testing.assert_array_equal(next_state_.finished, [[False, False, False],
-                                                         [False, False, False]])
+    self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
+    self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
+    self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
+    self.assertAllEqual(next_state_.finished, [[False, False, False],
+                                               [False, False, False]])
 
     expected_log_probs = []
     expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -158,7 +166,7 @@ class TestBeamStep(test.TestCase):
     expected_log_probs[1][0] += log_probs_[1, 2, 2]
     expected_log_probs[1][1] += log_probs_[1, 1, 2]
     expected_log_probs[1][2] += log_probs_[1, 0, 1]
-    np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
+    self.assertAllEqual(next_state_.log_probs, expected_log_probs)
 
   def test_step_with_eos(self):
     dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
@@ -197,12 +205,11 @@ class TestBeamStep(test.TestCase):
       outputs_, next_state_, state_, log_probs_ = sess.run(
           [outputs, next_beam_state, beam_state, log_probs])
 
-    np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
-    np.testing.assert_array_equal(outputs_.predicted_ids, [[0, 3, 2], [2, 0,
-                                                                       1]])
-    np.testing.assert_array_equal(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
-    np.testing.assert_array_equal(next_state_.finished, [[True, False, False],
-                                                         [False, True, False]])
+    self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
+    self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
+    self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
+    self.assertAllEqual(next_state_.finished, [[True, False, False],
+                                               [False, True, False]])
 
     expected_log_probs = []
     expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@@ -211,7 +218,7 @@ class TestBeamStep(test.TestCase):
     expected_log_probs[0][2] += log_probs_[0, 0, 2]
     expected_log_probs[1][0] += log_probs_[1, 1, 2]
     expected_log_probs[1][2] += log_probs_[1, 0, 1]
-    np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
+    self.assertAllEqual(next_state_.log_probs, expected_log_probs)
 
 
 class BeamSearchDecoderTest(test.TestCase):
@@ -259,8 +266,9 @@ class BeamSearchDecoderTest(test.TestCase):
           output_layer=output_layer,
           length_penalty_weight=0.0)
 
-      final_outputs, final_state = decoder.dynamic_decode(
-          bsd, output_time_major=time_major, maximum_iterations=max_out)
+      final_outputs, final_state, final_sequence_lengths = (
+          decoder.dynamic_decode(
+              bsd, output_time_major=time_major, maximum_iterations=max_out))
 
       def _t(shape):
         if time_major:
@@ -284,16 +292,18 @@ class BeamSearchDecoderTest(test.TestCase):
       sess.run(variables.global_variables_initializer())
       sess_results = sess.run({
           'final_outputs': final_outputs,
-          'final_state': final_state
+          'final_state': final_state,
+          'final_sequence_lengths': final_sequence_lengths
       })
 
-      # Mostly a smoke test
-      time_steps = max_out
+      max_sequence_length = np.max(sess_results['final_sequence_lengths'])
+
+      # A smoke test
       self.assertEqual(
-          _t((batch_size, time_steps, beam_width)),
+          _t((batch_size, max_sequence_length, beam_width)),
           sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
       self.assertEqual(
-          _t((batch_size, time_steps, beam_width)), sess_results[
+          _t((batch_size, max_sequence_length, beam_width)), sess_results[
               'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
 
   def testDynamicDecodeRNNBatchMajorNoAttention(self):
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
index 542254854a4..491d87f62d8 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py
@@ -38,7 +38,7 @@ class GatherTreeTest(test.TestCase):
         [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
     parent_ids = _transpose_batch_time(
         [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
-    sequence_length = [3]
+    sequence_length = [[3, 3, 3]]
     expected_result = _transpose_batch_time(
         [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
     beams = beam_search_ops.gather_tree(
@@ -54,7 +54,7 @@ class GatherTreeTest(test.TestCase):
         [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
     parent_ids = _transpose_batch_time(
         [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
-    sequence_length = [3]
+    sequence_length = [[3, 3, 3]]
     with ops.device("/cpu:0"):
       beams = beam_search_ops.gather_tree(
           step_ids=step_ids, parent_ids=parent_ids,
@@ -73,7 +73,7 @@ class GatherTreeTest(test.TestCase):
         [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
     parent_ids = _transpose_batch_time(
         [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
-    sequence_length = [3]
+    sequence_length = [[3, 3, 3]]
     expected_result = _transpose_batch_time(
         [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
     with ops.device("/gpu:0"):
@@ -84,7 +84,8 @@ class GatherTreeTest(test.TestCase):
       self.assertAllEqual(expected_result, beams.eval())
 
   def testGatherTreeBatch(self):
-    sequence_length = [0, 1, 2, 3]
+    # sequence_length is [batch_size, beam_width] = [4, 5]
+    sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5]
 
     with self.test_session(use_gpu=True):
       # (max_time = 4, batch_size = 4, beam_width = 5)
diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
index 340ec9bbb22..96dc7b4beee 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
@@ -60,9 +60,9 @@ class DynamicDecodeRNNTest(test.TestCase):
           initial_state=cell.zero_state(
               dtype=dtypes.float32, batch_size=batch_size))
 
-      final_outputs, final_state = decoder.dynamic_decode(
-          my_decoder, output_time_major=time_major,
-          maximum_iterations=maximum_iterations)
+      final_outputs, final_state, final_sequence_length = (
+          decoder.dynamic_decode(my_decoder, output_time_major=time_major,
+                                 maximum_iterations=maximum_iterations))
 
       def _t(shape):
         if time_major:
@@ -73,6 +73,9 @@ class DynamicDecodeRNNTest(test.TestCase):
           isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
       self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple))
 
+      self.assertEqual(
+          (batch_size,),
+          tuple(final_sequence_length.get_shape().as_list()))
       self.assertEqual(
           _t((batch_size, None, cell_depth)),
           tuple(final_outputs.rnn_output.get_shape().as_list()))
@@ -83,7 +86,8 @@ class DynamicDecodeRNNTest(test.TestCase):
       sess.run(variables.global_variables_initializer())
       sess_results = sess.run({
           "final_outputs": final_outputs,
-          "final_state": final_state
+          "final_state": final_state,
+          "final_sequence_length": final_sequence_length,
       })
 
       # Mostly a smoke test
@@ -131,7 +135,7 @@ class DynamicDecodeRNNTest(test.TestCase):
       # Match the variable scope of dynamic_rnn below so we end up
       # using the same variables
       with vs.variable_scope("root") as scope:
-        final_decoder_outputs, final_decoder_state = decoder.dynamic_decode(
+        final_decoder_outputs, final_decoder_state, _ = decoder.dynamic_decode(
             my_decoder,
             # impute_finished=True ensures outputs and final state
             # match those of dynamic_rnn called with sequence_length not None
diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
index 8f1f74ab09d..55ef21a5a0d 100644
--- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py
@@ -19,9 +19,9 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
-import numpy as np
 
 from tensorflow.contrib.rnn import core_rnn_cell
+from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
 from tensorflow.contrib.seq2seq.python.ops import decoder
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
@@ -33,7 +33,6 @@ from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import embedding_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import script_ops
 from tensorflow.python.util import nest
 
 
@@ -202,20 +201,24 @@ class BeamSearchDecoder(decoder.Decoder):
 
     return (finished, start_inputs, initial_state)
 
-  def finalize(self, outputs, final_state):
+  def finalize(self, outputs, final_state, sequence_lengths):
     """Finalize and return the predicted_ids.
 
     Args:
       outputs: An instance of BeamSearchDecoderOutput.
       final_state: An instance of BeamSearchDecoderState. Passed through to the
         output.
+      sequence_lengths: An `int32` tensor shaped `[batch_size, beam_width]`.
+        The sequence lengths determined for each beam during decode.
 
     Returns:
       outputs: An instance of FinalBeamSearchDecoderOutput where the
         predicted_ids are the result of calling _gather_tree.
       final_state: The same input instance of BeamSearchDecoderState.
     """
-    predicted_ids = _gather_tree(outputs.predicted_ids, outputs.parent_ids)
+    predicted_ids = beam_search_ops.gather_tree(
+        outputs.predicted_ids, outputs.parent_ids,
+        sequence_length=sequence_lengths)
     outputs = FinalBeamSearchDecoderOutput(
         beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
     return outputs, final_state
@@ -536,42 +539,6 @@ def _mask_probs(probs, eos_token, finished):
   return finished_examples + non_finished_examples
 
 
-def _gather_tree_py(values, parents):
-  """Gathers path through a tree backwards from the leave nodes.
-
-  Used to reconstruct beams given their parents.
-
-  Args:
-    values: A [T, batch_size, beam_width] tensor of indices.
-    parents: A [T, batch_size, beam_width] tensor of parent beam ids.
-
-  Returns:
-    The [T, batch_size, beam_width] numpy array of paths. For a given batch
-      entry b, the best path is given by ret[:, b, 0].
-  """
-  num_timesteps = values.shape[0]
-  num_beams = values.shape[2]
-  batch_size = values.shape[1]
-  ret = np.zeros_like(values)  # [T, MB, BW]
-  ret[-1, :, :] = values[-1, :, :]
-  for beam_id in range(num_beams):
-    for batch in range(batch_size):
-      parent = parents[-1][batch][beam_id]
-      for timestep in reversed(range(num_timesteps - 1)):
-        ret[timestep, batch, beam_id] = values[timestep][batch][parent]
-        parent = parents[timestep][batch][parent]
-  # now we are going to return ret as a [ts, mb, bw] tensor
-  return np.array(ret).astype(values.dtype)
-
-
-def _gather_tree(values, parents):
-  """Tensor version of _gather_tree_py."""
-  ret = script_ops.py_func(
-      func=_gather_tree_py, inp=[values, parents], Tout=values.dtype)
-  ret.set_shape(values.get_shape().as_list())
-  return ret
-
-
 def _tensor_gather_helper(gather_indices, gather_from, range_input, range_size,
                           final_shape):
   range_ = array_ops.expand_dims(math_ops.range(range_input) * range_size, 1)
diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py
index ee287b0cf65..ff705715e01 100644
--- a/tensorflow/contrib/seq2seq/python/ops/decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py
@@ -154,11 +154,11 @@ def dynamic_decode(decoder,
     scope: Optional variable scope to use.
 
   Returns:
-    `(final_outputs, final_state)`.
+    `(final_outputs, final_state, final_sequence_lengths)`.
 
   Raises:
     TypeError: if `decoder` is not an instance of `Decoder`.
-    ValueError: if maximum_iterations is provided but is not a scalar.
+    ValueError: if `maximum_iterations` is provided but is not a scalar.
   """
   if not isinstance(decoder, Decoder):
     raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
@@ -184,6 +184,8 @@ def dynamic_decode(decoder,
     if maximum_iterations is not None:
       initial_finished = math_ops.logical_or(
           initial_finished, 0 >= maximum_iterations)
+    initial_sequence_lengths = array_ops.zeros_like(
+        initial_finished, dtype=dtypes.int32)
     initial_time = constant_op.constant(0, dtype=dtypes.int32)
 
     def _shape(batch_size, from_shape):
@@ -206,10 +208,10 @@ def dynamic_decode(decoder,
                                             decoder.output_dtype)
 
     def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
-                  finished):
+                  finished, unused_sequence_lengths):
       return math_ops.logical_not(math_ops.reduce_all(finished))
 
-    def body(time, outputs_ta, state, inputs, finished):
+    def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
       """Internal while_loop body.
 
       Args:
@@ -217,10 +219,13 @@ def dynamic_decode(decoder,
         outputs_ta: structure of TensorArray.
         state: (structure of) state tensors and TensorArrays.
         inputs: (structure of) input tensors.
-        finished: 1-D bool tensor.
+        finished: bool tensor (keeping track of what's finished).
+        sequence_lengths: int32 tensor (keeping track of time of finish).
 
       Returns:
-        `(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
+        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
+          next_sequence_lengths)`.
+        ```
       """
       (next_outputs, decoder_state, next_inputs,
        decoder_finished) = decoder.step(time, inputs, state)
@@ -228,6 +233,10 @@ def dynamic_decode(decoder,
       if maximum_iterations is not None:
         next_finished = math_ops.logical_or(
             next_finished, time + 1 >= maximum_iterations)
+      next_sequence_lengths = array_ops.where(
+          math_ops.logical_and(math_ops.logical_not(finished), next_finished),
+          array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
+          sequence_lengths)
 
       nest.assert_same_structure(state, decoder_state)
       nest.assert_same_structure(outputs_ta, next_outputs)
@@ -260,26 +269,30 @@ def dynamic_decode(decoder,
 
       outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
                                       outputs_ta, emit)
-      return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
+      return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
+              next_sequence_lengths)
 
     res = control_flow_ops.while_loop(
         condition,
         body,
         loop_vars=[
             initial_time, initial_outputs_ta, initial_state, initial_inputs,
-            initial_finished
+            initial_finished, initial_sequence_lengths,
         ],
         parallel_iterations=parallel_iterations,
         swap_memory=swap_memory)
 
     final_outputs_ta = res[1]
     final_state = res[2]
+    final_sequence_lengths = res[5]
 
     final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
+
+    if hasattr(decoder, "finalize"):
+      final_outputs, final_state = decoder.finalize(
+          final_outputs, final_state, final_sequence_lengths)
+
     if not output_time_major:
       final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
 
-  if hasattr(decoder, "finalize"):
-    final_outputs, final_state = decoder.finalize(final_outputs, final_state)
-
-  return final_outputs, final_state
+  return final_outputs, final_state, final_sequence_lengths

From 20a52139efd7ddf34740db99d993c66976fc94c1 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 20 Apr 2017 13:34:58 -0800
Subject: [PATCH 02/27] Add unreduced NONE, and reduced MEAN options for
 losses. Remove "WEIGHTED_" prefix from other Reduction constants. Change:
 153758104

---
 tensorflow/python/kernel_tests/losses_test.py | 77 +++++++++++++---
 tensorflow/python/ops/losses/losses_impl.py   | 90 ++++++++++++-------
 .../golden/tensorflow.losses.-reduction.pbtxt | 12 ++-
 3 files changed, 134 insertions(+), 45 deletions(-)

diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 15eeb762d8f..40fddd76ffa 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -447,7 +447,8 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
                                      [-100.0, -100.0, 100.0]])
       labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
       loss = losses.sigmoid_cross_entropy(labels, logits)
-      self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+      self.assertEquals(logits.dtype, loss.dtype)
+      self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name)
       self.assertAlmostEqual(0.0, loss.eval(), 3)
 
   def testLossWithSingleDimPlaceholderForLogitsAndWeights1(self):
@@ -456,6 +457,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
     weights = array_ops.ones_like(logits, dtype=dtypes.float32)
 
     loss = losses.sigmoid_cross_entropy(labels, logits, weights)
+    self.assertEquals(logits.dtype, loss.dtype)
 
     with self.test_session() as sess:
       loss = sess.run(loss,
@@ -471,6 +473,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
     weights = array_ops.ones_like(logits, dtype=dtypes.float32)
 
     loss = losses.sigmoid_cross_entropy(labels, logits, weights)
+    self.assertEquals(logits.dtype, loss.dtype)
 
     with self.test_session() as sess:
       loss = sess.run(loss,
@@ -487,7 +490,8 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
                                      [-100.0, -100.0, 100.0]])
       labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
       loss = losses.sigmoid_cross_entropy(labels, logits)
-      self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+      self.assertEquals(logits.dtype, loss.dtype)
+      self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name)
       self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
 
   def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
@@ -498,7 +502,8 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
       labels = constant_op.constant([[0, 0, 1], [1, 0, 0], [0, 1, 0]])
       weights = constant_op.constant([[3, 4, 5], [2, 6, 0], [8, 0, 1]])
       loss = losses.sigmoid_cross_entropy(labels, logits, weights)
-      self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+      self.assertEquals(logits.dtype, loss.dtype)
+      self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name)
       self.assertAlmostEqual(1700.0 / 7.0, loss.eval(), 3)
 
   def testMultiCorrectSigmoid(self):
@@ -507,10 +512,43 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
                                    [-100.0, 100.0, 100.0]])
     labels = constant_op.constant([[1, 0, 1], [1, 1, 0], [0, 1, 1]])
     loss = losses.sigmoid_cross_entropy(labels, logits)
-    self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+    self.assertEquals(logits.dtype, loss.dtype)
+    self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name)
 
     with self.test_session():
-      self.assertAlmostEqual(loss.eval(), 0.0, 3)
+      self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+  def testSigmoidFloat64(self):
+    logits = constant_op.constant((
+        (100.0, -100.0, 100.0),
+        (100.0, -100.0, 100.0),
+        (100.0, 100.0, -100.0)
+    ), dtype=dtypes.float64)
+    labels = constant_op.constant((
+        (1, 0, 1), (1, 1, 0), (0, 1, 1)
+    ), dtype=dtypes.int64)
+    loss = losses.sigmoid_cross_entropy(labels, logits)
+    self.assertEquals(logits.dtype, loss.dtype)
+
+    with self.test_session():
+      self.assertAlmostEqual(44.444, loss.eval(), 3)
+
+  def testSigmoidNoReduction(self):
+    logits = constant_op.constant((
+        (100.0, -100.0, 100.0),
+        (100.0, -100.0, 100.0),
+        (100.0, 100.0, -100.0)))
+    labels = constant_op.constant(((1, 0, 1), (1, 1, 0), (0, 1, 1)))
+    loss = losses.sigmoid_cross_entropy(
+        labels, logits, reduction=losses.Reduction.NONE)
+    self.assertEquals(logits.dtype, loss.dtype)
+
+    with self.test_session():
+      self.assertAllClose((
+          (0., 0., 0.),
+          (0., 100., 100.),
+          (100., 0., 100.)
+      ), loss.eval(), 3)
 
   def testSigmoidLabelSmoothingCorrect(self):
     with self.test_session():
@@ -530,7 +568,8 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
       label_smoothing = 0.1
       loss = losses.sigmoid_cross_entropy(
           labels, logits, label_smoothing=label_smoothing)
-      self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+      self.assertEquals(logits.dtype, loss.dtype)
+      self.assertEquals('sigmoid_cross_entropy_loss/value', loss.op.name)
       expected_value = (100.0 + 50.0 * label_smoothing) / 3.0
       self.assertAlmostEqual(loss.eval(), expected_value, 3)
 
@@ -541,6 +580,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
       sigmoid_labels = constant_op.constant([[1, 0, 1]])
       sigmoid_loss = losses.sigmoid_cross_entropy(
           sigmoid_labels, sigmoid_logits, label_smoothing=label_smoothing)
+      self.assertEquals(sigmoid_logits.dtype, sigmoid_loss.dtype)
 
       softmax_logits = constant_op.constant(
           [[0.0, 100.0], [100.0, 0.0], [100.0, 0.0]])
@@ -1254,10 +1294,14 @@ class ComputeWeightedLossTest(test.TestCase):
         self.assertEqual(9, len(util.get_losses()))
         with self.test_session(g):
           for unweighted_loss in unweighted_losses:
-            if reduction == losses.Reduction.WEIGHTED_SUM:
+            if reduction == losses.Reduction.NONE:
+              self.assertAllClose(self._raw_losses, unweighted_loss.eval())
+            elif reduction == losses.Reduction.SUM:
               self.assertAllClose(
                   np.sum(self._raw_losses), unweighted_loss.eval())
-            else:  # losses.Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS
+            else:
+              # reduction one of losses.Reduction.MEAN and
+              # losses.Reduction.SUM_BY_NONZERO_WEIGHTS.
               self.assertAllClose(
                   np.mean(self._raw_losses), unweighted_loss.eval())
 
@@ -1341,13 +1385,20 @@ class ComputeWeightedLossTest(test.TestCase):
         with self.test_session(g):
           weighted_losses = weights * self._raw_losses
           weighted_sum = np.sum(weighted_losses)
-          if reduction == losses.Reduction.WEIGHTED_SUM:
+          if reduction == losses.Reduction.NONE:
+            self.assertAllClose(weighted_losses, weighted_loss.eval())
+          elif reduction == losses.Reduction.SUM:
             self.assertAllClose(weighted_sum, weighted_loss.eval())
-          else:  # losses.Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS
+          else:
             broadcast_weights = weights * np.ones_like(self._raw_losses)
-            self.assertAllClose(
-                weighted_sum / np.count_nonzero(broadcast_weights),
-                weighted_loss.eval())
+            if reduction == losses.Reduction.MEAN:
+              self.assertAllClose(
+                  weighted_sum / np.sum(broadcast_weights),
+                  weighted_loss.eval())
+            elif reduction == losses.Reduction.SUM_BY_NONZERO_WEIGHTS:
+              self.assertAllClose(
+                  weighted_sum / np.count_nonzero(broadcast_weights),
+                  weighted_loss.eval())
 
   def test1x1x1Weight(self):
     self._test_valid_weights((((17.0,),),))
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 45075d4d3c4..fc54553b0c3 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -27,21 +27,31 @@ from tensorflow.python.ops import nn
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import weights_broadcast_ops
 from tensorflow.python.ops.losses import util
+from tensorflow.python.platform import tf_logging as logging
 
 
-# TODO(ptucker): Per-example? Divided by batch_size? Divided by sum of weights?
 class Reduction(object):
   """Types of loss reduction."""
 
-  # Batch sum of weighted losses.
-  WEIGHTED_SUM = "weighted_sum"
+  # Un-reduced weighted losses with the same shape as input.
+  NONE = "none"
 
-  # `WEIGHTED_SUM` divided by number of non-zero weights.
-  WEIGHTED_SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"
+  # Scalar sum of `NONE`.
+  SUM = "weighted_sum"
+
+  # Scalar `SUM` divided by sum of weights.
+  MEAN = "weighted_mean"
+
+  # Scalar `SUM` divided by number of non-zero weights.
+  SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"
 
   @classmethod
   def all(cls):
-    return (cls.WEIGHTED_SUM, cls.WEIGHTED_SUM_BY_NONZERO_WEIGHTS)
+    return (
+        cls.NONE,
+        cls.SUM,
+        cls.MEAN,
+        cls.SUM_BY_NONZERO_WEIGHTS)
 
   @classmethod
   def validate(cls, key):
@@ -127,7 +137,7 @@ def _num_present(losses, weights, per_batch=False):
 
 def compute_weighted_loss(
     losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
-    reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Computes the weighted loss.
 
   Args:
@@ -140,7 +150,8 @@ def compute_weighted_loss(
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss `Tensor` of the same type as `losses`. If `reduction` is
+    `NONE`, this has the same shape as `losses`; otherwise, it is scalar.
 
   Raises:
     ValueError: If `weights` is `None` or the shape is not compatible with
@@ -156,9 +167,16 @@ def compute_weighted_loss(
       losses = math_ops.to_float(losses)
       weights = math_ops.to_float(weights)
       weighted_losses = math_ops.multiply(losses, weights)
-      loss = math_ops.reduce_sum(weighted_losses)
-      if reduction == Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS:
-        loss = _safe_mean(loss, _num_present(losses, weights))
+      if reduction == Reduction.NONE:
+        loss = weighted_losses
+      else:
+        loss = math_ops.reduce_sum(weighted_losses)
+        if reduction == Reduction.MEAN:
+          loss = _safe_mean(
+              loss,
+              math_ops.reduce_sum(array_ops.ones_like(losses) * weights))
+        elif reduction == Reduction.SUM_BY_NONZERO_WEIGHTS:
+          loss = _safe_mean(loss, _num_present(losses, weights))
 
       # Convert the result back to the input type.
       loss = math_ops.cast(loss, input_dtype)
@@ -169,7 +187,7 @@ def compute_weighted_loss(
 def absolute_difference(
     labels, predictions, weights=1.0, scope=None,
     loss_collection=ops.GraphKeys.LOSSES,
-    reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Adds an Absolute Difference loss to the training procedure.
 
   `weights` acts as a coefficient for the loss. If a scalar is provided, then
@@ -191,7 +209,8 @@ def absolute_difference(
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
+    shape as `labels`; otherwise, it is scalar.
 
   Raises:
     ValueError: If the shape of `predictions` doesn't match that of `labels` or
@@ -210,7 +229,7 @@ def absolute_difference(
 def cosine_distance(
     labels, predictions, dim=None, weights=1.0, scope=None,
     loss_collection=ops.GraphKeys.LOSSES,
-    reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Adds a cosine-distance loss to the training procedure.
 
   Note that the function assumes that `predictions` and `labels` are already
@@ -228,7 +247,8 @@ def cosine_distance(
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
+    shape as `labels`; otherwise, it is scalar.
 
   Raises:
     ValueError: If `predictions` shape doesn't match `labels` shape, or
@@ -250,7 +270,7 @@ def cosine_distance(
 
 def hinge_loss(labels, logits, weights=1.0, scope=None,
                loss_collection=ops.GraphKeys.LOSSES,
-               reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+               reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Adds a hinge loss to the training procedure.
 
   Args:
@@ -265,7 +285,8 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
+    shape as `labels`; otherwise, it is scalar.
 
   Raises:
     ValueError: If the shapes of `logits` and `labels` don't match.
@@ -285,7 +306,7 @@ def hinge_loss(labels, logits, weights=1.0, scope=None,
 
 def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
                loss_collection=ops.GraphKeys.LOSSES,
-               reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+               reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Adds a Huber Loss term to the training procedure.
 
   For each value x in `error=labels-predictions`, the following is calculated:
@@ -320,7 +341,8 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
+    shape as `labels`; otherwise, it is scalar.
 
   Raises:
     ValueError: If the shape of `predictions` doesn't match that of `labels` or
@@ -347,7 +369,7 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None,
 
 def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
              loss_collection=ops.GraphKeys.LOSSES,
-             reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+             reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Adds a Log Loss term to the training procedure.
 
   `weights` acts as a coefficient for the loss. If a scalar is provided, then
@@ -370,7 +392,8 @@ def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
+    shape as `labels`; otherwise, it is scalar.
 
   Raises:
     ValueError: If the shape of `predictions` doesn't match that of `labels` or
@@ -474,7 +497,7 @@ def mean_pairwise_squared_error(
 def mean_squared_error(
     labels, predictions, weights=1.0, scope=None,
     loss_collection=ops.GraphKeys.LOSSES,
-    reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Adds a Sum-of-Squares loss to the training procedure.
 
   `weights` acts as a coefficient for the loss. If a scalar is provided, then
@@ -496,7 +519,8 @@ def mean_squared_error(
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
+    shape as `labels`; otherwise, it is scalar.
 
   Raises:
     ValueError: If the shape of `predictions` doesn't match that of `labels` or
@@ -515,7 +539,7 @@ def mean_squared_error(
 def sigmoid_cross_entropy(
     multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None,
     loss_collection=ops.GraphKeys.LOSSES,
-    reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits.
 
   `weights` acts as a coefficient for the loss. If a scalar is provided,
@@ -531,7 +555,7 @@ def sigmoid_cross_entropy(
   Args:
     multi_class_labels: `[batch_size, num_classes]` target integer labels in
       `(0, 1)`.
-    logits: `[batch_size, num_classes]` logits outputs of the network.
+    logits: Float `[batch_size, num_classes]` logits outputs of the network.
     weights: Optional `Tensor` whose rank is either 0, or the same rank as
       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
       be either `1`, or the same as the corresponding `losses` dimension).
@@ -541,7 +565,8 @@ def sigmoid_cross_entropy(
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss `Tensor` of the same type as `logits`. If `reduction` is
+    `NONE`, this has the same shape as `logits`; otherwise, it is scalar.
 
   Raises:
     ValueError: If the shape of `logits` doesn't match that of
@@ -551,7 +576,9 @@ def sigmoid_cross_entropy(
   with ops.name_scope(scope, "sigmoid_cross_entropy_loss",
                       (logits, multi_class_labels, weights)) as scope:
     logits = ops.convert_to_tensor(logits)
+    logging.info("logits.dtype=%s.", logits.dtype)
     multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
+    logging.info("multi_class_labels.dtype=%s.", multi_class_labels.dtype)
     logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
 
     if label_smoothing > 0:
@@ -561,6 +588,7 @@ def sigmoid_cross_entropy(
     losses = nn.sigmoid_cross_entropy_with_logits(labels=multi_class_labels,
                                                   logits=logits,
                                                   name="xentropy")
+    logging.info("losses.dtype=%s.", losses.dtype)
     return compute_weighted_loss(
         losses, weights, scope, loss_collection, reduction=reduction)
 
@@ -568,7 +596,7 @@ def sigmoid_cross_entropy(
 def softmax_cross_entropy(
     onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
     loss_collection=ops.GraphKeys.LOSSES,
-    reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits.
 
   `weights` acts as a coefficient for the loss. If a scalar is provided,
@@ -593,7 +621,8 @@ def softmax_cross_entropy(
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss `Tensor` of the same type as `logits`. If `reduction` is
+    `NONE`, this has shape `[batch_size]`; otherwise, it is scalar.
 
   Raises:
     ValueError: If the shape of `logits` doesn't match that of `onehot_labels`
@@ -673,7 +702,7 @@ def _remove_squeezable_dimensions(
 def sparse_softmax_cross_entropy(
     labels, logits, weights=1.0, scope=None,
     loss_collection=ops.GraphKeys.LOSSES,
-    reduction=Reduction.WEIGHTED_SUM_BY_NONZERO_WEIGHTS):
+    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
   """Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`.
 
   `weights` acts as a coefficient for the loss. If a scalar is provided,
@@ -696,7 +725,8 @@ def sparse_softmax_cross_entropy(
     reduction: Type of reduction to apply to loss.
 
   Returns:
-    A scalar `Tensor` that returns the weighted loss.
+    Weighted loss `Tensor` of the same type as `logits`. If `reduction` is
+    `NONE`, this has the same shape as `labels`; otherwise, it is scalar.
 
   Raises:
     ValueError: If the shapes of logits, labels, and weight are incompatible, or
diff --git a/tensorflow/tools/api/golden/tensorflow.losses.-reduction.pbtxt b/tensorflow/tools/api/golden/tensorflow.losses.-reduction.pbtxt
index f2fc1f003e7..4bdc73370bf 100644
--- a/tensorflow/tools/api/golden/tensorflow.losses.-reduction.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.losses.-reduction.pbtxt
@@ -3,11 +3,19 @@ tf_class {
   is_instance: "<class \'tensorflow.python.ops.losses.losses_impl.Reduction\'>"
   is_instance: "<type \'object\'>"
   member {
-    name: "WEIGHTED_SUM"
+    name: "MEAN"
     mtype: "<type \'str\'>"
   }
   member {
-    name: "WEIGHTED_SUM_BY_NONZERO_WEIGHTS"
+    name: "NONE"
+    mtype: "<type \'str\'>"
+  }
+  member {
+    name: "SUM"
+    mtype: "<type \'str\'>"
+  }
+  member {
+    name: "SUM_BY_NONZERO_WEIGHTS"
     mtype: "<type \'str\'>"
   }
   member_method {

From f942e8c21658e20cdba20b3bba025a71cbb5e60a Mon Sep 17 00:00:00 2001
From: Stephan Hoyer <shoyer@google.com>
Date: Thu, 20 Apr 2017 15:34:26 -0800
Subject: [PATCH 03/27] Fix typo in "First TensorFlow program" from open source
 README

This is sort of pedantic, but the Python prompt prints strings with quotes.
Change: 153772799
---
 README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/README.md b/README.md
index cd0bffde796..3ab47736813 100644
--- a/README.md
+++ b/README.md
@@ -52,7 +52,7 @@ $ python
 >>> hello = tf.constant('Hello, TensorFlow!')
 >>> sess = tf.Session()
 >>> sess.run(hello)
-Hello, TensorFlow!
+'Hello, TensorFlow!'
 >>> a = tf.constant(10)
 >>> b = tf.constant(32)
 >>> sess.run(a+b)

From 7c068a3784d797ebdfb8f768ecb486479cbc4b93 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 20 Apr 2017 16:16:07 -0800
Subject: [PATCH 04/27] Fix _get_arguments to work with partials. Change:
 153776958

---
 tensorflow/python/estimator/estimator.py      |  45 ++++---
 tensorflow/python/estimator/estimator_test.py | 115 ++++++++++++++++++
 2 files changed, 142 insertions(+), 18 deletions(-)

diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 449cb54c841..a1b0d9358d4 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -524,7 +524,7 @@ class Estimator(object):
     Raises:
       ValueError: if model_fn returns invalid objects.
     """
-    model_fn_args = _get_arguments(self._model_fn).args
+    model_fn_args = _model_fn_args(self._model_fn)
     kwargs = {}
     if 'mode' in model_fn_args:
       kwargs['mode'] = mode
@@ -704,35 +704,44 @@ def _get_replica_device_setter(config):
     return None
 
 
-def _get_arguments(func):
-  """Returns a spec of given func."""
-  if hasattr(func, '__code__'):
-    # Regular function.
-    return inspect.getargspec(func)
-  elif hasattr(func, '__call__'):
-    # Callable object.
-    return _get_arguments(func.__call__)
-  elif hasattr(func, 'func'):
-    # Partial function.
-    return _get_arguments(func.func)
+def _model_fn_args(fn):
+  """Get argument names for function-like object.
+
+  Args:
+    fn: Function, or function-like object (e.g., result of `functools.partial`).
+
+  Returns:
+    `tuple` of string argument names.
+
+  Raises:
+    ValueError: if partial function has positionally bound arguments
+  """
+  if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
+    # Handle functools.partial and similar objects.
+    return tuple([
+        arg for arg in inspect.getargspec(fn.func).args[len(fn.args):]
+        if arg not in set(fn.keywords.keys())
+    ])
+  # Handle function.
+  return tuple(inspect.getargspec(fn).args)
 
 
 def _verify_model_fn_args(model_fn, params):
   """Verifies model fn arguments."""
-  fn_spec = _get_arguments(model_fn)
-  if 'features' not in fn_spec.args:
+  args = _model_fn_args(model_fn)
+  if 'features' not in args:
     raise ValueError('model_fn (%s) must include features argument.' % model_fn)
-  if 'labels' not in fn_spec.args:
+  if 'labels' not in args:
     raise ValueError('model_fn (%s) must include labels argument.' % model_fn)
-  if params is not None and 'params' not in fn_spec.args:
+  if params is not None and 'params' not in args:
     raise ValueError('model_fn (%s) does not include params argument, '
                      'but params (%s) is passed to Estimator.' % (model_fn,
                                                                   params))
-  if params is None and 'params' in fn_spec.args:
+  if params is None and 'params' in args:
     logging.warning('Estimator\'s model_fn (%s) includes params '
                     'argument, but params are not passed to Estimator.',
                     model_fn)
-  non_valid_args = list(set(fn_spec.args) - _VALID_MODEL_FN_ARGS)
+  non_valid_args = list(set(args) - _VALID_MODEL_FN_ARGS)
   if non_valid_args:
     raise ValueError('model_fn (%s) has following not expected args: %s' %
                      (model_fn, non_valid_args))
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 89a9483e201..3b46db59e30 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -18,10 +18,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
 import os
 import tempfile
 
 import numpy as np
+import six
 
 from google.protobuf import text_format
 
@@ -38,6 +40,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.layers import layers
 from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import init_ops
@@ -262,8 +265,120 @@ def model_fn_global_step_incrementer(features, labels, mode):
       train_op=state_ops.assign_add(global_step, 1))
 
 
+def _estimator_spec(
+    expected_features, expected_labels, actual_features, actual_labels, mode):
+  assert_ops = tuple([
+      check_ops.assert_equal(
+          expected_features[k], actual_features[k], name='assert_%s' % k)
+      for k in expected_features
+  ] + [
+      check_ops.assert_equal(
+          expected_labels, actual_labels, name='assert_labels')
+  ])
+  with ops.control_dependencies(assert_ops):
+    return model_fn_lib.EstimatorSpec(
+        mode=mode,
+        predictions=constant_op.constant(0.),
+        loss=constant_op.constant(0.),
+        train_op=constant_op.constant(0.))
+
+
+def _make_input_fn(features, labels):
+  def _input_fn():
+    return {
+        k: constant_op.constant(v)
+        for k, v in six.iteritems(features)
+    }, constant_op.constant(labels)
+  return _input_fn
+
+
 class EstimatorTrainTest(test.TestCase):
 
+  def test_minimal_model_fn_args(self):
+    expected_features = {'x': 42., 'y': 43.}
+    expected_labels = 44.
+
+    # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments
+    # doesn't work with mock fns.
+    model_fn_call_count = [0]
+
+    def _model_fn(features, labels):
+      model_fn_call_count[0] += 1
+      self.assertItemsEqual(expected_features.keys(), features.keys())
+      return _estimator_spec(
+          expected_features, expected_labels, features, labels,
+          model_fn_lib.ModeKeys.TRAIN)
+
+    with self.assertRaisesRegexp(ValueError, 'does not include params'):
+      estimator.Estimator(model_fn=_model_fn, params={'a': 'b'})
+    est = estimator.Estimator(model_fn=_model_fn, config=run_config.RunConfig())
+    self.assertEqual(0, model_fn_call_count[0])
+    est.train(
+        input_fn=_make_input_fn(expected_features, expected_labels), steps=1)
+    self.assertEqual(1, model_fn_call_count[0])
+
+  def test_all_model_fn_args(self):
+    expected_features = {'x': 42., 'y': 43.}
+    expected_labels = 44.
+    expected_params = {'some_param': 'some_value'}
+    expected_config = run_config.RunConfig()
+    expected_config.i_am_test = True
+
+    # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments
+    # doesn't work with mock fns.
+    model_fn_call_count = [0]
+
+    # Note that args are all passed by keyword, so can be in any order.
+    def _model_fn(mode, params, features, labels, config):
+      model_fn_call_count[0] += 1
+      self.assertItemsEqual(expected_features.keys(), features.keys())
+      self.assertEqual(model_fn_lib.ModeKeys.TRAIN, mode)
+      self.assertEqual(expected_params, params)
+      self.assertTrue(config.i_am_test)
+      return _estimator_spec(
+          expected_features, expected_labels, features, labels, mode)
+
+    est = estimator.Estimator(
+        model_fn=_model_fn, params=expected_params, config=expected_config)
+    self.assertEqual(0, model_fn_call_count[0])
+    est.train(
+        input_fn=_make_input_fn(expected_features, expected_labels), steps=1)
+    self.assertEqual(1, model_fn_call_count[0])
+
+  def test_partial_model_fn_args(self):
+    expected_features = {'x': 42., 'y': 43.}
+    expected_labels = 44.
+    expected_params = {'some_param': 'some_value'}
+    expected_config = run_config.RunConfig()
+    expected_config.i_am_test = True
+    expected_foo = 45.
+    expected_bar = 46.
+
+    # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments
+    # doesn't work with mock fns.
+    model_fn_call_count = [0]
+
+    def _model_fn(features, labels, foo, mode, params, config, bar):
+      model_fn_call_count[0] += 1
+      self.assertEqual(expected_foo, foo)
+      self.assertEqual(expected_bar, bar)
+      self.assertItemsEqual(expected_features.keys(), features.keys())
+      self.assertEqual(model_fn_lib.ModeKeys.TRAIN, mode)
+      self.assertEqual(expected_params, params)
+      self.assertTrue(config.i_am_test)
+      return _estimator_spec(
+          expected_features, expected_labels, features, labels, mode)
+    partial_model_fn = functools.partial(
+        _model_fn, foo=expected_foo, bar=expected_bar)
+
+    est = estimator.Estimator(
+        model_fn=partial_model_fn, params=expected_params,
+        config=expected_config)
+    self.assertEqual(0, model_fn_call_count[0])
+    est.train(
+        input_fn=_make_input_fn(expected_features, expected_labels), steps=1)
+    self.assertEqual(1, model_fn_call_count[0])
+
   def test_model_fn_must_return_estimator_spec(self):
 
     def model_fn(features, labels):

From 35d358f92dad1aaafd6afe4995db7b1db0e53d13 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Dandelion=20Man=C3=A9?= <dandelion@google.com>
Date: Thu, 20 Apr 2017 16:57:28 -0800
Subject: [PATCH 05/27] Fix serious bug which broke TensorBoard when there are
 many scalar runs and few text runs. Change: 153780710

---
 .../tf_dashboard_common/tf-multi-checkbox.html | 18 +-----------------
 1 file changed, 1 insertion(+), 17 deletions(-)

diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html
index ac407844f0b..e2c99772072 100644
--- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html
+++ b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html
@@ -183,6 +183,7 @@ handle these situations gracefully.
       // if undefined, default value (enable for first k runs, disable after).
         type: Object,
         value: TF.URIStorage.getObjectInitializer('runSelectionState', {}),
+        observer: "_storeRunToIsCheckedMapping",
       },
       // (Allows state to persist across regex filtering)
       outSelected: {
@@ -231,24 +232,7 @@ handle these situations gracefully.
     },
     observers: [
       "_setIsolatorIcon(runSelectionState, names)",
-      "_storeRunToIsCheckedMappingWithDefault(runSelectionState, namesMatchingRegex)",
     ],
-    _storeRunToIsCheckedMappingWithDefault() {
-      var runSelectionStateIsDefault = Object.keys(this.runSelectionState).length == 0;
-      if (runSelectionStateIsDefault || this.namesMatchingRegex == null) {
-        return;
-      }
-      var _this = this;
-      var allToggledOn = this.namesMatchingRegex
-              .every(function(n) {return _this.runSelectionState[n]});
-      var allToggledOff = this.namesMatchingRegex
-              .every(function(n) {return !_this.runSelectionState[n]});
-      var defaultOff = this.namesMatchingRegex.length > this.maxRunsToEnableByDefault;
-      if (defaultOff && allToggledOff || !defaultOff && allToggledOn) {
-        this.runSelectionState = {};
-      }
-      this._storeRunToIsCheckedMapping(this.runSelectionState);
-    },
     _storeRunToIsCheckedMapping: TF.URIStorage.getObjectObserver('runSelectionState', {}),
     _makeRegex: function(regex) {
       try {

From 7dba5ab8740064f916bfc127c910c02a01eabe11 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Dandelion=20Man=C3=A9?= <dandelion@google.com>
Date: Thu, 20 Apr 2017 16:58:00 -0800
Subject: [PATCH 06/27] Autogenerated Change: Change TensorBoard TAG to 54
 Change: 153780747

---
 tensorflow/tensorboard/TAG | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/tensorboard/TAG b/tensorflow/tensorboard/TAG
index 59343b09ec7..fb1e7bc8699 100644
--- a/tensorflow/tensorboard/TAG
+++ b/tensorflow/tensorboard/TAG
@@ -1 +1 @@
-53
+54

From c1bd0fe248c63b58b0b663a8c8529791354fdf75 Mon Sep 17 00:00:00 2001
From: Bjarke Hammersholt Roune <broune@google.com>
Date: Thu, 20 Apr 2017 17:20:10 -0800
Subject: [PATCH 07/27] - Bug fix in ShapeInference  - CommaSeparatedString and
 VectorString added to xla_util.h  - ReferenceUtil can now do more general Pad
 ops. Change: 153782516

---
 tensorflow/compiler/xla/reference_util.cc     | 35 ++++++++++++++++
 tensorflow/compiler/xla/reference_util.h      |  5 +++
 .../compiler/xla/service/shape_inference.cc   |  6 ++-
 tensorflow/compiler/xla/util.h                | 41 +++++++++++++++++++
 tensorflow/compiler/xla/util_test.cc          | 20 +++++++++
 5 files changed, 106 insertions(+), 1 deletion(-)

diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 86c9c3b1ac3..5630033ac89 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -649,4 +649,39 @@ ReferenceUtil::ReduceToRowArray2D(
   return result;
 }
 
+/* static */ Array4D<float> ReferenceUtil::PadArray4D(
+    const Array4D<float>& operand, const PaddingConfig& padding,
+    const float pad) {
+  CHECK_EQ(padding.dimensions_size(), 4);
+
+  const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
+                                           operand.n3(), operand.n4()};
+  std::vector<int64> pad_low(4);
+  std::vector<int64> pad_high(4);
+  std::vector<int64> output_bounds(4);
+  for (int64 i = 0; i < 4; ++i) {
+    pad_low[i] = padding.dimensions(i).edge_padding_low();
+    pad_high[i] = padding.dimensions(i).edge_padding_high();
+    CHECK_EQ(padding.dimensions(i).interior_padding(), 0) << "not implemented";
+
+    output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i];
+  }
+
+  Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2],
+                        output_bounds[3]);
+  result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
+    for (int i = 0; i < 4; ++i) {
+      bool in_low_padding = indices[i] < pad_low[i];
+      bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
+      if (in_low_padding || in_high_padding) {
+        *value = pad;
+        return;
+      }
+    }
+    *value = operand(indices[0] - pad_low[0], indices[1] - pad_low[1],
+                     indices[2] - pad_low[2], indices[3] - pad_low[3]);
+  });
+  return result;
+}
+
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 9e0f2472038..eb1eea7fc4c 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -395,6 +395,11 @@ class ReferenceUtil {
       const Array2D<float>& operand, const PaddingConfig& padding,
       const float pad);
 
+  // Returns the result of a 4D pad on an input array.
+  static Array4D<float> PadArray4D(const Array4D<float>& operand,
+                                   const PaddingConfig& padding,
+                                   const float pad);
+
  private:
   TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
 };
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 9472086e2b4..338d63f1a00 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -309,6 +309,10 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
     return InvalidArgument(
         "the rank of the operand and the padding configuration do not match.");
   }
+  if (operand_shape.element_type() != padding_value_shape.element_type()) {
+    return InvalidArgument(
+        "the element types of the operands to pad do not match");
+  }
   std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
   for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
     dimensions[i] = operand_shape.dimensions(i) +
@@ -338,7 +342,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
 
   // Check if both element types are the same.
   if (lhs.element_type() != rhs.element_type()) {
-    return fail("element types mismatch");
+    return fail("element types do not match");
   }
 
   if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 ||
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
index 8ec4f1b528d..32b5fbba003 100644
--- a/tensorflow/compiler/xla/util.h
+++ b/tensorflow/compiler/xla/util.h
@@ -31,6 +31,7 @@ limitations under the License.
 #include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/lib/math/math_util.h"
 #include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/protobuf.h"
@@ -200,6 +201,46 @@ int64 PositionInContainer(const Container& container, int64 value) {
                        std::find(container.begin(), container.end(), value));
 }
 
+// Formats the container as a comma-separated string. StrAppend must support
+// appending the elements of the container. Prefix is prepended and suffix is
+// appended to the returned string.
+template <typename Container>
+string CommaSeparatedString(const Container& c, const char* prefix = "",
+                            const char* suffix = "") {
+  // Not using Join() since the implementation here is simple anyway and this
+  // avoids copying the string to append prefix.
+  string comma_separated = prefix;
+  const char* separator = "";
+  for (const auto& entry : c) {
+    tensorflow::strings::StrAppend(&comma_separated, separator, entry);
+    separator = ", ";
+  }
+  comma_separated += suffix;
+  return comma_separated;
+}
+
+// Overload needed to allow the container to be an initializer list. The default
+// type for T makes an empty initializer list work as well.
+template <typename T = int>
+string CommaSeparatedString(const std::initializer_list<T>& c,
+                            const char* prefix = "", const char* suffix = "") {
+  return CommaSeparatedString<std::initializer_list<T>>(c, prefix, suffix);
+}
+
+// Formats the container in the mathematical notation for a vector, e.g. (1, 3,
+// 7). StrAppend must support appending the elements of c.
+template <typename Container>
+string VectorString(const Container& c) {
+  return CommaSeparatedString(c, "(", ")");
+}
+
+// Overload needed to allow the container to be an initializer list. The default
+// type for T makes an empty initializer list work as well.
+template <typename T = int>
+string VectorString(const std::initializer_list<T>& c) {
+  return VectorString<std::initializer_list<T>>(c);
+}
+
 // Returns a PaddingConfig object that represents no padding for the given rank.
 PaddingConfig MakeNoPaddingConfig(int64 rank);
 
diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc
index a81014f3b7a..547b924180b 100644
--- a/tensorflow/compiler/xla/util_test.cc
+++ b/tensorflow/compiler/xla/util_test.cc
@@ -80,6 +80,26 @@ TEST(UtilTest, HumanReadableNumFlopsExample) {
   ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9));
 }
 
+TEST(UtilTest, CommaSeparatedString) {
+  EXPECT_EQ(CommaSeparatedString({}), "");
+  EXPECT_EQ(CommaSeparatedString({"hello world"}), "hello world");
+  EXPECT_EQ(CommaSeparatedString({1, 57, 2}, "foo", "bar"), "foo1, 57, 2bar");
+}
+
+TEST(UtilTest, VectorString) {
+  std::list<int64> empty_list;
+  EXPECT_EQ(VectorString(empty_list), "()");
+
+  std::vector<float> float_vector = {5.5};
+  EXPECT_EQ(VectorString(float_vector), "(5.5)");
+
+  std::set<const char*> string_set = {"a", "b"};
+  EXPECT_EQ(VectorString(string_set), "(a, b)");
+
+  EXPECT_EQ(VectorString({}), "()");
+  EXPECT_EQ(VectorString({1, 57, 2}), "(1, 57, 2)");
+}
+
 TEST(UtilTest, LogLines) {
   // Just make sure this code runs (not verifying the output).
   LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);

From b0594e1b82180efe5b1d0558b4410137f3974b93 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 20 Apr 2017 21:32:47 -0800
Subject: [PATCH 08/27] [XLA] Fixes some div-by-zero bugs. Change: 153795265

---
 .../xla/service/hlo_execution_profile.cc      | 53 +++++++++++--------
 1 file changed, 31 insertions(+), 22 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index 447892c8dec..9e25f1aceb1 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -70,6 +70,7 @@ string HloExecutionProfile::ToString(
   string result;
   const int64 total_cycles = total_cycles_executed(computation);
   double clock_rate_ghz = device_description.clock_rate_ghz();
+  CHECK_GE(clock_rate_ghz, 1e-9);
 
   const auto cycles_to_microseconds = [&](double cycles) {
     return cycles / clock_rate_ghz / 1000.0;
@@ -80,14 +81,19 @@ string HloExecutionProfile::ToString(
     double nsecs = cycles / clock_rate_ghz;
     string bytes_per_sec;
     string bytes_per_cycle;
-    if (bytes_accessed >= 0) {
+    if (cycles <= 0 || bytes_accessed < 0) {
+      bytes_per_sec = "<unknown>";
+      bytes_per_cycle = "<unknown>";
+    } else {
       bytes_per_sec = tensorflow::strings::HumanReadableNumBytes(
           bytes_accessed / (nsecs / 1e9));
       bytes_per_cycle =
           tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles);
-    } else {
-      bytes_per_sec = "<unknown>";
-      bytes_per_cycle = "<unknown>";
+    }
+
+    double cycles_percent = 0;
+    if (total_cycles > 0) {
+      cycles_percent = cycles / static_cast<double>(total_cycles) * 100;
     }
 
     tensorflow::strings::StrAppend(
@@ -97,8 +103,7 @@ string HloExecutionProfile::ToString(
             ":: "
             "%12s/cycle :: "
             "%s",
-            cycles, cycles / static_cast<double>(total_cycles) * 100,
-            cycles_to_microseconds(cycles),
+            cycles, cycles_percent, cycles_to_microseconds(cycles),
             flops <= 0 ? "<none>" : HumanReadableNumFlops(flops, nsecs).c_str(),
             bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str()));
   };
@@ -114,26 +119,30 @@ string HloExecutionProfile::ToString(
   for (const auto& item : items) {
     const HloInstruction* hlo = item.first;
     tensorflow::strings::StrAppend(&result, "\n\t");
-    int64 flops = hlo == nullptr ? -1 : cost_analysis.flop_count(*hlo);
-    int64 bytes_accessed =
-        hlo == nullptr ? -1 : cost_analysis.bytes_accessed(*hlo);
-    string display = hlo == nullptr ? "<none>" : hlo->ToString();
+    const int64 flops = (hlo == nullptr) ? -1 : cost_analysis.flop_count(*hlo);
+    const int64 bytes_accessed =
+        (hlo == nullptr) ? -1 : cost_analysis.bytes_accessed(*hlo);
+    const string display = (hlo == nullptr) ? "<none>" : hlo->ToString();
     append_item(item.second, flops, bytes_accessed, display);
   }
 
-  MetricTableReport table;
-  table.SetMetricName("microseconds");
-  table.SetEntryName("ops");
-  table.SetShowCategoryTable();
-  for (const auto& item : items) {
-    MetricTableReport::Entry entry;
-    entry.text = item.first->ToString();
-    entry.short_text = item.first->ToString(/*compact_operands=*/true);
-    entry.category_text = item.first->ToCategory();
-    entry.metric = cycles_to_microseconds(item.second);
-    table.AddEntry(std::move(entry));
+  if (total_cycles <= 0) {
+    result += "****** 0 total cycles ******\n";
+  } else {
+    MetricTableReport table;
+    table.SetMetricName("microseconds");
+    table.SetEntryName("ops");
+    table.SetShowCategoryTable();
+    for (const auto& item : items) {
+      MetricTableReport::Entry entry;
+      entry.text = item.first->ToString();
+      entry.short_text = item.first->ToString(/*compact_operands=*/true);
+      entry.category_text = item.first->ToCategory();
+      entry.metric = cycles_to_microseconds(item.second);
+      table.AddEntry(std::move(entry));
+    }
+    result += table.MakeReport(cycles_to_microseconds(total_cycles));
   }
-  result += table.MakeReport(cycles_to_microseconds(total_cycles));
 
   return result;
 }

From 858e0afcc45c39b6428bf82ab1444323e925cfd8 Mon Sep 17 00:00:00 2001
From: Derek Murray <mrry@google.com>
Date: Thu, 20 Apr 2017 22:22:29 -0800
Subject: [PATCH 09/27] Switch DirectSession to use _Arg and _Retval ops for
 feeding and fetching.

This change reduces the overhead imposed by string processing and
rendezvous invocation in the DirectSession::Run() call by 1--2 microseconds
per value fed or fetched.

RELNOTES: Improved DirectSession::Run() overhead and error checking. Feeding a value of the wrong type will now synchronously raise an INVALID_ARGUMENT error instead of asynchronously raising an INTERNAL error. Code that depends on the (undefined) behavior when feeding a tensor of the wrong type may need to be updated.
Change: 153797943
---
 tensorflow/core/BUILD                         |   1 +
 .../core/common_runtime/build_graph_options.h |   5 +
 .../core/common_runtime/direct_session.cc     | 144 ++++++++++++++----
 .../core/common_runtime/direct_session.h      |  22 ++-
 .../core/common_runtime/graph_runner.cc       |   4 +-
 .../resource_variable_read_optimizer.cc       |   9 +-
 .../simple_graph_execution_state.cc           |  20 ++-
 .../simple_graph_execution_state.h            |  20 ++-
 tensorflow/core/framework/function.cc         |  15 +-
 tensorflow/core/framework/function.h          |   1 +
 tensorflow/core/graph/subgraph.cc             | 111 +++++++++-----
 tensorflow/core/graph/subgraph.h              |  15 +-
 tensorflow/core/graph/subgraph_test.cc        |  96 ++++++++++--
 tensorflow/python/debug/lib/debug_data.py     |   2 +-
 .../kernel_tests/control_flow_ops_py_test.py  |   2 +-
 .../graph_transforms/fold_constants_lib.cc    |   3 +-
 16 files changed, 370 insertions(+), 100 deletions(-)

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 1b78b25ff51..d6143493877 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1563,6 +1563,7 @@ tf_cuda_library(
         ":lib_internal",
         ":proto_text",
         ":protos_all_cc",
+        "//tensorflow/core/kernels:function_ops",
     ],
     alwayslink = 1,
 )
diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h
index c6d4bdad9c1..49566c8fa8f 100644
--- a/tensorflow/core/common_runtime/build_graph_options.h
+++ b/tensorflow/core/common_runtime/build_graph_options.h
@@ -30,6 +30,11 @@ struct BuildGraphOptions {
   // the former via "ref" fetch_endpoints.
   std::vector<string> target_nodes;
 
+  // If `true`, uses Arg/Retval to implement feeds/fetches; otherwise
+  // uses Recv/Send to implement feeds/fetches.
+  // TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
+  bool use_function_convention = false;
+
   string DebugString() const;
 };
 
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index c05cceced11..002e246b80d 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -361,7 +361,6 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) {
   return Status::OK();
 }
 
-// TODO(yuanbyu): Simplify by treating Run() as "PRunSetup(); PRun()".
 Status DirectSession::Run(const NamedTensorList& inputs,
                           const std::vector<string>& output_names,
                           const std::vector<string>& target_nodes,
@@ -426,13 +425,34 @@ Status DirectSession::Run(const RunOptions& run_options,
         executor_step_count, input_tensor_names, output_names, target_nodes));
   }
 
+  // Configure a call frame for the step, which we use to feed and
+  // fetch values to and from the executors.
+  FunctionCallFrame call_frame(executors_and_keys->input_types,
+                               executors_and_keys->output_types);
+  gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
+  for (const auto& it : inputs) {
+    if (it.second.dtype() == DT_RESOURCE) {
+      Tensor tensor_from_handle;
+      TF_RETURN_IF_ERROR(
+          ResourceHandleToInputTensor(it.second, &tensor_from_handle));
+      feed_args[executors_and_keys->input_name_to_index[it.first]] =
+          tensor_from_handle;
+    } else {
+      feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
+    }
+  }
+  Status s = call_frame.SetArgs(feed_args);
+  if (errors::IsInternal(s)) {
+    return errors::InvalidArgument(s.error_message());
+  } else if (!s.ok()) {
+    return s;
+  }
+
   // Create a run state and start execution.
   RunState run_state(args.step_id, &devices_);
   run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
   CancellationManager step_cancellation_manager;
-
-  // Send inputs.
-  TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
+  args.call_frame = &call_frame;
 
   // Start parallel Executors.
   const size_t num_executors = executors_and_keys->items.size();
@@ -535,8 +555,22 @@ Status DirectSession::Run(const RunOptions& run_options,
   }
 
   // Receive outputs.
-  TF_RETURN_IF_ERROR(
-      RecvOutputs(output_names, executors_and_keys, &run_state, outputs));
+  if (outputs) {
+    std::vector<Tensor> sorted_outputs;
+    Status s = call_frame.ConsumeRetvals(&sorted_outputs);
+    if (errors::IsInternal(s)) {
+      return errors::InvalidArgument(s.error_message());
+    } else if (!s.ok()) {
+      return s;
+    }
+    outputs->clear();
+    outputs->reserve(sorted_outputs.size());
+    for (const string& output_name : output_names) {
+      outputs->emplace_back(
+          std::move(sorted_outputs[executors_and_keys
+                                       ->output_name_to_index[output_name]]));
+    }
+  }
 
   // Save the output tensors of this run we choose to keep.
   TF_RETURN_IF_ERROR(
@@ -706,11 +740,11 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
       CheckFetch(inputs, output_names, executors_and_keys, run_state));
 
   // Send inputs.
-  Status s = SendInputs(inputs, executors_and_keys, run_state->rendez);
+  Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
 
   // Receive outputs.
   if (s.ok()) {
-    s = RecvOutputs(output_names, executors_and_keys, run_state, outputs);
+    s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
   }
 
   // Save the output tensors of this run we choose to keep.
@@ -770,16 +804,17 @@ Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
   }
 }
 
-Status DirectSession::SendInputs(const NamedTensorList& inputs,
-                                 const ExecutorsAndKeys* executors_and_keys,
-                                 IntraProcessRendezvous* rendez) {
+Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
+                                     const ExecutorsAndKeys* executors_and_keys,
+                                     IntraProcessRendezvous* rendez) {
   Status s;
   Rendezvous::ParsedKey parsed;
   // Insert the input tensors into the local rendezvous by their
   // rendezvous key.
   for (const auto& input : inputs) {
-    auto it = executors_and_keys->input_keys.find(input.first);
-    if (it == executors_and_keys->input_keys.end()) {
+    auto it =
+        executors_and_keys->input_name_to_rendezvous_key.find(input.first);
+    if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
       return errors::Internal("'", input.first, "' is not a pre-defined feed.");
     }
     const string& input_key = it->second;
@@ -808,10 +843,10 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
   return Status::OK();
 }
 
-Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
-                                  const ExecutorsAndKeys* executors_and_keys,
-                                  RunState* run_state,
-                                  std::vector<Tensor>* outputs) {
+Status DirectSession::RecvPRunOutputs(
+    const std::vector<string>& output_names,
+    const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
+    std::vector<Tensor>* outputs) {
   Status s;
   if (!output_names.empty()) {
     outputs->resize(output_names.size());
@@ -822,8 +857,9 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
   for (size_t output_offset = 0; output_offset < output_names.size();
        ++output_offset) {
     const string& output_name = output_names[output_offset];
-    auto it = executors_and_keys->output_keys.find(output_name);
-    if (it == executors_and_keys->output_keys.end()) {
+    auto it =
+        executors_and_keys->output_name_to_rendezvous_key.find(output_name);
+    if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
       return errors::Internal("'", output_name,
                               "' is not a pre-defined fetch.");
     }
@@ -987,14 +1023,16 @@ Status DirectSession::GetOrCreateExecutors(
   options.feed_endpoints = inputs_sorted;
   options.fetch_endpoints = outputs_sorted;
   options.target_nodes = tn_sorted;
+  options.use_function_convention = !run_state_args->is_partial_run;
 
   std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
 
   // The executor_lock_ is intentionally released while executor is
   // being created.
   std::unordered_map<string, std::unique_ptr<Graph>> graphs;
-  TF_RETURN_IF_ERROR(
-      CreateGraphs(options, &graphs, &ek->flib_def, run_state_args));
+  TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def,
+                                  run_state_args, &ek->input_types,
+                                  &ek->output_types));
 
   if (run_state_args->is_partial_run) {
     ek->graph = std::move(run_state_args->graph);
@@ -1079,17 +1117,37 @@ Status DirectSession::GetOrCreateExecutors(
     item->executor.reset(executor);
   }
 
-  // Compute the rendezvous keys to avoid recomputing them every time.
-  //
-  // We always use the first device as the device name portion of the
-  // key, even if we're feeding another graph.
-  for (const string& input : inputs) {
-    ek->input_keys[input] = GetRendezvousKey(
-        input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
-  }
-  for (const string& output : outputs) {
-    ek->output_keys[output] = GetRendezvousKey(
-        output, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
+  // Cache the mapping from input/output names to graph elements to
+  // avoid recomputing it every time.
+  if (!run_state_args->is_partial_run) {
+    // For regular `Run()`, we use the function calling convention, and so
+    // maintain a mapping from input/output names to
+    // argument/return-value ordinal index.
+    for (size_t i = 0; i < inputs_sorted.size(); ++i) {
+      const string& input = inputs_sorted[i];
+      ek->input_name_to_index[input] = i;
+    }
+    for (size_t i = 0; i < outputs_sorted.size(); ++i) {
+      const string& output = outputs_sorted[i];
+      ek->output_name_to_index[output] = i;
+    }
+  } else {
+    // For `PRun()`, we use the rendezvous calling convention, and so
+    // maintain a mapping from input/output names to rendezvous keys.
+    //
+    // We always use the first device as the device name portion of the
+    // key, even if we're feeding another graph.
+    for (size_t i = 0; i < inputs_sorted.size(); ++i) {
+      const string& input = inputs_sorted[i];
+      ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
+          input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
+    }
+    for (size_t i = 0; i < outputs_sorted.size(); ++i) {
+      const string& output = outputs_sorted[i];
+      ek->output_name_to_rendezvous_key[output] =
+          GetRendezvousKey(output, device_set_.client_device()->attributes(),
+                           FrameAndIter(0, 0));
+    }
   }
 
   // Reacquire the lock, try to insert into the map.
@@ -1110,7 +1168,8 @@ Status DirectSession::CreateGraphs(
     const BuildGraphOptions& subgraph_options,
     std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
     std::unique_ptr<FunctionLibraryDefinition>* flib_def,
-    RunStateArgs* run_state_args) {
+    RunStateArgs* run_state_args, DataTypeVector* input_types,
+    DataTypeVector* output_types) {
   mutex_lock l(graph_def_lock_);
   std::unique_ptr<SimpleClientGraph> client_graph;
 
@@ -1135,6 +1194,23 @@ Status DirectSession::CreateGraphs(
         execution_state->BuildGraph(subgraph_options, &client_graph));
   }
 
+  if (subgraph_options.feed_endpoints.size() !=
+      client_graph->feed_types.size()) {
+    return errors::Internal(
+        "Graph pruning failed: requested number of feed endpoints = ",
+        subgraph_options.feed_endpoints.size(),
+        " versus number of pruned feed endpoints = ",
+        client_graph->feed_types.size());
+  }
+  if (subgraph_options.fetch_endpoints.size() !=
+      client_graph->fetch_types.size()) {
+    return errors::Internal(
+        "Graph pruning failed: requested number of fetch endpoints = ",
+        subgraph_options.fetch_endpoints.size(),
+        " versus number of pruned fetch endpoints = ",
+        client_graph->fetch_types.size());
+  }
+
   auto current_stateful_placements = execution_state->GetStatefulPlacements();
   // Update our current state based on the execution_state's
   // placements.  If there are any mismatches for a node,
@@ -1240,6 +1316,8 @@ Status DirectSession::CreateGraphs(
     }
   }
   *flib_def = std::move(client_graph->flib_def);
+  std::swap(*input_types, client_graph->feed_types);
+  std::swap(*output_types, client_graph->fetch_types);
   return s;
 }
 
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index b9d22ac522c..848ef3bc62d 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -132,8 +132,13 @@ class DirectSession : public Session {
     NameNodeMap name_to_node;
     std::unique_ptr<FunctionLibraryDefinition> flib_def;
     std::vector<PerPartitionExecutorsAndLib> items;
-    std::unordered_map<string, string> input_keys;
-    std::unordered_map<string, string> output_keys;
+    std::unordered_map<string, size_t> input_name_to_index;
+    std::unordered_map<string, string> input_name_to_rendezvous_key;
+    std::unordered_map<string, size_t> output_name_to_index;
+    std::unordered_map<string, string> output_name_to_rendezvous_key;
+
+    DataTypeVector input_types;
+    DataTypeVector output_types;
   };
 
   // For each live partial execution, the session maintains a RunState.
@@ -187,7 +192,8 @@ class DirectSession : public Session {
       const BuildGraphOptions& options,
       std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
       std::unique_ptr<FunctionLibraryDefinition>* flib_def,
-      RunStateArgs* run_state_args);
+      RunStateArgs* run_state_args, DataTypeVector* input_types,
+      DataTypeVector* output_types);
 
   ::tensorflow::Status ExtendLocked(const GraphDef& graph)
       EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
@@ -196,17 +202,17 @@ class DirectSession : public Session {
       const Tensor& resource_tensor, Tensor* retrieved_tensor);
 
   // Feeds more inputs to the executors, triggering further execution.
-  ::tensorflow::Status SendInputs(
+  ::tensorflow::Status SendPRunInputs(
       const std::vector<std::pair<string, Tensor>>& inputs,
       const ExecutorsAndKeys* executors_and_keys,
       IntraProcessRendezvous* rendez);
 
   // Fetches more outputs from the executors. It waits until the output
   // tensors are computed.
-  ::tensorflow::Status RecvOutputs(const std::vector<string>& output_names,
-                                   const ExecutorsAndKeys* executors_and_keys,
-                                   RunState* run_state,
-                                   std::vector<Tensor>* outputs);
+  ::tensorflow::Status RecvPRunOutputs(
+      const std::vector<string>& output_names,
+      const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
+      std::vector<Tensor>* outputs);
 
   // Check if the specified fetches can be computed from the feeds
   // that we have already provided.
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index 514a63590b1..a85fbbf88ff 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -130,9 +130,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
   }
 
   // Call RewriteGraphForExecution
+  subgraph::RewriteGraphMetadata metadata;
   TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
       graph_to_run.get(), input_names, output_names, {} /* target nodes */,
-      cpu_device_->attributes()));
+      cpu_device_->attributes(), false /* use_function_convention */,
+      &metadata));
 
   // Create the local executor and the Rendezvous for fetching back the
   // constants.
diff --git a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc
index 85a29e11e23..c179e94c36b 100644
--- a/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc
+++ b/tensorflow/core/common_runtime/resource_variable_read_optimizer.cc
@@ -21,9 +21,9 @@ limitations under the License.
 namespace tensorflow {
 namespace {
 
-// Replaces ReadVariableOp nodes which are only used by Sends and sinks with
-// _UnsafeReadVariable nodes, as this transforamtion is safe and will improve
-// performance.
+// Replaces ReadVariableOp nodes which are only used by Sends, sinks,
+// and function Retvals with _UnsafeReadVariable nodes, as this
+// transformation is safe and will improve performance.
 class ResourceVariableReadPass : public GraphOptimizationPass {
  public:
   Status Run(const GraphOptimizationPassOptions& options) override {
@@ -43,7 +43,8 @@ class ResourceVariableReadPass : public GraphOptimizationPass {
       if (n->type_string() == "ReadVariableOp") {
         bool skip = false;
         for (const Edge* e : n->out_edges()) {
-          if (!e->dst()->IsSend() && e->dst()->name() != "_SINK") {
+          if (!e->dst()->IsSend() && e->dst()->type_string() != "_Retval" &&
+              e->dst()->name() != "_SINK") {
             skip = true;
           }
         }
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.cc b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
index c2ac15b345d..31e63a9ef75 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.cc
@@ -284,9 +284,11 @@ Status SimpleGraphExecutionState::InitBaseGraph(
   if (session_options_ &&
       session_options_->config.graph_options().place_pruned_graph()) {
     // Rewrite the graph before placement.
+    rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
     TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
         new_graph.get(), options.feed_endpoints, options.fetch_endpoints,
-        options.target_nodes, device_set_->client_device()->attributes()));
+        options.target_nodes, device_set_->client_device()->attributes(),
+        options.use_function_convention, rewrite_metadata_.get()));
   }
 
   // Save stateful placements before placing.
@@ -333,15 +335,26 @@ Status SimpleGraphExecutionState::BuildGraph(
   std::unique_ptr<Graph> ng(new Graph(flib_def_.get()));
   CopyGraph(*graph_, ng.get());
 
+  subgraph::RewriteGraphMetadata rewrite_metadata;
   if (session_options_ == nullptr ||
       !session_options_->config.graph_options().place_pruned_graph()) {
     // Extract the subset of the graph that needs to be run, adding feed/fetch
     // ops as needed.
     TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
         ng.get(), options.feed_endpoints, options.fetch_endpoints,
-        options.target_nodes, device_set_->client_device()->attributes()));
+        options.target_nodes, device_set_->client_device()->attributes(),
+        options.use_function_convention, &rewrite_metadata));
+  } else {
+    // This SimpleGraphExecutionState represents a graph that was
+    // pruned when this was constructed, so we copy the metadata from
+    // a member variable.
+    CHECK(rewrite_metadata_);
+    rewrite_metadata = *rewrite_metadata_;
   }
 
+  CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size());
+  CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size());
+
   // Make a fresh copy of the function library for the client graph.
   std::unique_ptr<FunctionLibraryDefinition> flib(
       new FunctionLibraryDefinition(*flib_def_));
@@ -363,7 +376,8 @@ Status SimpleGraphExecutionState::BuildGraph(
   // since the local CostModel used to record its stats is sized by
   // the largest node id.
   std::unique_ptr<SimpleClientGraph> dense_copy(
-      new SimpleClientGraph(std::move(flib)));
+      new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types,
+                            rewrite_metadata.fetch_types));
   CopyGraph(*ng, &dense_copy->graph);
 
   // TODO(vrv): We should check invariants of the graph here.
diff --git a/tensorflow/core/common_runtime/simple_graph_execution_state.h b/tensorflow/core/common_runtime/simple_graph_execution_state.h
index 3b6ce23c754..00b5509fd78 100644
--- a/tensorflow/core/common_runtime/simple_graph_execution_state.h
+++ b/tensorflow/core/common_runtime/simple_graph_execution_state.h
@@ -39,6 +39,10 @@ struct SessionOptions;
 class StepStats;
 class Timeline;
 
+namespace subgraph {
+struct RewriteGraphMetadata;
+}
+
 struct SimpleGraphExecutionStateOptions {
   const DeviceSet* device_set = nullptr;
   const SessionOptions* session_options = nullptr;
@@ -50,13 +54,19 @@ struct SimpleGraphExecutionStateOptions {
 // A SimpleClientGraph is simply a sub-graph of the full graph as induced by
 // BuildGraphOptions.
 struct SimpleClientGraph {
-  explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib)
-      : flib_def(std::move(flib)), graph(flib_def.get()) {}
+  explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
+                             DataTypeVector feed_types,
+                             DataTypeVector fetch_types)
+      : flib_def(std::move(flib)),
+        graph(flib_def.get()),
+        feed_types(std::move(feed_types)),
+        fetch_types(std::move(fetch_types)) {}
   // Each client-graph gets its own function library since optimization passes
   // post rewrite for execution might want to introduce new functions.
   std::unique_ptr<FunctionLibraryDefinition> flib_def;
   Graph graph;
-  int32 placement_version;
+  DataTypeVector feed_types;
+  DataTypeVector fetch_types;
 };
 
 // SimpleGraphExecutionState is responsible for generating an
@@ -190,6 +200,10 @@ class SimpleGraphExecutionState {
   // and may be updated by a graph optimization pass.
   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
 
+  // `rewrite_metadata_` is only set for SimpleGraphExecutionState
+  // objects created by `MakeForPrunedGraph()`.
+  std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
+
   // The dataflow graph owned by this object.
   Graph* graph_;
 
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index edb52737d94..8a7d96c38a9 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -789,7 +789,7 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
   rets->clear();
   rets->reserve(rets_.size());
   for (size_t i = 0; i < rets_.size(); ++i) {
-    auto item = rets_[i];
+    const auto& item = rets_[i];
     if (item.has_val) {
       rets->push_back(item.val);
     } else {
@@ -799,6 +799,19 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
   return Status::OK();
 }
 
+Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
+  rets->clear();
+  rets->reserve(rets_.size());
+  for (size_t i = 0; i < rets_.size(); ++i) {
+    if (rets_[i].has_val) {
+      rets->emplace_back(std::move(rets_[i].val));
+    } else {
+      return errors::Internal("Retval[", i, "] does not have value");
+    }
+  }
+  return Status::OK();
+}
+
 Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
   if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
     return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 63c868ac9b8..210e5b949a5 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -259,6 +259,7 @@ class FunctionCallFrame {
   // Caller methods.
   Status SetArgs(gtl::ArraySlice<Tensor> args);
   Status GetRetvals(std::vector<Tensor>* rets) const;
+  Status ConsumeRetvals(std::vector<Tensor>* rets);
 
   // Callee methods.
   Status GetArg(int index, Tensor* val) const;
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc
index 91292500e1e..9849d9a1596 100644
--- a/tensorflow/core/graph/subgraph.cc
+++ b/tensorflow/core/graph/subgraph.cc
@@ -55,8 +55,13 @@ namespace {
 // state).
 static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
                          const gtl::ArraySlice<string>& fed_outputs,
-                         subgraph::NameIndex* name_index) {
-  for (const string& t : fed_outputs) {
+                         bool use_function_convention,
+                         subgraph::NameIndex* name_index,
+                         DataTypeVector* out_feed_types) {
+  out_feed_types->clear();
+  out_feed_types->reserve(fed_outputs.size());
+  for (size_t i = 0; i < fed_outputs.size(); ++i) {
+    const string& t = fed_outputs[i];
     TensorId id(ParseTensorName(t));
 
     auto iter = name_index->find(id.first);
@@ -71,17 +76,31 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
     }
 
     Node* recv_node;
-    TF_RETURN_IF_ERROR(
-        NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
-                    "_Recv")
-            .Attr("tensor_type", BaseType(n->output_type(id.second)))
-            .Attr("tensor_name", t)
-            .Attr("send_device", device_info.name())
-            .Attr("recv_device", device_info.name())
-            .Attr("send_device_incarnation",
-                  static_cast<int64>(device_info.incarnation()))
-            .Attr("client_terminated", true)
-            .Finalize(g, &recv_node));
+
+    if (!use_function_convention) {
+      TF_RETURN_IF_ERROR(
+          NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
+                      "_Recv")
+              .Attr("tensor_type", BaseType(n->output_type(id.second)))
+              .Attr("tensor_name", t)
+              .Attr("send_device", device_info.name())
+              .Attr("recv_device", device_info.name())
+              .Attr("send_device_incarnation",
+                    static_cast<int64>(device_info.incarnation()))
+              .Attr("client_terminated", true)
+              .Finalize(g, &recv_node));
+    } else {
+      // NOTE(mrry): We must include the index as part of the node
+      // name, because _Arg is a "stateful" kernel and therefore
+      // its name must uniquely identify a kernel instance across all
+      // graphs in the same session.
+      TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_arg_", id.first, "_",
+                                                     id.second, "_", i),
+                                     "_Arg")
+                             .Attr("T", BaseType(n->output_type(id.second)))
+                             .Attr("index", static_cast<int32>(i))
+                             .Finalize(g, &recv_node));
+    }
     recv_node->set_assigned_device_name(device_info.name());
 
     // Copy the _output_shapes from the original node to the feed node,
@@ -130,6 +149,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
       }
       g->RemoveEdge(e);
     }
+    out_feed_types->push_back(BaseType(n->output_type(id.second)));
   }
   return Status::OK();
 }
@@ -181,9 +201,14 @@ namespace subgraph {
 
 Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
                     const gtl::ArraySlice<string>& fetch_outputs,
-                    NameIndex* name_index, std::vector<Node*>* fetch_nodes) {
-  fetch_nodes->clear();
-  for (const string& t : fetch_outputs) {
+                    bool use_function_convention, NameIndex* name_index,
+                    std::vector<Node*>* out_fetch_nodes,
+                    DataTypeVector* out_fetch_types) {
+  out_fetch_nodes->clear();
+  out_fetch_nodes->reserve(fetch_outputs.size());
+  for (size_t i = 0; i < fetch_outputs.size(); ++i) {
+    const string& t = fetch_outputs[i];
+
     // Parse t into node_name and output_index.
     TensorId id(ParseTensorName(t));
 
@@ -213,25 +238,39 @@ Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
 
     // Create the fetch Node and connect it up
     Node* send_node;
-    TF_RETURN_IF_ERROR(
-        NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
-                    "_Send")
-            .Input(n, id.second)
-            .Attr("tensor_name", t)
-            .Attr("send_device", device_info.name())
-            .Attr("recv_device", device_info.name())
-            .Attr("send_device_incarnation",
-                  static_cast<int64>(device_info.incarnation()))
-            .Attr("client_terminated", true)
-            .Finalize(g, &send_node));
+    if (!use_function_convention) {
+      TF_RETURN_IF_ERROR(
+          NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
+                      "_Send")
+              .Input(n, id.second)
+              .Attr("tensor_name", t)
+              .Attr("send_device", device_info.name())
+              .Attr("recv_device", device_info.name())
+              .Attr("send_device_incarnation",
+                    static_cast<int64>(device_info.incarnation()))
+              .Attr("client_terminated", true)
+              .Finalize(g, &send_node));
+    } else {
+      // NOTE(mrry): We must include the index as part of the node
+      // name, because _Retval is a "stateful" kernel and therefore
+      // its name must uniquely identify a kernel instance across all
+      // graphs in the same session.
+      TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_retval_", id.first, "_",
+                                                     id.second, "_", i),
+                                     "_Retval")
+                             .Input(n, id.second)
+                             .Attr("T", BaseType(n->output_type(id.second)))
+                             .Attr("index", static_cast<int32>(i))
+                             .Finalize(g, &send_node));
+    }
     send_node->set_assigned_device_name(device_info.name());
-    VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def());
 
     // Update the index.
     (*name_index)[send_node->name()] = send_node;
 
     g->AddControlEdge(send_node, g->sink_node());
-    fetch_nodes->push_back(send_node);
+    out_fetch_nodes->push_back(send_node);
+    out_fetch_types->push_back(BaseType(n->output_type(id.second)));
   }
 
   return Status::OK();
@@ -241,7 +280,8 @@ Status RewriteGraphForExecution(
     Graph* g, const gtl::ArraySlice<string>& fed_outputs,
     const gtl::ArraySlice<string>& fetch_outputs,
     const gtl::ArraySlice<string>& target_node_names,
-    const DeviceAttributes& device_info) {
+    const DeviceAttributes& device_info, bool use_function_convention,
+    RewriteGraphMetadata* out_metadata) {
   if (fetch_outputs.empty() && target_node_names.empty()) {
     return errors::InvalidArgument(
         "Must specify at least one target to fetch or execute.");
@@ -274,18 +314,21 @@ Status RewriteGraphForExecution(
   // currently listed in "fetch_nodes".  We pass "name_index" so the index is
   // kept up to date.
   if (!fed_outputs.empty()) {
-    TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index));
+    TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs,
+                                  use_function_convention, &name_index,
+                                  &out_metadata->feed_types));
   }
 
   // Add the fetch nodes, also updating "name_index".
   std::vector<Node*> fetch_nodes;
   if (!fetch_outputs.empty()) {
-    TF_RETURN_IF_ERROR(
-        FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes));
+    TF_RETURN_IF_ERROR(FetchOutputs(g, device_info, fetch_outputs,
+                                    use_function_convention, &name_index,
+                                    &fetch_nodes, &out_metadata->fetch_types));
   }
 
   // Prune the graph to only compute what is needed for the fetch nodes and the
-  // targets nodes.
+  // target nodes.
   if (!fetch_nodes.empty() || !target_node_names.empty()) {
     TF_RETURN_IF_ERROR(
         PruneForTargets(g, name_index, fetch_nodes, target_node_names));
diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h
index d94d983d000..8ccc27914bc 100644
--- a/tensorflow/core/graph/subgraph.h
+++ b/tensorflow/core/graph/subgraph.h
@@ -26,6 +26,18 @@ limitations under the License.
 namespace tensorflow {
 namespace subgraph {
 
+// Information about a graph rewritten by `RewriteGraphForExecution()`.
+struct RewriteGraphMetadata {
+  // The element type of each tensor fed to this subgraph. The order
+  // of types corresponds to the order of tensor names in
+  // `fed_outputs` when calling `RewriteGraphForExecution()`.
+  DataTypeVector feed_types;
+  // The element type of each tensor fetched from this subgraph. The
+  // order of types corresponds to the order of tensor names in
+  // `fetch_outputs` when calling `RewriteGraphForExecution()`.
+  DataTypeVector fetch_types;
+};
+
 // Rewrite the graph structure of "*g" to deal with feeding node
 // outputs, fetching node outputs, and only running a subset of the
 // graph.  "fed_outputs" and "fetch_outputs" are both lists of
@@ -56,7 +68,8 @@ Status RewriteGraphForExecution(
     Graph* g, const gtl::ArraySlice<string>& fed_outputs,
     const gtl::ArraySlice<string>& fetch_outputs,
     const gtl::ArraySlice<string>& target_node_names,
-    const DeviceAttributes& device_info);
+    const DeviceAttributes& device_info, bool use_function_convention,
+    RewriteGraphMetadata* out_metadata);
 
 typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex;
 
diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc
index ee4960121f5..3dc11b7a166 100644
--- a/tensorflow/core/graph/subgraph_test.cc
+++ b/tensorflow/core/graph/subgraph_test.cc
@@ -104,7 +104,8 @@ class SubgraphTest : public ::testing::Test {
   }
 
   string Subgraph(const string& fed_str, const string& fetch_str,
-                  const string& targets_str) {
+                  const string& targets_str,
+                  bool use_function_convention = false) {
     Graph* subgraph = new Graph(OpRegistry::Global());
     CopyGraph(*g_, subgraph);
     std::vector<string> fed =
@@ -114,13 +115,18 @@ class SubgraphTest : public ::testing::Test {
     std::vector<string> targets =
         str_util::Split(targets_str, ',', str_util::SkipEmpty());
 
-    Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, targets,
-                                                  device_info_);
+    subgraph::RewriteGraphMetadata metadata;
+    Status s = subgraph::RewriteGraphForExecution(
+        subgraph, fed, fetch, targets, device_info_, use_function_convention,
+        &metadata);
     if (!s.ok()) {
       delete subgraph;
       return s.ToString();
     }
 
+    EXPECT_EQ(fed.size(), metadata.feed_types.size());
+    EXPECT_EQ(fetch.size(), metadata.fetch_types.size());
+
     // Replace the graph with the subgraph for the rest of the display program
     g_.reset(subgraph);
     return "OK";
@@ -178,6 +184,20 @@ TEST_F(SubgraphTest, FedOutputs1) {
   ExpectNodes("W1,W2,_recv_input_1,t1,t2");
 }
 
+TEST_F(SubgraphTest, FedOutputs1_FunctionConvention) {
+  ExpectOK(
+      "node { name: 'W1' op: 'TestParams' }"
+      "node { name: 'W2' op: 'TestParams' }"
+      "node { name: 'input' op: 'TestInput' }"
+      "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+      "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+      "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+      "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+  EXPECT_EQ("OK",
+            Subgraph("input:1", "", "t2", true /* use_function_convention */));
+  ExpectNodes("W1,W2,_arg_input_1_0,t1,t2");
+}
+
 TEST_F(SubgraphTest, FedRefNode) {
   ExpectOK(
       "node { name: 'W1' op: 'TestParams' }"
@@ -189,7 +209,19 @@ TEST_F(SubgraphTest, FedRefNode) {
   EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
 }
 
-TEST_F(SubgraphTest, FedOutputs2) {
+TEST_F(SubgraphTest, FedRefNode_FunctionConvention) {
+  ExpectOK(
+      "node { name: 'W1' op: 'TestParams' }"
+      "node { name: 'W2' op: 'TestParams' }"
+      "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }");
+  EXPECT_EQ("OK",
+            Subgraph("W1:0", "", "t1", true /* use_function_convention */));
+  ExpectNodes("_arg_W1_0_0,W2,t1");
+  Node* n = FindNode("_arg_W1_0_0");
+  EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
+}
+
+TEST_F(SubgraphTest, FedOutputs2_FunctionConvention) {
   ExpectOK(
       "node { name: 'W1' op: 'TestParams' }"
       "node { name: 'W2' op: 'TestParams' }"
@@ -200,8 +232,9 @@ TEST_F(SubgraphTest, FedOutputs2) {
       "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
   // We feed input:1, but nothing connects to it, so the _recv(input:1)
   // node also disappears.
-  EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2"));
-  ExpectNodes("_recv_t1_0,_recv_W2_0,t2");
+  EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2",
+                           true /* use_function_convention */));
+  ExpectNodes("_arg_t1_0_1,_arg_W2_0_2,t2");
 }
 
 TEST_F(SubgraphTest, FetchOutputs1) {
@@ -218,6 +251,22 @@ TEST_F(SubgraphTest, FetchOutputs1) {
       "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0");
 }
 
+TEST_F(SubgraphTest, FetchOutputs1_FunctionConvention) {
+  ExpectOK(
+      "node { name: 'W1' op: 'TestParams' }"
+      "node { name: 'W2' op: 'TestParams' }"
+      "node { name: 'input' op: 'TestInput' }"
+      "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+      "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+      "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+      "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+  EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2",
+                           true /* use_function_convention */));
+  ExpectNodes(
+      "W1,W2,input,t1,t2,_retval_W2_0_0,_retval_input_1_1,_retval_t1_0_2,_"
+      "retval_t2_0_3");
+}
+
 TEST_F(SubgraphTest, FetchOutputs2) {
   ExpectOK(
       "node { name: 'W1' op: 'TestParams' }"
@@ -231,6 +280,20 @@ TEST_F(SubgraphTest, FetchOutputs2) {
   ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0");
 }
 
+TEST_F(SubgraphTest, FetchOutputs2_FunctionConvention) {
+  ExpectOK(
+      "node { name: 'W1' op: 'TestParams' }"
+      "node { name: 'W2' op: 'TestParams' }"
+      "node { name: 'input' op: 'TestInput' }"
+      "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
+      "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
+      "node { name: 't3_a' op: 'TestRelu' input: 't2' }"
+      "node { name: 't3_b' op: 'TestRelu' input: 't2' }");
+  EXPECT_EQ("OK",
+            Subgraph("", "t3_a", "t2", true /* use_function_convention */));
+  ExpectNodes("W1,W2,input,t1,t2,t3_a,_retval_t3_a_0_0");
+}
+
 TEST_F(SubgraphTest, ChainOfFools) {
   ExpectOK(
       "node { name: 'a' op: 'TestParams' }"
@@ -315,7 +378,8 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) {
 REGISTER_OP("In").Output("o: float");
 REGISTER_OP("Op").Input("i: float").Output("o: float");
 
-static void BM_Subgraph(int iters, int num_nodes) {
+static void BM_SubgraphHelper(int iters, int num_nodes,
+                              bool use_function_convention) {
   DeviceAttributes device_info;
   device_info.set_name("/job:a/replica:0/task:0/cpu:0");
   device_info.set_device_type(DeviceType(DEVICE_CPU).type());
@@ -347,12 +411,26 @@ static void BM_Subgraph(int iters, int num_nodes) {
   while (--iters > 0) {
     Graph* subgraph = new Graph(OpRegistry::Global());
     CopyGraph(g, subgraph);
-    TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch,
-                                                   targets, device_info));
+    subgraph::RewriteGraphMetadata metadata;
+    TF_CHECK_OK(subgraph::RewriteGraphForExecution(
+        subgraph, fed, fetch, targets, device_info, use_function_convention,
+        &metadata));
     delete subgraph;
   }
 }
+
+static void BM_Subgraph(int iters, int num_nodes) {
+  BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */);
+}
+static void BM_SubgraphFunctionConvention(int iters, int num_nodes) {
+  BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */);
+}
 BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
+BENCHMARK(BM_SubgraphFunctionConvention)
+    ->Arg(100)
+    ->Arg(1000)
+    ->Arg(10000)
+    ->Arg(100000);
 
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py
index a76dd4f6d60..bb457a01b23 100644
--- a/tensorflow/python/debug/lib/debug_data.py
+++ b/tensorflow/python/debug/lib/debug_data.py
@@ -820,7 +820,7 @@ class DebugDumpDir(object):
     self._node_op_types[node.name] = node.op
 
     for inp in node.input:
-      if is_copy_node(inp) and node.op == "_Send":
+      if is_copy_node(inp) and (node.op == "_Send" or node.op == "_Retval"):
         self._copy_send_nodes.append(node.name)
 
       if inp.startswith("^"):
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 6c7cbbff9cb..00f6cc0d6d9 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -196,7 +196,7 @@ class ControlFlowTest(test.TestCase):
 
       with self.assertRaisesWithPredicateMatch(
           errors_impl.InvalidArgumentError,
-          lambda e: "The tensor returned for" in str(e)):
+          lambda e: "Retval[0] does not have value" in str(e)):
         dead_branch.eval()
 
   def testSwitchMergeLess(self):
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index 8d1f19bf30b..466e61b42dc 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -147,9 +147,10 @@ Status FoldConstants(const GraphDef& input_graph_def,
   TF_RETURN_IF_ERROR(
       ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr));
   DeviceAttributes device_attributes;
+  subgraph::RewriteGraphMetadata metadata;
   TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
       &input_graph, context.input_names, context.output_names, {},
-      device_attributes));
+      device_attributes, false /* use_function_convention */, &metadata));
   bool was_mutated;
   TF_RETURN_IF_ERROR(DoConstantFoldingWithStatus(
       ConstantFoldingOptions(), nullptr, Env::Default(), nullptr, &input_graph,

From 8e0d6f12efec0aeef16a64c6422aa3e90bbf4058 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 21 Apr 2017 06:35:18 -0800
Subject: [PATCH 10/27] Automated rollback of change 153736477 Change:
 153825726

---
 configure                           |   5 +-
 tensorflow/workspace.bzl            |   2 +
 third_party/py/BUILD                |   0
 third_party/py/BUILD.tpl            |  53 +++++++
 third_party/py/numpy/BUILD          |   6 +-
 third_party/py/python_configure.bzl | 206 ++++++++++++++++++++++++++++
 util/python/BUILD                   |  28 +---
 util/python/python_config.sh        |  76 +---------
 8 files changed, 274 insertions(+), 102 deletions(-)
 create mode 100644 third_party/py/BUILD
 create mode 100644 third_party/py/BUILD.tpl
 create mode 100644 third_party/py/python_configure.bzl

diff --git a/configure b/configure
index 48a4594da63..47bdd5d018e 100755
--- a/configure
+++ b/configure
@@ -86,6 +86,9 @@ while true; do
   PYTHON_BIN_PATH=""
   # Retry
 done
+export PYTHON_BIN_PATH
+write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH"
+# TODO(ngiraldo): allow the user to optionally set PYTHON_INCLUDE_PATH and NUMPY_INCLUDE_PATH
 
 ## Set up MKL related environment settings
 if false; then # Disable building with MKL for now
@@ -243,7 +246,7 @@ fi
 
 
 # Invoke python_config and set up symlinks to python includes
-./util/python/python_config.sh --setup "$PYTHON_BIN_PATH"
+./util/python/python_config.sh "$PYTHON_BIN_PATH"
 
 # Append CC optimization flags to bazel.rc
 echo >> tools/bazel.rc
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 8a858fb62a4..5ab91b69a15 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -5,6 +5,7 @@ load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
 load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external")
 load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external")
 load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles_external")
+load("//third_party/py:python_configure.bzl", "python_configure")
 
 
 # Parse the bazel version string from `native.bazel_version`.
@@ -119,6 +120,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
   check_version("0.4.5")
   cuda_configure(name="local_config_cuda")
   sycl_configure(name="local_config_sycl")
+  python_configure(name="local_config_python")
   if path_prefix:
     print("path_prefix was specified to tf_workspace but is no longer used " +
           "and will be removed in the future.")
diff --git a/third_party/py/BUILD b/third_party/py/BUILD
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/third_party/py/BUILD.tpl b/third_party/py/BUILD.tpl
new file mode 100644
index 00000000000..157834df4b9
--- /dev/null
+++ b/third_party/py/BUILD.tpl
@@ -0,0 +1,53 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+    name = "python_headers",
+    hdrs = select({
+        "windows" : [
+            "python_include_windows",
+        ],
+        "//conditions:default" : [
+            "python_include",
+        ],
+    }),
+    includes = select({
+        "windows" : [
+            "python_include_windows",
+        ],
+        "//conditions:default" : [
+            "python_include",
+        ],
+    }),
+)
+
+cc_library(
+    name = "numpy_headers",
+    hdrs = select({
+        "windows" : [
+            "numpy_include_windows",
+        ],
+        "//conditions:default" : [
+            "numpy_include",
+        ],
+    }),
+    includes = select({
+        "windows" : [
+            "numpy_include_windows",
+        ],
+        "//conditions:default" : [
+            "numpy_include",
+        ],
+    }),
+)
+
+config_setting(
+    name = "windows",
+    values = {"cpu": "x64_windows"},
+    visibility = ["//visibility:public"],
+)
+
+%{PYTHON_INCLUDE_GENRULE}
+
+%{NUMPY_INCLUDE_GENRULE}
diff --git a/third_party/py/numpy/BUILD b/third_party/py/numpy/BUILD
index 1d461505a69..be8332572b1 100644
--- a/third_party/py/numpy/BUILD
+++ b/third_party/py/numpy/BUILD
@@ -8,11 +8,9 @@ py_library(
     srcs_version = "PY2AND3",
 )
 
-cc_library(
+alias(
     name = "headers",
-    hdrs = glob(["numpy_include/**/*.h"]),
-    data = ["//util/python:python_checked"],
-    includes = ["numpy_include"],
+    actual = "@local_config_python//:numpy_headers",
 )
 
 genrule(
diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl
new file mode 100644
index 00000000000..d49d4c17815
--- /dev/null
+++ b/third_party/py/python_configure.bzl
@@ -0,0 +1,206 @@
+# -*- Python -*-
+"""Repository rule for Python autoconfiguration.
+
+`python_configure` depends on the following environment variables:
+
+  * `NUMPY_INCLUDE_PATH`: Location of Numpy libraries.
+  * `PYTHON_BIN_PATH`: location of python binary.
+  * `PYTHON_INCLUDE_PATH`: Location of python binaries.
+"""
+
+_NUMPY_INCLUDE_PATH = "NUMPY_INCLUDE_PATH"
+_PYTHON_BIN_PATH = "PYTHON_BIN_PATH"
+_PYTHON_INCLUDE_PATH = "PYTHON_INCLUDE_PATH"
+
+
+def _tpl(repository_ctx, tpl, substitutions={}, out=None):
+  if not out:
+    out = tpl
+  repository_ctx.template(
+      out,
+      Label("//third_party/py:%s.tpl" % tpl),
+      substitutions)
+
+
+def _python_configure_warning(msg):
+  """Output warning message during auto configuration."""
+  yellow = "\033[1;33m"
+  no_color = "\033[0m"
+  print("\n%sPython Configuration Warning:%s %s\n" % (yellow, no_color, msg))
+
+
+def _python_configure_fail(msg):
+  """Output failure message when auto configuration fails."""
+  red = "\033[0;31m"
+  no_color = "\033[0m"
+  fail("\n%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
+
+
+def _get_env_var(repository_ctx, name, default = None, enable_warning = True):
+  """Find an environment variable in system path."""
+  if name in repository_ctx.os.environ:
+    return repository_ctx.os.environ[name]
+  if default != None:
+    if enable_warning:
+      _python_configure_warning(
+          "'%s' environment variable is not set, using '%s' as default" % (name, default))
+    return default
+  _python_configure_fail("'%s' environment variable is not set" % name)
+
+
+def _is_windows(repository_ctx):
+  """Returns true if the host operating system is windows."""
+  os_name = repository_ctx.os.name.lower()
+  if os_name.find("windows") != -1:
+    return True
+  return False
+
+
+def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name):
+  """returns a genrule to symlink all files in a directory."""
+  # Get the list of files under this directory
+  find_result = None
+  if _is_windows(repository_ctx):
+    find_result = repository_ctx.execute([
+        "dir", src_dir, "/b", "/s", "/a-d",
+    ])
+  else:
+    find_result = repository_ctx.execute([
+        "find", src_dir, "-follow", "-type", "f",
+    ])
+  # Create a list with the src_dir stripped to use for outputs.
+  dest_files = find_result.stdout.replace(src_dir, '').splitlines()
+  src_files = find_result.stdout.splitlines()
+  command = []
+  command_windows = []
+  outs = []
+  outs_windows = []
+  for i in range(len(dest_files)):
+    if dest_files[i] != "":
+      command.append('ln -s ' + src_files[i] + ' $(@D)/' +
+                     dest_dir + dest_files[i])
+      # ln -sf is actually implemented as copying in msys since creating
+      # symbolic links is privileged on Windows. But copying is too slow, so
+      # invoke mklink to create junctions on Windows.
+      command_windows.append('mklink /J ' + src_files[i] + ' $(@D)/' +
+                             dest_dir + dest_files[i])
+      outs.append('      "' + dest_dir + dest_files[i] + '",')
+      outs_windows.append('      "' + dest_dir + '_windows' +
+                          dest_files[i] + '",')
+  genrule = _genrule(src_dir, genrule_name, ' && '.join(command),
+                     '\n'.join(outs))
+  genrule_windows = _genrule(src_dir, genrule_name + '_windows',
+                             "cmd /c \"" + ' && '.join(command_windows) + "\"",
+                             '\n'.join(outs_windows))
+  return genrule + '\n' + genrule_windows
+
+
+def _genrule(src_dir, genrule_name, command, outs):
+  """Returns a string with a genrule.
+
+  Genrule executes the given command and produces the given outputs.
+  """
+  return (
+      'genrule(\n' +
+      '    name = "' +
+      genrule_name + '",\n' +
+      '    outs = [\n' +
+      outs +
+      '    ],\n' +
+      '    cmd = """\n' +
+      command +
+      '    """,\n' +
+      ')\n'
+  )
+
+
+def _check_python_bin(repository_ctx, python_bin):
+  """Checks the python bin path."""
+  cmd =  '[[ -x "%s" ]] && [[ ! -d "%s" ]]' % (python_bin, python_bin)
+  result = repository_ctx.execute(["bash", "-c", cmd])
+  if result.return_code == 1:
+    _python_configure_fail(
+        "PYTHON_BIN_PATH is not executable.  Is it the python binary?")
+
+
+def _get_python_include(repository_ctx, python_bin):
+  """Gets the python include path."""
+  result = repository_ctx.execute([python_bin, "-c",
+                                   'from __future__ import print_function;' +
+                                   'from distutils import sysconfig;' +
+                                   'print(sysconfig.get_python_inc())'])
+  if result == "":
+    _python_configure_fail(
+        "Problem getting python include path.  Is distutils installed?")
+  return result.stdout.splitlines()[0]
+
+
+def _get_numpy_include(repository_ctx, python_bin):
+  """Gets the numpy include path."""
+  result = repository_ctx.execute([python_bin, "-c",
+                                   'from __future__ import print_function;' +
+                                   'import numpy;' +
+                                   ' print(numpy.get_include());'])
+  if result == "":
+    _python_configure_fail(
+        "Problem getting numpy include path.  Is numpy installed?")
+  return result.stdout.splitlines()[0]
+
+
+def _create_python_repository(repository_ctx):
+  """Creates the repository containing files set up to build with Python."""
+  python_include = None
+  numpy_include = None
+  # If local checks were requested, the python and numpy include will be auto
+  # detected on the host config (using _PYTHON_BIN_PATH).
+  if repository_ctx.attr.local_checks:
+    python_bin = _get_env_var(repository_ctx, _PYTHON_BIN_PATH)
+    _check_python_bin(repository_ctx, python_bin)
+    python_include = _get_python_include(repository_ctx, python_bin)
+    numpy_include = _get_numpy_include(repository_ctx, python_bin) + '/numpy'
+  else:
+    # Otherwise, we assume user provides all paths (via ENV or attrs)
+    python_include = _get_env_var(repository_ctx, _PYTHON_INCLUDE_PATH,
+                                  repository_ctx.attr.python_include)
+    numpy_include = _get_env_var(repository_ctx, _NUMPY_INCLUDE_PATH,
+                                 repository_ctx.attr.numpy_include) + '/numpy'
+
+  python_include_rule = _symlink_genrule_for_dir(
+      repository_ctx, python_include, 'python_include', 'python_include')
+  numpy_include_rule = _symlink_genrule_for_dir(
+      repository_ctx, numpy_include, 'numpy_include/numpy', 'numpy_include')
+  _tpl(repository_ctx, "BUILD", {
+      "%{PYTHON_INCLUDE_GENRULE}": python_include_rule,
+      "%{NUMPY_INCLUDE_GENRULE}": numpy_include_rule,
+  })
+
+
+def _python_autoconf_impl(repository_ctx):
+  """Implementation of the python_autoconf repository rule."""
+  _create_python_repository(repository_ctx)
+
+
+python_configure = repository_rule(
+    implementation = _python_autoconf_impl,
+    attrs = {
+        "local_checks": attr.bool(mandatory = False, default = True),
+        "python_include": attr.string(mandatory = False),
+        "numpy_include": attr.string(mandatory = False),
+    },
+    environ = [
+        _PYTHON_BIN_PATH,
+        _PYTHON_INCLUDE_PATH,
+        _NUMPY_INCLUDE_PATH,
+    ],
+)
+"""Detects and configures the local Python.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+python_configure(name = "local_config_python")
+```
+
+Args:
+  name: A unique name for this workspace rule.
+"""
diff --git a/util/python/BUILD b/util/python/BUILD
index 29688b875df..96daf9947ad 100644
--- a/util/python/BUILD
+++ b/util/python/BUILD
@@ -2,31 +2,7 @@ licenses(["restricted"])
 
 package(default_visibility = ["//visibility:public"])
 
-cc_library(
+alias(
     name = "python_headers",
-    hdrs = glob([
-        "python_include/**/*.h",
-    ]),
-    data = [":python_checked"],
-    includes = ["python_include"],
-)
-
-genrule(
-    name = "python_check",
-    srcs = [
-        "python_config.sh",
-        "configure_files",
-    ],
-    outs = [
-        "python_checked",
-    ],
-    cmd = "OUTPUTDIR=\"$(@D)/\"; $(location :python_config.sh) --check && touch $$OUTPUTDIR/python_checked",
-    local = 1,
-)
-
-filegroup(
-    name = "configure_files",
-    data = glob([
-        "*",
-    ]),
+    actual = "@local_config_python//:python_headers",
 )
diff --git a/util/python/python_config.sh b/util/python/python_config.sh
index 4b18bf3578d..d5762ad4561 100755
--- a/util/python/python_config.sh
+++ b/util/python/python_config.sh
@@ -26,23 +26,9 @@ else
   script_path=${script_path:-.}
 fi
 
-EXPECTED_PATHS="$script_path/util/python/python_include"\
-" $script_path/util/python/python_lib"\
-" $script_path/third_party/py/numpy/numpy_include"
-
 function main {
-  argument="$1"
-  shift
-  case $argument in
-    --check)
-      check_python
-      exit 0
-      ;;
-    --setup)
-      setup_python "$1"
-      exit 0
-      ;;
-  esac
+  setup_python "$1"
+  exit 0
 }
 
 function python_path {
@@ -93,6 +79,7 @@ END
 function setup_python {
   PYTHON_BIN_PATH="$1";
 
+  # TODO(ngiraldo): move most of these checks to root configure
   if [ -z "$PYTHON_BIN_PATH" ]; then
     echo "PYTHON_BIN_PATH was not provided.  Did you run configure?"
     exit 1
@@ -108,12 +95,7 @@ function setup_python {
     exit 1
   fi
 
-  local python_include="$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; from distutils import sysconfig; print(sysconfig.get_python_inc());')"
-  if [ "$python_include" == "" ]; then
-    echo -e "\n\nERROR: Problem getting python include path.  Is distutils installed?"
-    exit 1
-  fi
-
+  # TODO(ngiraldo): confirm if these checks are really necessary, remove if not
   if [ -z "$PYTHON_LIB_PATH" ]; then
     local python_lib_path
     # Split python_path into an array of paths, this allows path containing spaces
@@ -149,35 +131,12 @@ function setup_python {
     exit 1
   fi
 
-  local numpy_include=$("${PYTHON_BIN_PATH}" -c 'from __future__ import print_function; import numpy; print(numpy.get_include());')
-  if [ "$numpy_include" == "" ]; then
-    echo -e "\n\nERROR: Problem getting numpy include path.  Is numpy installed?"
-    exit 1
-  fi
-
-  for x in $EXPECTED_PATHS; do
-    if [ -e "$x" ]; then
-      rm -rf "$x"
-    fi
-  done
-
-# ln -sf is actually implemented as copying in msys since creating symbolic
-# links is privileged on Windows. But copying is too slow, so invoke mklink
-# to create junctions on Windows.
-  if is_windows; then
-    cmd /c "mklink /J util\\python\\python_include \"${python_include}\""
-    cmd /c "mklink /J util\\python\\python_lib \"${python_lib}\""
-    cmd /c "mklink /J third_party\\py\\numpy\\numpy_include \"${numpy_include}\""
-  else
-    ln -sf "${python_include}" util/python/python_include
-    ln -sf "${python_lib}" util/python/python_lib
-    ln -sf "${numpy_include}" third_party/py/numpy/numpy_include
-  fi
   # Convert python path to Windows style before writing into bazel.rc
   if is_windows; then
     PYTHON_BIN_PATH="$(cygpath -m "$PYTHON_BIN_PATH")"
   fi
 
+  # TODO(ngiraldo): move all below to root configure
   # Write tools/bazel.rc
   echo "# Autogenerated by configure: DO NOT EDIT" > tools/bazel.rc
   sed -e "s/\$PYTHON_MAJOR_VERSION/$python_major_version/g" \
@@ -197,29 +156,4 @@ function is_windows() {
   fi
 }
 
-function check_python {
-  for x in $EXPECTED_PATHS; do
-    if [ ! -e "$x" ]; then
-      echo -e "\n\nERROR: Cannot find '${x}'.  Did you run configure?\n\n" 1>&2
-      exit 1
-    fi
-    # Don't check symbolic link on Windows
-    if ! is_windows && [ ! -L "${x}" ]; then
-      echo -e "\n\nERROR: '${x}' is not a symbolic link.  Internal error.\n\n" 1>&2
-      exit 1
-    fi
-    if is_windows; then
-      # In msys, readlink <path> doesn't work, because no symbolic link on
-      # Windows. readlink -f <path> returns the real path of a junction.
-      true_path=$(readlink -f "${x}")
-    else
-      true_path=$(readlink "${x}")
-    fi
-    if [ ! -d "${true_path}" ]; then
-      echo -e "\n\nERROR: '${x}' does not refer to an existing directory: ${true_path}.  Do you need to rerun configure?\n\n" 1>&2
-      exit 1
-    fi
-  done
-}
-
 main "$@"

From 41135a11af38dff22676a0138336edf7ab9681bb Mon Sep 17 00:00:00 2001
From: Nikhil Thorat <nsthorat@google.com>
Date: Fri, 21 Apr 2017 08:55:29 -0800
Subject: [PATCH 11/27] Branch the projector so that we have a d3v4 version
 with the new typings. Change: 153837827

---
 tensorflow/BUILD                              |   1 +
 .../components/vz_projector_d3v4/BUILD        |  19 +
 .../vz_projector_d3v4/analyticsLogger.ts      |  67 ++
 .../components/vz_projector_d3v4/bh_tsne.ts   | 472 ++++++++++++
 .../vz_projector_d3v4/data-provider-demo.ts   | 127 +++
 .../vz_projector_d3v4/data-provider-proto.ts  |  88 +++
 .../vz_projector_d3v4/data-provider-server.ts | 137 ++++
 .../vz_projector_d3v4/data-provider.ts        | 429 +++++++++++
 .../vz_projector_d3v4/data-provider_test.ts   |  96 +++
 .../components/vz_projector_d3v4/data.ts      | 547 +++++++++++++
 .../components/vz_projector_d3v4/data_test.ts | 104 +++
 .../vz_projector_d3v4/external.d.ts           |  51 ++
 .../components/vz_projector_d3v4/heap.ts      | 146 ++++
 .../components/vz_projector_d3v4/knn.ts       | 235 ++++++
 .../components/vz_projector_d3v4/label.ts     | 151 ++++
 .../components/vz_projector_d3v4/logging.ts   | 103 +++
 .../projectorEventContext.ts                  |  45 ++
 .../projectorScatterPlotAdapter.ts            | 713 +++++++++++++++++
 .../vz_projector_d3v4/renderContext.ts        |  53 ++
 .../vz_projector_d3v4/scatterPlot.ts          | 723 ++++++++++++++++++
 .../scatterPlotRectangleSelector.ts           | 107 +++
 .../scatterPlotRectangleSelector_test.ts      |  69 ++
 .../scatterPlotVisualizer.ts                  |  51 ++
 .../scatterPlotVisualizer3DLabels.ts          | 367 +++++++++
 .../scatterPlotVisualizerCanvasLabels.ts      | 187 +++++
 .../scatterPlotVisualizerPolylines.ts         | 149 ++++
 .../scatterPlotVisualizerSprites.ts           | 435 +++++++++++
 .../components/vz_projector_d3v4/sptree.ts    | 175 +++++
 .../vz_projector_d3v4/sptree_test.ts          | 104 +++
 .../components/vz_projector_d3v4/styles.html  | 185 +++++
 .../components/vz_projector_d3v4/util.ts      | 252 ++++++
 .../components/vz_projector_d3v4/util_test.ts |  42 +
 .../components/vz_projector_d3v4/vector.ts    | 266 +++++++
 .../vz_projector_d3v4/vz-projector-app.html   | 105 +++
 .../vz-projector-bookmark-panel.html          | 205 +++++
 .../vz-projector-bookmark-panel.ts            | 283 +++++++
 .../vz_projector_d3v4/vz-projector-colab.html |  32 +
 .../vz-projector-dashboard.html               |  79 ++
 .../vz-projector-data-panel.html              | 399 ++++++++++
 .../vz-projector-data-panel.ts                | 497 ++++++++++++
 .../vz_projector_d3v4/vz-projector-input.html |  64 ++
 .../vz_projector_d3v4/vz-projector-input.ts   | 113 +++
 .../vz-projector-inspector-panel.html         | 240 ++++++
 .../vz-projector-inspector-panel.ts           | 337 ++++++++
 .../vz-projector-legend.html                  |  76 ++
 .../vz_projector_d3v4/vz-projector-legend.ts  |  98 +++
 .../vz-projector-metadata-card.html           |  97 +++
 .../vz-projector-metadata-card.ts             |  88 +++
 .../vz-projector-projections-panel.html       | 314 ++++++++
 .../vz-projector-projections-panel.ts         | 589 ++++++++++++++
 .../vz-projector-projections-panel_test.ts    | 109 +++
 .../vz_projector_d3v4/vz-projector-util.ts    |  34 +
 .../vz_projector_d3v4/vz-projector.html       | 343 +++++++++
 .../vz_projector_d3v4/vz-projector.ts         | 570 ++++++++++++++
 54 files changed, 11368 insertions(+)
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/BUILD
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/analyticsLogger.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/bh_tsne.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-demo.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-proto.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-server.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/data-provider.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/data-provider_test.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/data.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/data_test.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/external.d.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/heap.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/knn.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/label.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/logging.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/projectorEventContext.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/projectorScatterPlotAdapter.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/renderContext.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlot.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector_test.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer3DLabels.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerCanvasLabels.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerPolylines.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerSprites.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/sptree.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/sptree_test.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/styles.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/util.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/util_test.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vector.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-app.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-colab.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel_test.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-util.ts
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.html
 create mode 100644 tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.ts

diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index f0c1271e898..0f7f848cb1a 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -297,6 +297,7 @@ filegroup(
         "//tensorflow/tensorboard/components/vz_line_chart:all_files",
         "//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",
         "//tensorflow/tensorboard/components/vz_projector:all_files",
+        "//tensorflow/tensorboard/components/vz_projector_d3v4:all_files",
         "//tensorflow/tensorboard/components/vz_sorting:all_files",
         "//tensorflow/tensorboard/components/vz_sorting/test:all_files",
         "//tensorflow/tensorboard/lib:all_files",
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/BUILD b/tensorflow/tensorboard/components/vz_projector_d3v4/BUILD
new file mode 100644
index 00000000000..8c222be10e9
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/BUILD
@@ -0,0 +1,19 @@
+# Description:
+# Package for the Embedding Projector component.
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/analyticsLogger.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/analyticsLogger.ts
new file mode 100644
index 00000000000..aa1f86927da
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/analyticsLogger.ts
@@ -0,0 +1,67 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+import {ProjectionType} from './data';
+
+export class AnalyticsLogger {
+  private eventLogging: boolean;
+  private pageViewLogging: boolean;
+
+  /**
+   * Constructs an event logger using Google Analytics. It assumes there is a
+   * Google Analytics script added to the page elsewhere. If there is no such
+   * script, the logger acts as a no-op.
+   *
+   * @param pageViewLogging Whether to log page views.
+   * @param eventLogging Whether to log user interaction.
+   */
+  constructor(pageViewLogging: boolean, eventLogging: boolean) {
+    if (typeof ga === 'undefined' || ga == null) {
+      this.eventLogging = false;
+      this.pageViewLogging = false;
+      return;
+    }
+    this.eventLogging = eventLogging;
+    this.pageViewLogging = pageViewLogging;
+  }
+
+  logPageView(pageTitle: string) {
+    if (this.pageViewLogging) {
+      // Always send a page view.
+      ga('send', {hitType: 'pageview', page: `/v/${pageTitle}`});
+    }
+  }
+
+  logProjectionChanged(projection: ProjectionType) {
+    if (this.eventLogging) {
+      ga('send', {
+        hitType: 'event',
+        eventCategory: 'Projection',
+        eventAction: 'click',
+        eventLabel: projection
+      });
+    }
+  }
+
+  logWebGLDisabled() {
+    if (this.eventLogging) {
+      ga('send', {
+        hitType: 'event',
+        eventCategory: 'Error',
+        eventAction: 'PageLoad',
+        eventLabel: 'WebGL_disabled'
+      });
+    }
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/bh_tsne.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/bh_tsne.ts
new file mode 100644
index 00000000000..9d2df65f560
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/bh_tsne.ts
@@ -0,0 +1,472 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/**
+ * This is a fork of the Karpathy's TSNE.js (original license below).
+ * This fork implements Barnes-Hut approximation and runs in O(NlogN)
+ * time, as opposed to the Karpathy's O(N^2) version.
+ *
+ * @author smilkov@google.com (Daniel Smilkov)
+ */
+
+/**
+ * The MIT License (MIT)
+ * Copyright (c) 2015 Andrej Karpathy
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+import {SPNode, SPTree} from './sptree';
+
+type AugmSPNode = SPNode&{numCells: number, yCell: number[], rCell: number};
+
+/**
+ * Barnes-hut approximation level. Higher means more approximation and faster
+ * results. Recommended value mentioned in the paper is 0.8.
+ */
+const THETA = 0.8;
+
+const MIN_POSSIBLE_PROB = 1E-9;
+
+// Variables used for memorizing the second random number since running
+// gaussRandom() generates two random numbers at the cost of 1 atomic
+// computation. This optimization results in 2X speed-up of the generator.
+let return_v = false;
+let v_val = 0.0;
+
+/** Returns the square euclidean distance between two vectors. */
+export function dist2(a: number[], b: number[]): number {
+  if (a.length !== b.length) {
+    throw new Error('Vectors a and b must be of same length');
+  }
+
+  let result = 0;
+  for (let i = 0; i < a.length; ++i) {
+    let diff = a[i] - b[i];
+    result += diff * diff;
+  }
+  return result;
+}
+
+/** Returns the square euclidean distance between two 2D points. */
+export function dist2_2D(a: number[], b: number[]): number {
+  let dX = a[0] - b[0];
+  let dY = a[1] - b[1];
+  return dX * dX + dY * dY;
+}
+
+/** Returns the square euclidean distance between two 3D points. */
+export function dist2_3D(a: number[], b: number[]): number {
+  let dX = a[0] - b[0];
+  let dY = a[1] - b[1];
+  let dZ = a[2] - b[2];
+  return dX * dX + dY * dY + dZ * dZ;
+}
+
+function gaussRandom(rng: () => number): number {
+  if (return_v) {
+    return_v = false;
+    return v_val;
+  }
+  let u = 2 * rng() - 1;
+  let v = 2 * rng() - 1;
+  let r = u * u + v * v;
+  if (r === 0 || r > 1) {
+    return gaussRandom(rng);
+  }
+  let c = Math.sqrt(-2 * Math.log(r) / r);
+  v_val = v * c;  // cache this for next function call for efficiency
+  return_v = true;
+  return u * c;
+};
+
+// return random normal number
+function randn(rng: () => number, mu: number, std: number) {
+  return mu + gaussRandom(rng) * std;
+};
+
+// utilitity that creates contiguous vector of zeros of size n
+function zeros(n: number): Float64Array {
+  return new Float64Array(n);
+};
+
+// utility that returns a matrix filled with random numbers
+// generated by the provided generator.
+function randnMatrix(n: number, d: number, rng: () => number) {
+  let nd = n * d;
+  let x = zeros(nd);
+  for (let i = 0; i < nd; ++i) {
+    x[i] = randn(rng, 0.0, 1E-4);
+  }
+  return x;
+};
+
+// utility that returns a matrix filled with the provided value.
+function arrayofs(n: number, d: number, val: number) {
+  let x: number[][] = [];
+  for (let i = 0; i < n; ++i) {
+    x.push(d === 3 ? [val, val, val] : [val, val]);
+  }
+  return x;
+};
+
+// compute (p_{i|j} + p_{j|i})/(2n)
+function nearest2P(
+    nearest: {index: number, dist: number}[][], perplexity: number,
+    tol: number) {
+  let N = nearest.length;
+  let Htarget = Math.log(perplexity);  // target entropy of distribution
+  let P = zeros(N * N);                // temporary probability matrix
+  let K = nearest[0].length;
+  let pRow: number[] = new Array(K);  // pij[].
+
+  for (let i = 0; i < N; ++i) {
+    let neighbors = nearest[i];
+    let betaMin = -Infinity;
+    let betaMax = Infinity;
+    let beta = 1;  // initial value of precision
+    let maxTries = 50;
+
+    // perform binary search to find a suitable precision beta
+    // so that the entropy of the distribution is appropriate
+    let numTries = 0;
+    while (true) {
+      // compute entropy and kernel row with beta precision
+      let psum = 0.0;
+      for (let k = 0; k < neighbors.length; ++k) {
+        let neighbor = neighbors[k];
+        let pij = (i === neighbor.index) ? 0 : Math.exp(-neighbor.dist * beta);
+        pij = Math.max(pij, MIN_POSSIBLE_PROB);
+        pRow[k] = pij;
+        psum += pij;
+      }
+      // normalize p and compute entropy
+      let Hhere = 0.0;
+      for (let k = 0; k < pRow.length; ++k) {
+        pRow[k] /= psum;
+        let pij = pRow[k];
+        if (pij > 1E-7) {
+          Hhere -= pij * Math.log(pij);
+        };
+      }
+
+      // adjust beta based on result
+      if (Hhere > Htarget) {
+        // entropy was too high (distribution too diffuse)
+        // so we need to increase the precision for more peaky distribution
+        betaMin = beta;  // move up the bounds
+        if (betaMax === Infinity) {
+          beta = beta * 2;
+        } else {
+          beta = (beta + betaMax) / 2;
+        }
+
+      } else {
+        // converse case. make distrubtion less peaky
+        betaMax = beta;
+        if (betaMin === -Infinity) {
+          beta = beta / 2;
+        } else {
+          beta = (beta + betaMin) / 2;
+        }
+      }
+      numTries++;
+      // stopping conditions: too many tries or got a good precision
+      if (numTries >= maxTries || Math.abs(Hhere - Htarget) < tol) {
+        break;
+      }
+    }
+
+    // copy over the final prow to P at row i
+    for (let k = 0; k < pRow.length; ++k) {
+      let pij = pRow[k];
+      let j = neighbors[k].index;
+      P[i * N + j] = pij;
+    }
+  }  // end loop over examples i
+
+  // symmetrize P and normalize it to sum to 1 over all ij
+  let N2 = N * 2;
+  for (let i = 0; i < N; ++i) {
+    for (let j = i + 1; j < N; ++j) {
+      let i_j = i * N + j;
+      let j_i = j * N + i;
+      let value = (P[i_j] + P[j_i]) / N2;
+      P[i_j] = value;
+      P[j_i] = value;
+    }
+  }
+  return P;
+};
+
+// helper function
+function sign(x: number) {
+  return x > 0 ? 1 : x < 0 ? -1 : 0;
+}
+
+function computeForce_2d(
+    force: number[], mult: number, pointA: number[], pointB: number[]) {
+  force[0] += mult * (pointA[0] - pointB[0]);
+  force[1] += mult * (pointA[1] - pointB[1]);
+}
+
+function computeForce_3d(
+    force: number[], mult: number, pointA: number[], pointB: number[]) {
+  force[0] += mult * (pointA[0] - pointB[0]);
+  force[1] += mult * (pointA[1] - pointB[1]);
+  force[2] += mult * (pointA[2] - pointB[2]);
+}
+
+export interface TSNEOptions {
+  /** How many dimensions. */
+  dim: number;
+  /** Roughly how many neighbors each point influences. */
+  perplexity?: number;
+  /** Learning rate. */
+  epsilon?: number;
+  /** A random number generator. */
+  rng?: () => number;
+}
+
+export class TSNE {
+  private perplexity: number;
+  private epsilon: number;
+  /** Random generator */
+  private rng: () => number;
+  private iter = 0;
+  private Y: Float64Array;
+  private N: number;
+  private P: Float64Array;
+  private gains: number[][];
+  private ystep: number[][];
+  private nearest: {index: number, dist: number}[][];
+  private dim: number;
+  private dist2: (a: number[], b: number[]) => number;
+  private computeForce:
+      (force: number[], mult: number, pointA: number[],
+       pointB: number[]) => void;
+
+  constructor(opt: TSNEOptions) {
+    opt = opt || {dim: 2};
+    this.perplexity = opt.perplexity || 30;
+    this.epsilon = opt.epsilon || 10;
+    this.rng = opt.rng || Math.random;
+    this.dim = opt.dim;
+    if (opt.dim === 2) {
+      this.dist2 = dist2_2D;
+      this.computeForce = computeForce_2d;
+    } else if (opt.dim === 3) {
+      this.dist2 = dist2_3D;
+      this.computeForce = computeForce_3d;
+    } else {
+      throw new Error('Only 2D and 3D is supported');
+    }
+  }
+
+  // this function takes a fattened distance matrix and creates
+  // matrix P from them.
+  // D is assumed to be provided as an array of size N^2.
+  initDataDist(nearest: {index: number, dist: number}[][]) {
+    let N = nearest.length;
+    this.nearest = nearest;
+    this.P = nearest2P(nearest, this.perplexity, 1E-4);
+    this.N = N;
+    this.initSolution();  // refresh this
+  }
+
+  // (re)initializes the solution to random
+  initSolution() {
+    // generate random solution to t-SNE
+    this.Y = randnMatrix(this.N, this.dim, this.rng);  // the solution
+    this.gains = arrayofs(this.N, this.dim, 1.0);      // step gains
+    // to accelerate progress in unchanging directions
+    this.ystep = arrayofs(this.N, this.dim, 0.0);  // momentum accumulator
+    this.iter = 0;
+  }
+
+  // return pointer to current solution
+  getSolution() { return this.Y; }
+
+  // perform a single step of optimization to improve the embedding
+  step() {
+    this.iter += 1;
+    let N = this.N;
+
+    let grad = this.costGrad(this.Y);  // evaluate gradient
+
+    // perform gradient step
+    let ymean = this.dim === 3 ? [0, 0, 0] : [0, 0];
+    for (let i = 0; i < N; ++i) {
+      for (let d = 0; d < this.dim; ++d) {
+        let gid = grad[i][d];
+        let sid = this.ystep[i][d];
+        let gainid = this.gains[i][d];
+
+        // compute gain update
+        let newgain = sign(gid) === sign(sid) ? gainid * 0.8 : gainid + 0.2;
+        if (newgain < 0.01) {
+          newgain = 0.01;  // clamp
+        }
+        this.gains[i][d] = newgain;  // store for next turn
+
+        // compute momentum step direction
+        let momval = this.iter < 250 ? 0.5 : 0.8;
+        let newsid = momval * sid - this.epsilon * newgain * grad[i][d];
+        this.ystep[i][d] = newsid;  // remember the step we took
+
+        // step!
+        let i_d = i * this.dim + d;
+        this.Y[i_d] += newsid;
+        ymean[d] += this.Y[i_d];  // accumulate mean so that we
+                                  // can center later
+      }
+    }
+
+    // reproject Y to be zero mean
+    for (let i = 0; i < N; ++i) {
+      for (let d = 0; d < this.dim; ++d) {
+        this.Y[i * this.dim + d] -= ymean[d] / N;
+      }
+    }
+  }
+
+  // return cost and gradient, given an arrangement
+  costGrad(Y: Float64Array): number[][] {
+    let N = this.N;
+    let P = this.P;
+
+    // Trick that helps with local optima.
+    let alpha = this.iter < 100 ? 4 : 1;
+
+    // Make data for the SP tree.
+    let points: number[][] = new Array(N);  // (x, y)[]
+    for (let i = 0; i < N; ++i) {
+      let iTimesD = i * this.dim;
+      let row = new Array(this.dim);
+      for (let d = 0; d < this.dim; ++d) {
+        row[d] = Y[iTimesD + d];
+      }
+      points[i] = row;
+    }
+
+    // Make a tree.
+    let tree = new SPTree(points);
+    let root = tree.root as AugmSPNode;
+    // Annotate the tree.
+
+    let annotateTree =
+        (node: AugmSPNode): {numCells: number, yCell: number[]} => {
+          let numCells = 1;
+          if (node.children == null) {
+            // Update the current node and tell the parent.
+            node.numCells = numCells;
+            node.yCell = node.point;
+            return {numCells, yCell: node.yCell};
+          }
+          // node.point is a 2 or 3-dim number[], so slice() makes a copy.
+          let yCell = node.point.slice();
+          for (let i = 0; i < node.children.length; ++i) {
+            let child = node.children[i];
+            if (child == null) {
+              continue;
+            }
+            let result = annotateTree(child as AugmSPNode);
+            numCells += result.numCells;
+            for (let d = 0; d < this.dim; ++d) {
+              yCell[d] += result.yCell[d];
+            }
+          }
+          // Update the node and tell the parent.
+          node.numCells = numCells;
+          node.yCell = yCell.map(v => v / numCells);
+          return {numCells, yCell};
+        };
+
+    // Augment the tree with more info.
+    annotateTree(root);
+    tree.visit((node: AugmSPNode, low: number[], high: number[]) => {
+      node.rCell = high[0] - low[0];
+      return false;
+    });
+    // compute current Q distribution, unnormalized first
+    let grad: number[][] = [];
+    let Z = 0;
+    let forces: [number[], number[]][] = new Array(N);
+    for (let i = 0; i < N; ++i) {
+      let pointI = points[i];
+      // Compute the positive forces for the i-th node.
+      let Fpos = this.dim === 3 ? [0, 0, 0] : [0, 0];
+      let neighbors = this.nearest[i];
+      for (let k = 0; k < neighbors.length; ++k) {
+        let j = neighbors[k].index;
+        let pij = P[i * N + j];
+        let pointJ = points[j];
+        let squaredDistItoJ = this.dist2(pointI, pointJ);
+        let premult = pij / (1 + squaredDistItoJ);
+        this.computeForce(Fpos, premult, pointI, pointJ);
+      }
+      // Compute the negative forces for the i-th node.
+      let FnegZ = this.dim === 3 ? [0, 0, 0] : [0, 0];
+      tree.visit((node: AugmSPNode) => {
+        let squaredDistToCell = this.dist2(pointI, node.yCell);
+        // Squared distance from point i to cell.
+        if (node.children == null ||
+            (squaredDistToCell > 0 &&
+             node.rCell / Math.sqrt(squaredDistToCell) < THETA)) {
+          let qijZ = 1 / (1 + squaredDistToCell);
+          let dZ = node.numCells * qijZ;
+          Z += dZ;
+          dZ *= qijZ;
+          this.computeForce(FnegZ, dZ, pointI, node.yCell);
+          return true;
+        }
+        // Cell is too close to approximate.
+        let squaredDistToPoint = this.dist2(pointI, node.point);
+        let qijZ = 1 / (1 + squaredDistToPoint);
+        Z += qijZ;
+        qijZ *= qijZ;
+        this.computeForce(FnegZ, qijZ, pointI, node.point);
+        return false;
+      }, true);
+      forces[i] = [Fpos, FnegZ];
+    }
+    // Normalize the negative forces and compute the gradient.
+    const A = 4 * alpha;
+    const B = 4 / Z;
+    for (let i = 0; i < N; ++i) {
+      let [FPos, FNegZ] = forces[i];
+      let gsum = new Array(this.dim);
+      for (let d = 0; d < this.dim; ++d) {
+        gsum[d] = A * FPos[d] - B * FNegZ[d];
+      }
+      grad.push(gsum);
+    }
+    return grad;
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-demo.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-demo.ts
new file mode 100644
index 00000000000..1410a84a8e4
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-demo.ts
@@ -0,0 +1,127 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {DataSet, SpriteAndMetadataInfo, State} from './data';
+import {ProjectorConfig, DataProvider, EmbeddingInfo, TENSORS_MSG_ID} from './data-provider';
+import * as dataProvider from './data-provider';
+import * as logging from './logging';
+
+const BYTES_EXTENSION = '.bytes';
+
+/** Data provider that loads data from a demo folder. */
+export class DemoDataProvider implements DataProvider {
+  private projectorConfigPath: string;
+  private projectorConfig: ProjectorConfig;
+
+  constructor(projectorConfigPath: string) {
+    this.projectorConfigPath = projectorConfigPath;
+  }
+
+  private getEmbeddingInfo(tensorName: string): EmbeddingInfo {
+    let embeddings = this.projectorConfig.embeddings;
+    for (let i = 0; i < embeddings.length; i++) {
+      let embedding = embeddings[i];
+      if (embedding.tensorName === tensorName) {
+        return embedding;
+      }
+    }
+    return null;
+  }
+
+  retrieveRuns(callback: (runs: string[]) => void): void {
+    callback(['Demo']);
+  }
+
+  retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void)
+      : void {
+    const msgId = logging.setModalMessage('Fetching projector config...');
+
+    const xhr = new XMLHttpRequest();
+    xhr.open('GET', this.projectorConfigPath);
+    xhr.onerror = (err) => {
+      let errorMessage = err.message;
+      // If the error is a valid XMLHttpResponse, it's possible this is a
+      // cross-origin error.
+      if (xhr.responseText != null) {
+        errorMessage = 'Cannot fetch projector config, possibly a ' +
+            'Cross-Origin request error.';
+      }
+      logging.setErrorMessage(errorMessage, 'fetching projector config');
+    };
+    xhr.onload = () => {
+      const projectorConfig = JSON.parse(xhr.responseText) as ProjectorConfig;
+      logging.setModalMessage(null, msgId);
+      this.projectorConfig = projectorConfig;
+      callback(projectorConfig);
+    };
+    xhr.send();
+  }
+
+  retrieveTensor(run: string, tensorName: string,
+      callback: (ds: DataSet) => void) {
+    let embedding = this.getEmbeddingInfo(tensorName);
+    let url = `${embedding.tensorPath}`;
+    if (embedding.tensorPath.substr(-1 * BYTES_EXTENSION.length) ===
+        BYTES_EXTENSION) {
+      dataProvider.retrieveTensorAsBytes(
+          this, this.getEmbeddingInfo(tensorName), run, tensorName, url,
+          callback);
+    } else {
+      logging.setModalMessage('Fetching tensors...', TENSORS_MSG_ID);
+      const request = new XMLHttpRequest();
+      request.open('GET', url);
+      request.responseType = 'arraybuffer';
+
+      request.onerror = () => {
+        logging.setErrorMessage(request.responseText, 'fetching tensors');
+      };
+      request.onload = () => {
+        dataProvider.parseTensors(request.response).then(points => {
+          callback(new DataSet(points));
+        });
+      };
+      request.send();
+    }
+  }
+
+  retrieveSpriteAndMetadata(run: string, tensorName: string,
+      callback: (r: SpriteAndMetadataInfo) => void) {
+    let embedding = this.getEmbeddingInfo(tensorName);
+    let spriteImagePath = null;
+    if (embedding.sprite && embedding.sprite.imagePath) {
+      spriteImagePath = embedding.sprite.imagePath;
+    }
+    dataProvider.retrieveSpriteAndMetadataInfo(
+        embedding.metadataPath, spriteImagePath, embedding.sprite, callback);
+  }
+
+  getBookmarks(
+      run: string, tensorName: string, callback: (r: State[]) => void) {
+    let embedding = this.getEmbeddingInfo(tensorName);
+    let msgId = logging.setModalMessage('Fetching bookmarks...');
+
+    const xhr = new XMLHttpRequest();
+    xhr.open('GET', embedding.bookmarksPath);
+    xhr.onerror = (err) => {
+      logging.setErrorMessage(xhr.responseText);
+    };
+    xhr.onload = () => {
+      const bookmarks = JSON.parse(xhr.responseText) as State[];
+      logging.setModalMessage(null, msgId);
+      callback(bookmarks);
+    };
+    xhr.send();
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-proto.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-proto.ts
new file mode 100644
index 00000000000..67124a92323
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-proto.ts
@@ -0,0 +1,88 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {DataPoint, DataProto, DataSet, SpriteAndMetadataInfo, PointMetadata, State} from './data';
+import {analyzeMetadata, ProjectorConfig, DataProvider} from './data-provider';
+
+
+export class ProtoDataProvider implements DataProvider {
+  private dataProto: DataProto;
+
+  constructor(dataProto: DataProto) {
+    this.dataProto = dataProto;
+  }
+
+  retrieveRuns(callback: (runs: string[]) => void): void {
+    callback(['proto']);
+  }
+
+  retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void) {
+    callback({
+      modelCheckpointPath: 'proto',
+      embeddings: [{
+        tensorName: 'proto',
+        tensorShape: this.dataProto.shape,
+        metadataPath: 'proto'
+      }]
+    });
+  }
+
+  retrieveTensor(run: string, tensorName: string,
+      callback: (ds: DataSet) => void) {
+    callback(this.flatArrayToDataset(this.dataProto.tensor));
+  }
+
+  retrieveSpriteAndMetadata(run: string, tensorName: string,
+      callback: (r: SpriteAndMetadataInfo) => void): void {
+    let columnNames = this.dataProto.metadata.columns.map(c => c.name);
+    let n = this.dataProto.shape[0];
+    let pointsMetadata: PointMetadata[] = new Array(n);
+    this.dataProto.metadata.columns.forEach(c => {
+      let values = c.numericValues || c.stringValues;
+      for (let i = 0; i < n; i++) {
+        pointsMetadata[i] = pointsMetadata[i] || {};
+        pointsMetadata[i][c.name] = values[i];
+      }
+    });
+    callback({
+      stats: analyzeMetadata(columnNames, pointsMetadata),
+      pointsInfo: pointsMetadata
+    });
+  }
+
+  getBookmarks(run: string, tensorName: string,
+      callback: (r: State[]) => void): void {
+    return callback([]);
+  }
+
+  private flatArrayToDataset(tensor: number[]): DataSet {
+    let points: DataPoint[] = [];
+    let n = this.dataProto.shape[0];
+    let d = this.dataProto.shape[1];
+    if (n * d !== tensor.length) {
+      throw 'The shape doesn\'t match the length of the flattened array';
+    }
+    for (let i = 0; i < n; i++) {
+      let offset = i * d;
+      points.push({
+        vector: new Float32Array(tensor.slice(offset, offset + d)),
+        metadata: {},
+        projections: null,
+        index: i
+      });
+    }
+    return new DataSet(points);
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-server.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-server.ts
new file mode 100644
index 00000000000..02720ebf6a7
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider-server.ts
@@ -0,0 +1,137 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {DataSet, SpriteAndMetadataInfo, State} from './data';
+import * as dataProvider from './data-provider';
+import {DataProvider, EmbeddingInfo, ProjectorConfig} from './data-provider';
+import * as logging from './logging';
+
+// Limit for the number of data points we receive from the server.
+export const LIMIT_NUM_POINTS = 100000;
+
+/**
+ * Data provider that loads data provided by a python server (usually backed
+ * by a checkpoint file).
+ */
+export class ServerDataProvider implements DataProvider {
+  private routePrefix: string;
+  private runProjectorConfigCache: {[run: string]: ProjectorConfig} = {};
+
+  constructor(routePrefix: string) {
+    this.routePrefix = routePrefix;
+  }
+
+  private getEmbeddingInfo(run: string, tensorName: string,
+      callback: (e: EmbeddingInfo) => void): void {
+    this.retrieveProjectorConfig(run, config => {
+      const embeddings = config.embeddings;
+      for (let i = 0; i < embeddings.length; i++) {
+        const embedding = embeddings[i];
+        if (embedding.tensorName === tensorName) {
+          callback(embedding);
+          return;
+        }
+      }
+      callback(null);
+    });
+  }
+
+  retrieveRuns(callback: (runs: string[]) => void): void {
+    const msgId = logging.setModalMessage('Fetching runs...');
+
+    const xhr = new XMLHttpRequest();
+    xhr.open('GET', `${this.routePrefix}/runs`);
+    xhr.onerror = (err) => {
+      logging.setErrorMessage(xhr.responseText, 'fetching runs');
+    };
+    xhr.onload = () => {
+      const runs = JSON.parse(xhr.responseText);
+      logging.setModalMessage(null, msgId);
+      callback(runs);
+    };
+    xhr.send();
+  }
+
+  retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void)
+      : void {
+    if (run in this.runProjectorConfigCache) {
+      callback(this.runProjectorConfigCache[run]);
+      return;
+    }
+
+    const msgId = logging.setModalMessage('Fetching projector config...');
+
+    const xhr = new XMLHttpRequest();
+    xhr.open('GET', `${this.routePrefix}/info?run=${run}`);
+    xhr.onerror = (err) => {
+      logging.setErrorMessage(xhr.responseText, 'fetching projector config');
+    };
+    xhr.onload = () => {
+      const config = JSON.parse(xhr.responseText) as ProjectorConfig;
+      logging.setModalMessage(null, msgId);
+      this.runProjectorConfigCache[run] = config;
+      callback(config);
+    };
+    xhr.send();
+  }
+
+  retrieveTensor(run: string, tensorName: string,
+      callback: (ds: DataSet) => void) {
+    this.getEmbeddingInfo(run, tensorName, embedding => {
+      dataProvider.retrieveTensorAsBytes(
+          this, embedding, run, tensorName,
+          `${this.routePrefix}/tensor?run=${run}&name=${tensorName}` +
+              `&num_rows=${LIMIT_NUM_POINTS}`,
+          callback);
+    });
+  }
+
+  retrieveSpriteAndMetadata(run: string, tensorName: string,
+      callback: (r: SpriteAndMetadataInfo) => void) {
+    this.getEmbeddingInfo(run, tensorName, embedding => {
+      let metadataPath = null;
+      if (embedding.metadataPath) {
+        metadataPath =
+            `${this.routePrefix}/metadata?` +
+            `run=${run}&name=${tensorName}&num_rows=${LIMIT_NUM_POINTS}`;
+      }
+      let spriteImagePath = null;
+      if (embedding.sprite && embedding.sprite.imagePath) {
+        spriteImagePath =
+            `${this.routePrefix}/sprite_image?run=${run}&name=${tensorName}`;
+      }
+      dataProvider.retrieveSpriteAndMetadataInfo(metadataPath, spriteImagePath,
+          embedding.sprite, callback);
+    });
+  }
+
+  getBookmarks(
+      run: string, tensorName: string, callback: (r: State[]) => void) {
+    const msgId = logging.setModalMessage('Fetching bookmarks...');
+
+    const xhr = new XMLHttpRequest();
+    xhr.open(
+        'GET', `${this.routePrefix}/bookmarks?run=${run}&name=${tensorName}`);
+    xhr.onerror = (err) => {
+      logging.setErrorMessage(xhr.responseText, 'fetching bookmarks');
+    };
+    xhr.onload = () => {
+      logging.setModalMessage(null, msgId);
+      const bookmarks = JSON.parse(xhr.responseText);
+      callback(bookmarks);
+    };
+    xhr.send();
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider.ts
new file mode 100644
index 00000000000..c8eede798c6
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider.ts
@@ -0,0 +1,429 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {ColumnStats, DataPoint, DataSet, SpriteAndMetadataInfo, PointMetadata, State} from './data';
+import * as logging from './logging';
+import {runAsyncTask} from './util';
+
+/** Maximum number of colors supported in the color map. */
+const NUM_COLORS_COLOR_MAP = 50;
+const MAX_SPRITE_IMAGE_SIZE_PX = 8192;
+
+export const METADATA_MSG_ID = 'metadata';
+export const TENSORS_MSG_ID = 'tensors';
+
+/** Matches the json format of `projector_config.proto` */
+export interface SpriteMetadata {
+  imagePath: string;
+  singleImageDim: [number, number];
+}
+
+/** Matches the json format of `projector_config.proto` */
+export interface EmbeddingInfo {
+  /** Name of the tensor. */
+  tensorName: string;
+  /** The shape of the tensor. */
+  tensorShape: [number, number];
+  /**
+   * The path to the tensors TSV file. If empty, it is assumed that the tensor
+   * is stored in the checkpoint file.
+   */
+  tensorPath?: string;
+  /** The path to the metadata file associated with the tensor. */
+  metadataPath?: string;
+  /** The path to the bookmarks file associated with the tensor. */
+  bookmarksPath?: string;
+  sprite?: SpriteMetadata;
+}
+
+/**
+ * Matches the json format of `projector_config.proto`
+ * This should be kept in sync with the code in vz-projector-data-panel which
+ * holds a template for users to build a projector config JSON object from the
+ * projector UI.
+ */
+export interface ProjectorConfig {
+  embeddings: EmbeddingInfo[];
+  modelCheckpointPath?: string;
+}
+
+export type ServingMode = 'demo' | 'server' | 'proto';
+
+/** Interface between the data storage and the UI. */
+export interface DataProvider {
+  /** Returns a list of run names that have embedding config files. */
+  retrieveRuns(callback: (runs: string[]) => void): void;
+
+  /**
+   * Returns the projector configuration: number of tensors, their shapes,
+   * and their associated metadata files.
+   */
+  retrieveProjectorConfig(run: string,
+      callback: (d: ProjectorConfig) => void): void;
+
+  /** Fetches and returns the tensor with the specified name. */
+  retrieveTensor(run: string, tensorName: string,
+      callback: (ds: DataSet) => void);
+
+  /**
+   * Fetches the metadata for the specified tensor.
+   */
+  retrieveSpriteAndMetadata(run: string, tensorName: string,
+      callback: (r: SpriteAndMetadataInfo) => void): void;
+
+  getBookmarks(run: string, tensorName: string, callback: (r: State[]) => void):
+      void;
+}
+
+export function retrieveTensorAsBytes(
+    dp: DataProvider, embedding: EmbeddingInfo, run: string, tensorName: string,
+    tensorsPath: string, callback: (ds: DataSet) => void) {
+  // Get the tensor.
+  logging.setModalMessage('Fetching tensor values...', TENSORS_MSG_ID);
+  let xhr = new XMLHttpRequest();
+  xhr.open('GET', tensorsPath);
+  xhr.responseType = 'arraybuffer';
+  xhr.onprogress = (ev) => {
+    if (ev.lengthComputable) {
+      let percent = (ev.loaded * 100 / ev.total).toFixed(1);
+      logging.setModalMessage(
+          'Fetching tensor values: ' + percent + '%', TENSORS_MSG_ID);
+    }
+  };
+  xhr.onload = () => {
+    if (xhr.status !== 200) {
+      let msg = String.fromCharCode.apply(null, new Uint8Array(xhr.response));
+      logging.setErrorMessage(msg, 'fetching tensors');
+      return;
+    }
+    let data: Float32Array;
+    try {
+      data = new Float32Array(xhr.response);
+    } catch (e) {
+      logging.setErrorMessage(e, 'parsing tensor bytes');
+      return;
+    }
+
+    let dim = embedding.tensorShape[1];
+    let N = data.length / dim;
+    if (embedding.tensorShape[0] > N) {
+      logging.setWarningMessage(
+          `Showing the first ${N.toLocaleString()}` +
+          ` of ${embedding.tensorShape[0].toLocaleString()} data points`);
+    }
+    parseTensorsFromFloat32Array(data, dim).then(dataPoints => {
+      callback(new DataSet(dataPoints));
+    });
+  };
+  xhr.send();
+}
+
+export function parseRawTensors(
+    content: ArrayBuffer, callback: (ds: DataSet) => void) {
+  parseTensors(content).then(data => {
+    callback(new DataSet(data));
+  });
+}
+
+export function parseRawMetadata(
+    contents: ArrayBuffer, callback: (r: SpriteAndMetadataInfo) => void) {
+  parseMetadata(contents).then(result => callback(result));
+}
+
+/**
+ * Parse an ArrayBuffer in a streaming fashion line by line (or custom delim).
+ * Can handle very large files.
+ *
+ * @param content The array buffer.
+ * @param callback The callback called on each line.
+ * @param chunkSize The size of each read chunk, defaults to ~1MB. (optional)
+ * @param delim The delimiter used to split a line, defaults to '\n'. (optional)
+ * @returns A promise for when it is finished.
+ */
+function streamParse(
+    content: ArrayBuffer, callback: (line: string) => void, chunkSize = 1000000,
+    delim = '\n'): Promise<void> {
+  return new Promise<void>((resolve, reject) => {
+    let offset = 0;
+    let bufferSize = content.byteLength - 1;
+    let data = '';
+
+    function readHandler(str) {
+      offset += chunkSize;
+      let parts = str.split(delim);
+      let first = data + parts[0];
+      if (parts.length === 1) {
+        data = first;
+        readChunk(offset, chunkSize);
+        return;
+      }
+      data = parts[parts.length - 1];
+      callback(first);
+      for (let i = 1; i < parts.length - 1; i++) {
+        callback(parts[i]);
+      }
+      if (offset >= bufferSize) {
+        if (data) {
+          callback(data);
+        }
+        resolve();
+        return;
+      }
+      readChunk(offset, chunkSize);
+    }
+
+    function readChunk(offset: number, size: number) {
+      const contentChunk = content.slice(offset, offset + size);
+
+      const blob = new Blob([contentChunk]);
+      const file = new FileReader();
+      file.onload = (e: any) => readHandler(e.target.result);
+      file.readAsText(blob);
+    }
+
+    readChunk(offset, chunkSize);
+  });
+}
+
+/** Parses a tsv text file. */
+export function parseTensors(
+    content: ArrayBuffer, valueDelim = '\t'): Promise<DataPoint[]> {
+  logging.setModalMessage('Parsing tensors...', TENSORS_MSG_ID);
+
+  return new Promise<DataPoint[]>((resolve, reject) => {
+    const data: DataPoint[] = [];
+    let numDim: number;
+
+    streamParse(content, (line: string) => {
+      line = line.trim();
+      if (line === '') {
+        return;
+      }
+      const row = line.split(valueDelim);
+      const dataPoint: DataPoint = {
+        metadata: {},
+        vector: null,
+        index: data.length,
+        projections: null,
+      };
+      // If the first label is not a number, take it as the label.
+      if (isNaN(row[0] as any) || numDim === row.length - 1) {
+        dataPoint.metadata['label'] = row[0];
+        dataPoint.vector = new Float32Array(row.slice(1).map(Number));
+      } else {
+        dataPoint.vector = new Float32Array(row.map(Number));
+      }
+      data.push(dataPoint);
+      if (numDim == null) {
+        numDim = dataPoint.vector.length;
+      }
+      if (numDim !== dataPoint.vector.length) {
+        logging.setModalMessage(
+            'Parsing failed. Vector dimensions do not match');
+        throw Error('Parsing failed');
+      }
+      if (numDim <= 1) {
+        logging.setModalMessage(
+            'Parsing failed. Found a vector with only one dimension?');
+        throw Error('Parsing failed');
+      }
+    }).then(() => {
+      logging.setModalMessage(null, TENSORS_MSG_ID);
+      resolve(data);
+    });
+  });
+}
+
+/** Parses a tsv text file. */
+export function parseTensorsFromFloat32Array(data: Float32Array,
+    dim: number): Promise<DataPoint[]> {
+  return runAsyncTask('Parsing tensors...', () => {
+    const N = data.length / dim;
+    const dataPoints: DataPoint[] = [];
+    let offset = 0;
+    for (let i = 0; i < N; ++i) {
+      dataPoints.push({
+        metadata: {},
+        vector: data.subarray(offset, offset + dim),
+        index: i,
+        projections: null,
+      });
+      offset += dim;
+    }
+    return dataPoints;
+  }, TENSORS_MSG_ID).then(dataPoints => {
+    logging.setModalMessage(null, TENSORS_MSG_ID);
+    return dataPoints;
+  });
+}
+
+export function analyzeMetadata(
+    columnNames, pointsMetadata: PointMetadata[]): ColumnStats[] {
+  const columnStats: ColumnStats[] = columnNames.map(name => {
+    return {
+      name: name,
+      isNumeric: true,
+      tooManyUniqueValues: false,
+      min: Number.POSITIVE_INFINITY,
+      max: Number.NEGATIVE_INFINITY
+    };
+  });
+
+  const mapOfValues: [{[value: string]: number}] =
+      columnNames.map(() => new Object());
+
+  pointsMetadata.forEach(metadata => {
+    columnNames.forEach((name: string, colIndex: number) => {
+      const stats = columnStats[colIndex];
+      const map = mapOfValues[colIndex];
+      const value = metadata[name];
+
+      // Skip missing values.
+      if (value == null) {
+        return;
+      }
+
+      if (!stats.tooManyUniqueValues) {
+        if (value in map) {
+          map[value]++;
+        } else {
+          map[value] = 1;
+        }
+        if (Object.keys(map).length > NUM_COLORS_COLOR_MAP) {
+          stats.tooManyUniqueValues = true;
+        }
+      }
+      if (isNaN(value as any)) {
+        stats.isNumeric = false;
+      } else {
+        metadata[name] = +value;
+        stats.min = Math.min(stats.min, +value);
+        stats.max = Math.max(stats.max, +value);
+      }
+    });
+  });
+  columnStats.forEach((stats, colIndex) => {
+    stats.uniqueEntries = Object.keys(mapOfValues[colIndex]).map(label => {
+      return {label, count: mapOfValues[colIndex][label]};
+    });
+  });
+  return columnStats;
+}
+
+export function parseMetadata(content: ArrayBuffer):
+    Promise<SpriteAndMetadataInfo> {
+  logging.setModalMessage('Parsing metadata...', METADATA_MSG_ID);
+
+  return new Promise<SpriteAndMetadataInfo>((resolve, reject) => {
+    let pointsMetadata: PointMetadata[] = [];
+    let hasHeader = false;
+    let lineNumber = 0;
+    let columnNames = ['label'];
+    streamParse(content, (line: string) => {
+      if (line.trim().length === 0) {
+        return;
+      }
+      if (lineNumber === 0) {
+        hasHeader = line.indexOf('\t') >= 0;
+
+        // If the first row doesn't contain metadata keys, we assume that the
+        // values are labels.
+        if (hasHeader) {
+          columnNames = line.split('\t');
+          lineNumber++;
+          return;
+        }
+      }
+
+      lineNumber++;
+
+      let rowValues = line.split('\t');
+      let metadata: PointMetadata = {};
+      pointsMetadata.push(metadata);
+      columnNames.forEach((name: string, colIndex: number) => {
+        let value = rowValues[colIndex];
+        // Normalize missing values.
+        value = (value === '' ? null : value);
+        metadata[name] = value;
+      });
+    }).then(() => {
+      logging.setModalMessage(null, METADATA_MSG_ID);
+      resolve({
+        stats: analyzeMetadata(columnNames, pointsMetadata),
+        pointsInfo: pointsMetadata
+      });
+    });
+  });
+}
+
+export function fetchImage(url: string): Promise<HTMLImageElement> {
+  return new Promise<HTMLImageElement>((resolve, reject) => {
+    let image = new Image();
+    image.onload = () => resolve(image);
+    image.onerror = (err) => reject(err);
+    image.crossOrigin = '';
+    image.src = url;
+  });
+}
+
+export function retrieveSpriteAndMetadataInfo(metadataPath: string,
+    spriteImagePath: string, spriteMetadata: SpriteMetadata,
+    callback: (r: SpriteAndMetadataInfo) => void) {
+  let metadataPromise: Promise<SpriteAndMetadataInfo> = Promise.resolve({});
+  if (metadataPath) {
+    metadataPromise = new Promise<SpriteAndMetadataInfo>((resolve, reject) => {
+      logging.setModalMessage('Fetching metadata...', METADATA_MSG_ID);
+
+      const request = new XMLHttpRequest();
+      request.open('GET', metadataPath);
+      request.responseType = 'arraybuffer';
+
+      request.onerror = () => {
+        logging.setErrorMessage(request.responseText, 'fetching metadata');
+        reject();
+      };
+      request.onload = () => {
+        resolve(parseMetadata(request.response));
+      };
+      request.send(null);
+    });
+  }
+  let spriteMsgId = null;
+  let spritesPromise: Promise<HTMLImageElement> = null;
+  if (spriteImagePath) {
+    spriteMsgId = logging.setModalMessage('Fetching sprite image...');
+    spritesPromise = fetchImage(spriteImagePath);
+  }
+
+  // Fetch the metadata and the image in parallel.
+  Promise.all([metadataPromise, spritesPromise]).then(values => {
+    if (spriteMsgId) {
+      logging.setModalMessage(null, spriteMsgId);
+    }
+    const [metadata, spriteImage] = values;
+
+    if (spriteImage && (spriteImage.height > MAX_SPRITE_IMAGE_SIZE_PX ||
+                        spriteImage.width > MAX_SPRITE_IMAGE_SIZE_PX)) {
+      logging.setModalMessage(
+          `Error: Sprite image of dimensions ${spriteImage.width}px x ` +
+          `${spriteImage.height}px exceeds maximum dimensions ` +
+          `${MAX_SPRITE_IMAGE_SIZE_PX}px x ${MAX_SPRITE_IMAGE_SIZE_PX}px`);
+    } else {
+      metadata.spriteImage = spriteImage;
+      metadata.spriteMetadata = spriteMetadata;
+      callback(metadata);
+    }
+  });
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider_test.ts
new file mode 100644
index 00000000000..01b89ca7001
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data-provider_test.ts
@@ -0,0 +1,96 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {DataPoint, SpriteAndMetadataInfo} from './data';
+import * as data_provider from './data-provider';
+
+/**
+ * Converts a string to an ArrayBuffer.
+ */
+function stringToArrayBuffer(str: string): Promise<ArrayBuffer> {
+  return new Promise<ArrayBuffer>((resolve, reject) => {
+    let blob = new Blob([str]);
+    let file = new FileReader();
+    file.onload = (e: any) => {
+      resolve(e.target.result);
+    };
+    file.readAsArrayBuffer(blob);
+  });
+}
+
+/**
+ * Converts an data array to TSV format.
+ */
+function dataToTsv(data: string[][]|number[][]) {
+  let lines = [];
+  for (let i = 0; i < data.length; i++) {
+    lines.push(data[i].join('\t'));
+  }
+  return lines.join('\n');
+}
+
+describe('parse tensors', () => {
+  it('parseTensors', (doneFn) => {
+    let tensors = [[1.0, 2.0], [2.0, 3.0]];
+    stringToArrayBuffer(dataToTsv(tensors))
+        .then((tensorsArrayBuffer: ArrayBuffer) => {
+          data_provider.parseTensors(tensorsArrayBuffer)
+              .then((data: DataPoint[]) => {
+                expect(data.length).toBe(2);
+
+                expect(data[0].vector).toEqual(new Float32Array(tensors[0]));
+                expect(data[0].index).toEqual(0);
+                expect(data[0].projections).toBeNull();
+
+                expect(data[1].vector).toEqual(new Float32Array(tensors[1]));
+                expect(data[1].index).toEqual(1);
+                expect(data[1].projections).toBeNull();
+                doneFn();
+              });
+        });
+  });
+  it('parseMetadata', (doneFn) => {
+    let metadata = [['label', 'fakecol'], ['Г', '0'], ['label1', '1']];
+
+    stringToArrayBuffer(dataToTsv(metadata))
+        .then((metadataArrayBuffer: ArrayBuffer) => {
+          data_provider.parseMetadata(metadataArrayBuffer)
+              .then((spriteAndMetadataInfo: SpriteAndMetadataInfo) => {
+                expect(spriteAndMetadataInfo.stats.length).toBe(2);
+                expect(spriteAndMetadataInfo.stats[0].name)
+                    .toBe(metadata[0][0]);
+                expect(spriteAndMetadataInfo.stats[0].isNumeric).toBe(false);
+                expect(spriteAndMetadataInfo.stats[0].tooManyUniqueValues)
+                    .toBe(false);
+                expect(spriteAndMetadataInfo.stats[1].name)
+                    .toBe(metadata[0][1]);
+                expect(spriteAndMetadataInfo.stats[1].isNumeric).toBe(true);
+                expect(spriteAndMetadataInfo.stats[1].tooManyUniqueValues)
+                    .toBe(false);
+
+                expect(spriteAndMetadataInfo.pointsInfo.length).toBe(2);
+                expect(spriteAndMetadataInfo.pointsInfo[0]['label'])
+                    .toBe(metadata[1][0]);
+                expect(spriteAndMetadataInfo.pointsInfo[0]['fakecol'])
+                    .toBe(+metadata[1][1]);
+                expect(spriteAndMetadataInfo.pointsInfo[1]['label'])
+                    .toBe(metadata[2][0]);
+                expect(spriteAndMetadataInfo.pointsInfo[1]['fakecol'])
+                    .toBe(+metadata[2][1]);
+                doneFn();
+              });
+        });
+  });
+});
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data.ts
new file mode 100644
index 00000000000..c4e81985fc8
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data.ts
@@ -0,0 +1,547 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {TSNE} from './bh_tsne';
+import {SpriteMetadata} from './data-provider';
+import * as knn from './knn';
+import * as logging from './logging';
+import * as scatterPlot from './scatterPlot';
+import * as util from './util';
+import * as vector from './vector';
+
+export type DistanceFunction = (a: number[], b: number[]) => number;
+export type ProjectionComponents3D = [string, string, string];
+
+export interface PointMetadata { [key: string]: number|string; }
+
+export interface DataProto {
+  shape: [number, number];
+  tensor: number[];
+  metadata: {
+    columns: Array<
+        {name: string; stringValues: string[]; numericValues: number[];}>;
+  };
+}
+
+/** Statistics for a metadata column. */
+export interface ColumnStats {
+  name: string;
+  isNumeric: boolean;
+  tooManyUniqueValues: boolean;
+  uniqueEntries?: Array<{label: string, count: number}>;
+  min: number;
+  max: number;
+}
+
+export interface SpriteAndMetadataInfo {
+  stats?: ColumnStats[];
+  pointsInfo?: PointMetadata[];
+  spriteImage?: HTMLImageElement;
+  spriteMetadata?: SpriteMetadata;
+}
+
+/** A single collection of points which make up a sequence through space. */
+export interface Sequence {
+  /** Indices into the DataPoints array in the Data object. */
+  pointIndices: number[];
+}
+
+export interface DataPoint {
+  /** The point in the original space. */
+  vector: Float32Array;
+
+  /*
+   * Metadata for each point. Each metadata is a set of key/value pairs
+   * where the value can be a string or a number.
+   */
+  metadata: PointMetadata;
+
+  /** index of the sequence, used for highlighting on click */
+  sequenceIndex?: number;
+
+  /** index in the original data source */
+  index: number;
+
+  /** This is where the calculated projections space are cached */
+  projections: {[key: string]: number};
+}
+
+const IS_FIREFOX = navigator.userAgent.toLowerCase().indexOf('firefox') >= 0;
+/** Controls whether nearest neighbors computation is done on the GPU or CPU. */
+const KNN_GPU_ENABLED = util.hasWebGLSupport() && !IS_FIREFOX;
+
+export const TSNE_SAMPLE_SIZE = 10000;
+export const PCA_SAMPLE_SIZE = 50000;
+/** Number of dimensions to sample when doing approximate PCA. */
+export const PCA_SAMPLE_DIM = 200;
+/** Number of pca components to compute. */
+const NUM_PCA_COMPONENTS = 10;
+/**
+ * Reserved metadata attributes used for sequence information
+ * NOTE: Use "__seq_next__" as "__next__" is deprecated.
+ */
+const SEQUENCE_METADATA_ATTRS = ['__next__', '__seq_next__'];
+
+function getSequenceNextPointIndex(pointMetadata: PointMetadata): number|null {
+  let sequenceAttr = null;
+  for (let metadataAttr of SEQUENCE_METADATA_ATTRS) {
+    if (metadataAttr in pointMetadata && pointMetadata[metadataAttr] !== '') {
+      sequenceAttr = pointMetadata[metadataAttr];
+      break;
+    }
+  }
+  if (sequenceAttr == null) {
+    return null;
+  }
+  return +sequenceAttr;
+}
+
+/**
+ * Dataset contains a DataPoints array that should be treated as immutable. This
+ * acts as a working subset of the original data, with cached properties
+ * from computationally expensive operations. Because creating a subset
+ * requires normalizing and shifting the vector space, we make a copy of the
+ * data so we can still always create new subsets based on the original data.
+ */
+export class DataSet {
+  points: DataPoint[];
+  sequences: Sequence[];
+
+  shuffledDataIndices: number[] = [];
+
+  /**
+   * This keeps a list of all current projections so you can easily test to see
+   * if it's been calculated already.
+   */
+  projections: {[projection: string]: boolean} = {};
+  nearest: knn.NearestEntry[][];
+  nearestK: number;
+  tSNEIteration: number = 0;
+  tSNEShouldStop = true;
+  dim: [number, number] = [0, 0];
+  hasTSNERun: boolean = false;
+  spriteAndMetadataInfo: SpriteAndMetadataInfo;
+  fracVariancesExplained: number[];
+
+  private tsne: TSNE;
+
+  /** Creates a new Dataset */
+  constructor(
+      points: DataPoint[], spriteAndMetadataInfo?: SpriteAndMetadataInfo) {
+    this.points = points;
+    this.shuffledDataIndices = util.shuffle(util.range(this.points.length));
+    this.sequences = this.computeSequences(points);
+    this.dim = [this.points.length, this.points[0].vector.length];
+    this.spriteAndMetadataInfo = spriteAndMetadataInfo;
+  }
+
+  private computeSequences(points: DataPoint[]) {
+    // Keep a list of indices seen so we don't compute sequences for a given
+    // point twice.
+    let indicesSeen = new Int8Array(points.length);
+    // Compute sequences.
+    let indexToSequence: {[index: number]: Sequence} = {};
+    let sequences: Sequence[] = [];
+    for (let i = 0; i < points.length; i++) {
+      if (indicesSeen[i]) {
+        continue;
+      }
+      indicesSeen[i] = 1;
+
+      // Ignore points without a sequence attribute.
+      let next = getSequenceNextPointIndex(points[i].metadata);
+      if (next == null) {
+        continue;
+      }
+      if (next in indexToSequence) {
+        let existingSequence = indexToSequence[next];
+        // Pushing at the beginning of the array.
+        existingSequence.pointIndices.unshift(i);
+        indexToSequence[i] = existingSequence;
+        continue;
+      }
+      // The current point is pointing to a new/unseen sequence.
+      let newSequence: Sequence = {pointIndices: []};
+      indexToSequence[i] = newSequence;
+      sequences.push(newSequence);
+      let currentIndex = i;
+      while (points[currentIndex]) {
+        newSequence.pointIndices.push(currentIndex);
+        let next = getSequenceNextPointIndex(points[currentIndex].metadata);
+        if (next != null) {
+          indicesSeen[next] = 1;
+          currentIndex = next;
+        } else {
+          currentIndex = -1;
+        }
+      }
+    }
+    return sequences;
+  }
+
+  projectionCanBeRendered(projection: ProjectionType): boolean {
+    if (projection !== 'tsne') {
+      return true;
+    }
+    return this.tSNEIteration > 0;
+  }
+
+  /**
+   * Returns a new subset dataset by copying out data. We make a copy because
+   * we have to modify the vectors by normalizing them.
+   *
+   * @param subset Array of indices of points that we want in the subset.
+   *
+   * @return A subset of the original dataset.
+   */
+  getSubset(subset?: number[]): DataSet {
+    const pointsSubset = ((subset != null) && (subset.length > 0)) ?
+        subset.map(i => this.points[i]) :
+        this.points;
+    let points = pointsSubset.map(dp => {
+      return {
+        metadata: dp.metadata,
+        index: dp.index,
+        vector: dp.vector.slice(),
+        projections: {} as {[key: string]: number}
+      };
+    });
+    return new DataSet(points, this.spriteAndMetadataInfo);
+  }
+
+  /**
+   * Computes the centroid, shifts all points to that centroid,
+   * then makes them all unit norm.
+   */
+  normalize() {
+    // Compute the centroid of all data points.
+    let centroid = vector.centroid(this.points, a => a.vector);
+    if (centroid == null) {
+      throw Error('centroid should not be null');
+    }
+    // Shift all points by the centroid and make them unit norm.
+    for (let id = 0; id < this.points.length; ++id) {
+      let dataPoint = this.points[id];
+      dataPoint.vector = vector.sub(dataPoint.vector, centroid);
+      vector.unit(dataPoint.vector);
+    }
+  }
+
+  /** Projects the dataset onto a given vector and caches the result. */
+  projectLinear(dir: vector.Vector, label: string) {
+    this.projections[label] = true;
+    this.points.forEach(dataPoint => {
+      dataPoint.projections[label] = vector.dot(dataPoint.vector, dir);
+    });
+  }
+
+  /** Projects the dataset along the top 10 principal components. */
+  projectPCA(): Promise<void> {
+    if (this.projections['pca-0'] != null) {
+      return Promise.resolve<void>(null);
+    }
+    return util.runAsyncTask('Computing PCA...', () => {
+      // Approximate pca vectors by sampling the dimensions.
+      let dim = this.points[0].vector.length;
+      let vectors = this.shuffledDataIndices.map(i => this.points[i].vector);
+      if (dim > PCA_SAMPLE_DIM) {
+        vectors = vector.projectRandom(vectors, PCA_SAMPLE_DIM);
+      }
+      let sampledVectors = vectors.slice(0, PCA_SAMPLE_SIZE);
+
+      let sigma = numeric.div(
+          numeric.dot(numeric.transpose(sampledVectors), sampledVectors),
+          sampledVectors.length);
+      let svd = numeric.svd(sigma);
+
+      let variances: number[] = svd.S;
+      let totalVariance = 0;
+      for (let i = 0; i < variances.length; ++i) {
+        totalVariance += variances[i];
+      }
+      for (let i = 0; i < variances.length; ++i) {
+        variances[i] /= totalVariance;
+      }
+      this.fracVariancesExplained = variances;
+
+      let U: number[][] = svd.U;
+      let pcaVectors = vectors.map(vector => {
+        let newV = new Float32Array(NUM_PCA_COMPONENTS);
+        for (let newDim = 0; newDim < NUM_PCA_COMPONENTS; newDim++) {
+          let dot = 0;
+          for (let oldDim = 0; oldDim < vector.length; oldDim++) {
+            dot += vector[oldDim] * U[oldDim][newDim];
+          }
+          newV[newDim] = dot;
+        }
+        return newV;
+      });
+      for (let d = 0; d < NUM_PCA_COMPONENTS; d++) {
+        let label = 'pca-' + d;
+        this.projections[label] = true;
+        for (let i = 0; i < pcaVectors.length; i++) {
+          let pointIndex = this.shuffledDataIndices[i];
+          this.points[pointIndex].projections[label] = pcaVectors[i][d];
+        }
+      }
+    });
+  }
+
+  /** Runs tsne on the data. */
+  projectTSNE(
+      perplexity: number, learningRate: number, tsneDim: number,
+      stepCallback: (iter: number) => void) {
+    this.hasTSNERun = true;
+    let k = Math.floor(3 * perplexity);
+    let opt = {epsilon: learningRate, perplexity: perplexity, dim: tsneDim};
+    this.tsne = new TSNE(opt);
+    this.tSNEShouldStop = false;
+    this.tSNEIteration = 0;
+
+    let sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE);
+    let step = () => {
+      if (this.tSNEShouldStop) {
+        stepCallback(null);
+        this.tsne = null;
+        return;
+      }
+      this.tsne.step();
+      let result = this.tsne.getSolution();
+      sampledIndices.forEach((index, i) => {
+        let dataPoint = this.points[index];
+
+        dataPoint.projections['tsne-0'] = result[i * tsneDim + 0];
+        dataPoint.projections['tsne-1'] = result[i * tsneDim + 1];
+        if (tsneDim === 3) {
+          dataPoint.projections['tsne-2'] = result[i * tsneDim + 2];
+        }
+      });
+      this.tSNEIteration++;
+      stepCallback(this.tSNEIteration);
+      requestAnimationFrame(step);
+    };
+
+    // Nearest neighbors calculations.
+    let knnComputation: Promise<knn.NearestEntry[][]>;
+
+    if (this.nearest != null && k === this.nearestK) {
+      // We found the nearest neighbors before and will reuse them.
+      knnComputation = Promise.resolve(this.nearest);
+    } else {
+      let sampledData = sampledIndices.map(i => this.points[i]);
+      this.nearestK = k;
+      knnComputation = KNN_GPU_ENABLED ?
+          knn.findKNNGPUCosine(sampledData, k, (d => d.vector)) :
+          knn.findKNN(
+              sampledData, k, (d => d.vector),
+              (a, b, limit) => vector.cosDistNorm(a, b));
+    }
+    knnComputation.then(nearest => {
+      this.nearest = nearest;
+      util.runAsyncTask('Initializing T-SNE...', () => {
+            this.tsne.initDataDist(this.nearest);
+          }).then(step);
+    });
+  }
+
+  /**
+   * Merges metadata to the dataset and returns whether it succeeded.
+   */
+  mergeMetadata(metadata: SpriteAndMetadataInfo): boolean {
+    if (metadata.pointsInfo.length !== this.points.length) {
+      let errorMessage = `Number of tensors (${this.points.length}) do not` +
+          ` match the number of lines in metadata` +
+          ` (${metadata.pointsInfo.length}).`;
+
+      if (metadata.stats.length === 1 &&
+          this.points.length + 1 === metadata.pointsInfo.length) {
+        // If there is only one column of metadata and the number of points is
+        // exactly one less than the number of metadata lines, this is due to an
+        // unnecessary header line in the metadata and we can show a meaningful
+        // error.
+        logging.setErrorMessage(
+            errorMessage + ' Single column metadata should not have a header ' +
+                'row.',
+            'merging metadata');
+        return false;
+      } else if (
+          metadata.stats.length > 1 &&
+          this.points.length - 1 === metadata.pointsInfo.length) {
+        // If there are multiple columns of metadata and the number of points is
+        // exactly one greater than the number of lines in the metadata, this
+        // means there is a missing metadata header.
+        logging.setErrorMessage(
+            errorMessage + ' Multi-column metadata should have a header ' +
+                'row with column labels.',
+            'merging metadata');
+        return false;
+      }
+
+      logging.setWarningMessage(errorMessage);
+    }
+    this.spriteAndMetadataInfo = metadata;
+    metadata.pointsInfo.slice(0, this.points.length)
+        .forEach((m, i) => this.points[i].metadata = m);
+    return true;
+  }
+
+  stopTSNE() {
+    this.tSNEShouldStop = true;
+  }
+
+  /**
+   * Finds the nearest neighbors of the query point using a
+   * user-specified distance metric.
+   */
+  findNeighbors(pointIndex: number, distFunc: DistanceFunction, numNN: number):
+      knn.NearestEntry[] {
+    // Find the nearest neighbors of a particular point.
+    let neighbors = knn.findKNNofPoint(
+        this.points, pointIndex, numNN, (d => d.vector), distFunc);
+    // TODO(smilkov): Figure out why we slice.
+    let result = neighbors.slice(0, numNN);
+    return result;
+  }
+
+  /**
+   * Search the dataset based on a metadata field.
+   */
+  query(query: string, inRegexMode: boolean, fieldName: string): number[] {
+    let predicate = util.getSearchPredicate(query, inRegexMode, fieldName);
+    let matches: number[] = [];
+    this.points.forEach((point, id) => {
+      if (predicate(point)) {
+        matches.push(id);
+      }
+    });
+    return matches;
+  }
+}
+
+export type ProjectionType = 'tsne' | 'pca' | 'custom';
+
+export class Projection {
+  constructor(
+      public projectionType: ProjectionType,
+      public projectionComponents: ProjectionComponents3D,
+      public dimensionality: number, public dataSet: DataSet) {}
+}
+
+export interface ColorOption {
+  name: string;
+  desc?: string;
+  map?: (value: string|number) => string;
+  /** List of items for the color map. Defined only for categorical map. */
+  items?: {label: string, count: number}[];
+  /** Threshold values and their colors. Defined for gradient color map. */
+  thresholds?: {value: number, color: string}[];
+  isSeparator?: boolean;
+  tooManyUniqueValues?: boolean;
+}
+
+/**
+ * An interface that holds all the data for serializing the current state of
+ * the world.
+ */
+export class State {
+  /** A label identifying this state. */
+  label: string = '';
+
+  /** Whether this State is selected in the bookmarks pane. */
+  isSelected: boolean = false;
+
+  /** The selected projection tab. */
+  selectedProjection: ProjectionType;
+
+  /** Dimensions of the DataSet. */
+  dataSetDimensions: [number, number];
+
+  /** t-SNE parameters */
+  tSNEIteration: number = 0;
+  tSNEPerplexity: number = 0;
+  tSNELearningRate: number = 0;
+  tSNEis3d: boolean = true;
+
+  /** PCA projection component dimensions */
+  pcaComponentDimensions: number[] = [];
+
+  /** Custom projection parameters */
+  customSelectedSearchByMetadataOption: string;
+  customXLeftText: string;
+  customXLeftRegex: boolean;
+  customXRightText: string;
+  customXRightRegex: boolean;
+  customYUpText: string;
+  customYUpRegex: boolean;
+  customYDownText: string;
+  customYDownRegex: boolean;
+
+  /** The computed projections of the tensors. */
+  projections: Array<{[key: string]: number}> = [];
+
+  /** Filtered dataset indices. */
+  filteredPoints: number[];
+
+  /** The indices of selected points. */
+  selectedPoints: number[] = [];
+
+  /** Camera state (2d/3d, position, target, zoom, etc). */
+  cameraDef: scatterPlot.CameraDef;
+
+  /** Color by option. */
+  selectedColorOptionName: string;
+  forceCategoricalColoring: boolean;
+
+  /** Label by option. */
+  selectedLabelOption: string;
+}
+
+export function getProjectionComponents(
+    projection: ProjectionType,
+    components: (number|string)[]): ProjectionComponents3D {
+  if (components.length > 3) {
+    throw new RangeError('components length must be <= 3');
+  }
+  const projectionComponents: [string, string, string] = [null, null, null];
+  const prefix = (projection === 'custom') ? 'linear' : projection;
+  for (let i = 0; i < components.length; ++i) {
+    if (components[i] == null) {
+      continue;
+    }
+    projectionComponents[i] = prefix + '-' + components[i];
+  }
+  return projectionComponents;
+}
+
+export function stateGetAccessorDimensions(state: State): Array<number|string> {
+  let dimensions: Array<number|string>;
+  switch (state.selectedProjection) {
+    case 'pca':
+      dimensions = state.pcaComponentDimensions.slice();
+      break;
+    case 'tsne':
+      dimensions = [0, 1];
+      if (state.tSNEis3d) {
+        dimensions.push(2);
+      }
+      break;
+    case 'custom':
+      dimensions = ['x', 'y'];
+      break;
+    default:
+      throw new Error('Unexpected fallthrough');
+  }
+  return dimensions;
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/data_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/data_test.ts
new file mode 100644
index 00000000000..924ae3a929f
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/data_test.ts
@@ -0,0 +1,104 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {DataPoint, DataSet, State, stateGetAccessorDimensions} from './data';
+
+/**
+ * Helper method that makes a list of points given an array of
+ * sequence indexes.
+ *
+ * @param sequences The i-th entry holds the 'next' attribute for the i-th
+ * point.
+ */
+function makePointsWithSequences(
+    sequences: number[], nextAttr = '__seq_next__') {
+  let points: DataPoint[] = [];
+  sequences.forEach((t, i) => {
+    let metadata: {[key: string]: any} = {};
+    metadata[nextAttr] = t >= 0 ? t : null;
+    points.push({
+      vector: new Float32Array(0),
+      metadata: metadata,
+      projections: {},
+      index: i
+    });
+  });
+  return points;
+}
+
+describe('constructor_with_sequences', () => {
+  it('Simple forward pointing sequences, __seq_next__ metadata format', () => {
+    // The input is: 0->2, 1->None, 2->3, 3->None. This should return
+    // one sequence 0->2->3.
+    const points = makePointsWithSequences([2, -1, 3, -1]);
+    let dataset = new DataSet(points);
+    expect(dataset.sequences.length).toEqual(1);
+    expect(dataset.sequences[0].pointIndices).toEqual([0, 2, 3]);
+  });
+
+  it('Simple forward pointing sequences, __next__ metadata format', () => {
+    // The input is: 0->2, 1->None, 2->3, 3->None. This should return
+    // one sequence 0->2->3.
+    const points = makePointsWithSequences([2, -1, 3, -1], '__next__');
+    let dataset = new DataSet(points);
+    expect(dataset.sequences.length).toEqual(1);
+    expect(dataset.sequences[0].pointIndices).toEqual([0, 2, 3]);
+  });
+
+  it('No sequences', () => {
+    let points = makePointsWithSequences([-1, -1, -1, -1]);
+    let dataset = new DataSet(points);
+    expect(dataset.sequences.length).toEqual(0);
+  });
+
+  it('A sequence that goes backwards and forward in the array', () => {
+    // The input is: 0->2, 1->0, 2->nothing, 3->1. This should return
+    // one sequence 3->1->0->2.
+    let points = makePointsWithSequences([2, 0, -1, 1]);
+    let dataset = new DataSet(points);
+    expect(dataset.sequences.length).toEqual(1);
+    expect(dataset.sequences[0].pointIndices).toEqual([3, 1, 0, 2]);
+  });
+});
+
+describe('stateGetAccessorDimensions', () => {
+  it('returns [0, 1] for 2d t-SNE', () => {
+    const state = new State();
+    state.selectedProjection = 'tsne';
+    state.tSNEis3d = false;
+    expect(stateGetAccessorDimensions(state)).toEqual([0, 1]);
+  });
+
+  it('returns [0, 1, 2] for 3d t-SNE', () => {
+    const state = new State();
+    state.selectedProjection = 'tsne';
+    state.tSNEis3d = true;
+    expect(stateGetAccessorDimensions(state)).toEqual([0, 1, 2]);
+  });
+
+  it('returns pca component dimensions array for pca', () => {
+    const state = new State();
+    state.selectedProjection = 'pca';
+    state.pcaComponentDimensions = [13, 12, 11, 10];
+    expect(stateGetAccessorDimensions(state))
+        .toEqual(state.pcaComponentDimensions);
+  });
+
+  it('returns ["x", "y"] for custom projections', () => {
+    const state = new State();
+    state.selectedProjection = 'custom';
+    expect(stateGetAccessorDimensions(state)).toEqual(['x', 'y']);
+  });
+});
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/external.d.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/external.d.ts
new file mode 100644
index 00000000000..cbc1512c215
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/external.d.ts
@@ -0,0 +1,51 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// TODO(smilkov): Split into weblas.d.ts and numeric.d.ts and write
+// typings for numeric.
+interface Tensor {
+  new(size: [number, number], data: Float32Array);
+  transfer(): Float32Array;
+  delete(): void;
+}
+
+interface Weblas {
+  sgemm(M: number, N: number, K: number, alpha: number,
+      A: Float32Array, B: Float32Array, beta: number, C: Float32Array):
+      Float32Array;
+  pipeline: {
+     Tensor: Tensor;
+     sgemm(alpha: number, A: Tensor, B: Tensor, beta: number,
+         C: Tensor): Tensor;
+  };
+  util: {
+    transpose(M: number, N: number, data: Float32Array): Tensor;
+  };
+
+}
+
+declare let numeric: any;
+declare let weblas: Weblas;
+
+interface AnalyticsEventType {
+  hitType: string;
+  page?: string;
+  eventCategory?: string;
+  eventAction?: string;
+  eventLabel?: string;
+  eventValue?: number;
+}
+
+declare let ga: (command: string, eventObj: AnalyticsEventType) => void;
\ No newline at end of file
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/heap.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/heap.ts
new file mode 100644
index 00000000000..ac3144e6493
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/heap.ts
@@ -0,0 +1,146 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/** Min key heap. */
+export type HeapItem<T> = {
+  key: number,
+  value: T
+};
+
+/**
+ * Min-heap data structure. Provides O(1) for peek, returning the smallest key.
+ */
+// TODO(jart): Rename to Heap and use Comparator.
+export class MinHeap<T> {
+  private arr: HeapItem<T>[] = [];
+
+  /** Push an element with the provided key. */
+  push(key: number, value: T): void {
+    this.arr.push({key, value});
+    this.bubbleUp(this.arr.length - 1);
+  }
+
+  /** Pop the element with the smallest key. */
+  pop(): HeapItem<T> {
+    if (this.arr.length === 0) {
+      throw new Error('pop() called on empty binary heap');
+    }
+    let item = this.arr[0];
+    let last = this.arr.length - 1;
+    this.arr[0] = this.arr[last];
+    this.arr.pop();
+    if (last > 0) {
+      this.bubbleDown(0);
+    }
+    return item;
+  };
+
+  /** Returns, but doesn't remove the element with the smallest key */
+  peek(): HeapItem<T> { return this.arr[0]; }
+
+  /**
+   * Pops the element with the smallest key and at the same time
+   * adds the newly provided element. This is faster than calling
+   * pop() and push() separately.
+   */
+  popPush(key: number, value: T): HeapItem<T> {
+    if (this.arr.length === 0) {
+      throw new Error('pop() called on empty binary heap');
+    }
+    let item = this.arr[0];
+    this.arr[0] = {key, value};
+    if (this.arr.length > 0) {
+      this.bubbleDown(0);
+    }
+    return item;
+  }
+
+  /** Returns the number of elements in the heap. */
+  size(): number { return this.arr.length; }
+
+  /** Returns all the items in the heap. */
+  items(): HeapItem<T>[] { return this.arr; }
+
+  private swap(a: number, b: number) {
+    let temp = this.arr[a];
+    this.arr[a] = this.arr[b];
+    this.arr[b] = temp;
+  }
+
+  private bubbleDown(pos: number) {
+    let left = (pos << 1) + 1;
+    let right = left + 1;
+    let largest = pos;
+    if (left < this.arr.length && this.arr[left].key < this.arr[largest].key) {
+      largest = left;
+    }
+    if (right < this.arr.length &&
+        this.arr[right].key < this.arr[largest].key) {
+      largest = right;
+    }
+    if (largest !== pos) {
+      this.swap(largest, pos);
+      this.bubbleDown(largest);
+    }
+  }
+
+  private bubbleUp(pos: number) {
+    if (pos <= 0) {
+      return;
+    }
+    let parent = ((pos - 1) >> 1);
+    if (this.arr[pos].key < this.arr[parent].key) {
+      this.swap(pos, parent);
+      this.bubbleUp(parent);
+    }
+  }
+}
+
+/** List that keeps the K elements with the smallest keys. */
+export class KMin<T> {
+  private k: number;
+  private maxHeap = new MinHeap<T>();
+
+  /** Constructs a new k-min data structure with the provided k. */
+  constructor(k: number) { this.k = k; }
+
+  /** Adds an element to the list. */
+  add(key: number, value: T) {
+    if (this.maxHeap.size() < this.k) {
+      this.maxHeap.push(-key, value);
+      return;
+    }
+    let largest = this.maxHeap.peek();
+    // If the new element is smaller, replace the largest with the new element.
+    if (key < -largest.key) {
+      this.maxHeap.popPush(-key, value);
+    }
+  }
+
+  /** Returns the k items with the smallest keys. */
+  getMinKItems(): T[] {
+    let items = this.maxHeap.items();
+    items.sort((a, b) => b.key - a.key);
+    return items.map(a => a.value);
+  }
+
+  /** Returns the size of the list. */
+  getSize(): number { return this.maxHeap.size(); }
+
+  /** Returns the largest key in the list. */
+  getLargestKey(): number {
+    return this.maxHeap.size() === 0 ? null : -this.maxHeap.peek().key;
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/knn.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/knn.ts
new file mode 100644
index 00000000000..906e077b5d7
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/knn.ts
@@ -0,0 +1,235 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {runAsyncTask} from './util';
+import * as logging from './logging';
+import {KMin} from './heap';
+import {Vector} from './vector';
+import * as vector from './vector';
+
+export type NearestEntry = {
+  index: number,
+  dist: number
+};
+
+/**
+ * Optimal size for the height of the matrix when doing computation on the GPU
+ * using WebGL. This was found experimentally.
+ *
+ * This also guarantees that for computing pair-wise distance for up to 10K
+ * vectors, no more than 40MB will be allocated in the GPU. Without the
+ * allocation limit, we can freeze the graphics of the whole OS.
+ */
+const OPTIMAL_GPU_BLOCK_SIZE = 256;
+/** Id of message box used for knn gpu progress bar. */
+const KNN_GPU_MSG_ID = 'knn-gpu';
+
+/**
+ * Returns the K nearest neighbors for each vector where the distance
+ * computation is done on the GPU (WebGL) using cosine distance.
+ *
+ * @param dataPoints List of data points, where each data point holds an
+ *   n-dimensional vector.
+ * @param k Number of nearest neighbors to find.
+ * @param accessor A method that returns the vector, given the data point.
+ */
+export function findKNNGPUCosine<T>(
+    dataPoints: T[], k: number,
+    accessor: (dataPoint: T) => Float32Array): Promise<NearestEntry[][]> {
+  let N = dataPoints.length;
+  let dim = accessor(dataPoints[0]).length;
+
+  // The goal is to compute a large matrix multiplication A*A.T where A is of
+  // size NxD and A.T is its transpose. This results in a NxN matrix which
+  // could be too big to store on the GPU memory. To avoid memory overflow, we
+  // compute multiple A*partial_A.T where partial_A is of size BxD (B is much
+  // smaller than N). This results in storing only NxB size matrices on the GPU
+  // at a given time.
+
+  // A*A.T will give us NxN matrix holding the cosine distance between every
+  // pair of points, which we sort using KMin data structure to obtain the
+  // K nearest neighbors for each point.
+  let typedArray = vector.toTypedArray(dataPoints, accessor);
+  let bigMatrix = new weblas.pipeline.Tensor([N, dim], typedArray);
+  let nearest: NearestEntry[][] = new Array(N);
+  let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE);
+  let M = Math.floor(N / numPieces);
+  let modulo = N % numPieces;
+  let offset = 0;
+  let progress = 0;
+  let progressDiff = 1 / (2 * numPieces);
+  let piece = 0;
+
+  function step(resolve: (result: NearestEntry[][]) => void) {
+    let progressMsg =
+        'Finding nearest neighbors: ' + (progress * 100).toFixed() + '%';
+    runAsyncTask(progressMsg, () => {
+      let B = piece < modulo ? M + 1 : M;
+      let typedB = new Float32Array(B * dim);
+      for (let i = 0; i < B; ++i) {
+        let vector = accessor(dataPoints[offset + i]);
+        for (let d = 0; d < dim; ++d) {
+          typedB[i * dim + d] = vector[d];
+        }
+      }
+      let partialMatrix = new weblas.pipeline.Tensor([B, dim], typedB);
+      // Result is N x B matrix.
+      let result =
+          weblas.pipeline.sgemm(1, bigMatrix, partialMatrix, null, null);
+      let partial = result.transfer();
+      partialMatrix.delete();
+      result.delete();
+      progress += progressDiff;
+      for (let i = 0; i < B; i++) {
+        let kMin = new KMin<NearestEntry>(k);
+        let iReal = offset + i;
+        for (let j = 0; j < N; j++) {
+          if (j === iReal) {
+            continue;
+          }
+          let cosDist = 1 - partial[j * B + i];  // [j, i];
+          kMin.add(cosDist, {index: j, dist: cosDist});
+        }
+        nearest[iReal] = kMin.getMinKItems();
+      }
+      progress += progressDiff;
+      offset += B;
+      piece++;
+    }, KNN_GPU_MSG_ID).then(() => {
+      if (piece < numPieces) {
+        step(resolve);
+      } else {
+        logging.setModalMessage(null, KNN_GPU_MSG_ID);
+        bigMatrix.delete();
+        resolve(nearest);
+      }
+    }, error => {
+      // GPU failed. Reverting back to CPU.
+      logging.setModalMessage(null, KNN_GPU_MSG_ID);
+      let distFunc = (a, b, limit) => vector.cosDistNorm(a, b);
+      findKNN(dataPoints, k, accessor, distFunc).then(nearest => {
+        resolve(nearest);
+      });
+    });
+  }
+  return new Promise<NearestEntry[][]>(resolve => step(resolve));
+}
+
+/**
+ * Returns the K nearest neighbors for each vector where the distance
+ * computation is done on the CPU using a user-specified distance method.
+ *
+ * @param dataPoints List of data points, where each data point holds an
+ *   n-dimensional vector.
+ * @param k Number of nearest neighbors to find.
+ * @param accessor A method that returns the vector, given the data point.
+ * @param dist Method that takes two vectors and a limit, and computes the
+ *   distance between two vectors, with the ability to stop early if the
+ *   distance is above the limit.
+ */
+export function findKNN<T>(
+    dataPoints: T[], k: number, accessor: (dataPoint: T) => Float32Array,
+    dist: (a: Vector, b: Vector, limit: number) =>
+        number): Promise<NearestEntry[][]> {
+  return runAsyncTask<NearestEntry[][]>('Finding nearest neighbors...', () => {
+    let N = dataPoints.length;
+    let nearest: NearestEntry[][] = new Array(N);
+    // Find the distances from node i.
+    let kMin: KMin<NearestEntry>[] = new Array(N);
+    for (let i = 0; i < N; i++) {
+      kMin[i] = new KMin<NearestEntry>(k);
+    }
+    for (let i = 0; i < N; i++) {
+      let a = accessor(dataPoints[i]);
+      let kMinA = kMin[i];
+      for (let j = i + 1; j < N; j++) {
+        let kMinB = kMin[j];
+        let limitI = kMinA.getSize() === k ?
+            kMinA.getLargestKey() || Number.MAX_VALUE :
+            Number.MAX_VALUE;
+        let limitJ = kMinB.getSize() === k ?
+            kMinB.getLargestKey() || Number.MAX_VALUE :
+            Number.MAX_VALUE;
+        let limit = Math.max(limitI, limitJ);
+        let dist2ItoJ = dist(a, accessor(dataPoints[j]), limit);
+        if (dist2ItoJ >= 0) {
+          kMinA.add(dist2ItoJ, {index: j, dist: dist2ItoJ});
+          kMinB.add(dist2ItoJ, {index: i, dist: dist2ItoJ});
+        }
+      }
+    }
+    for (let i = 0; i < N; i++) {
+      nearest[i] = kMin[i].getMinKItems();
+    }
+    return nearest;
+  });
+}
+
+/** Calculates the minimum distance between a search point and a rectangle. */
+function minDist(
+    point: [number, number], x1: number, y1: number, x2: number, y2: number) {
+  let x = point[0];
+  let y = point[1];
+  let dx1 = x - x1;
+  let dx2 = x - x2;
+  let dy1 = y - y1;
+  let dy2 = y - y2;
+
+  if (dx1 * dx2 <= 0) {    // x is between x1 and x2
+    if (dy1 * dy2 <= 0) {  // (x,y) is inside the rectangle
+      return 0;            // return 0 as point is in rect
+    }
+    return Math.min(Math.abs(dy1), Math.abs(dy2));
+  }
+  if (dy1 * dy2 <= 0) {  // y is between y1 and y2
+    // We know it is already inside the rectangle
+    return Math.min(Math.abs(dx1), Math.abs(dx2));
+  }
+  let corner: [number, number];
+  if (x > x2) {
+    // Upper-right vs lower-right.
+    corner = y > y2 ? [x2, y2] : [x2, y1];
+  } else {
+    // Upper-left vs lower-left.
+    corner = y > y2 ? [x1, y2] : [x1, y1];
+  }
+  return Math.sqrt(vector.dist22D([x, y], corner));
+}
+
+/**
+ * Returns the nearest neighbors of a particular point.
+ *
+ * @param dataPoints List of data points.
+ * @param pointIndex The index of the point we need the nearest neighbors of.
+ * @param k Number of nearest neighbors to search for.
+ * @param accessor Method that maps a data point => vector (array of numbers).
+ * @param distance Method that takes two vectors and returns their distance.
+ */
+export function findKNNofPoint<T>(
+    dataPoints: T[], pointIndex: number, k: number,
+    accessor: (dataPoint: T) => Float32Array,
+    distance: (a: Vector, b: Vector) => number) {
+  let kMin = new KMin<NearestEntry>(k);
+  let a = accessor(dataPoints[pointIndex]);
+  for (let i = 0; i < dataPoints.length; ++i) {
+    if (i === pointIndex) {
+      continue;
+    }
+    let b = accessor(dataPoints[i]);
+    let dist = distance(a, b);
+    kMin.add(dist, {index: i, dist: dist});
+  }
+  return kMin.getMinKItems();
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/label.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/label.ts
new file mode 100644
index 00000000000..67987f06ea3
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/label.ts
@@ -0,0 +1,151 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+export interface BoundingBox {
+  loX: number;
+  loY: number;
+  hiX: number;
+  hiY: number;
+}
+
+/**
+ * Accelerates label placement by dividing the view into a uniform grid.
+ * Labels only need to be tested for collision with other labels that overlap
+ * the same grid cells. This is a fork of {@code amoeba.CollisionGrid}.
+ */
+export class CollisionGrid {
+  private numHorizCells: number;
+  private numVertCells: number;
+  private grid: BoundingBox[][];
+  private bound: BoundingBox;
+  private cellWidth: number;
+  private cellHeight: number;
+
+  /**
+   * Constructs a new Collision grid.
+   *
+   * @param bound The bound of the grid. Labels out of bounds will be rejected.
+   * @param cellWidth Width of a cell in the grid.
+   * @param cellHeight Height of a cell in the grid.
+   */
+  constructor(bound: BoundingBox, cellWidth: number, cellHeight: number) {
+    /** The bound of the grid. Labels out of bounds will be rejected. */
+    this.bound = bound;
+
+    /** Width of a cell in the grid. */
+    this.cellWidth = cellWidth;
+
+    /** Height of a cell in the grid. */
+    this.cellHeight = cellHeight;
+
+    /** Number of grid cells along the x axis. */
+    this.numHorizCells = Math.ceil(this.boundWidth(bound) / cellWidth);
+
+    /** Number of grid cells along the y axis. */
+    this.numVertCells = Math.ceil(this.boundHeight(bound) / cellHeight);
+
+    /**
+     * The 2d grid (stored as a 1d array.) Each cell consists of an array of
+     * BoundingBoxes for objects that are in the cell.
+     */
+    this.grid = new Array(this.numHorizCells * this.numVertCells);
+  }
+
+  private boundWidth(bound: BoundingBox) { return bound.hiX - bound.loX; }
+
+  private boundHeight(bound: BoundingBox) { return bound.hiY - bound.loY; }
+
+  private boundsIntersect(a: BoundingBox, b: BoundingBox) {
+    return !(a.loX > b.hiX || a.loY > b.hiY || a.hiX < b.loX || a.hiY < b.loY);
+  }
+
+  /**
+   * Checks if a given bounding box has any conflicts in the grid and inserts it
+   * if none are found.
+   *
+   * @param bound The bound to insert.
+   * @param justTest If true, just test if it conflicts, without inserting.
+   * @return True if the bound was successfully inserted; false if it
+   *         could not be inserted due to a conflict.
+   */
+  insert(bound: BoundingBox, justTest = false): boolean {
+    // Reject if the label is out of bounds.
+    if ((bound.hiX < this.bound.loX) || (bound.loX > this.bound.hiX) ||
+        (bound.hiY < this.bound.loY) || (bound.loY > this.bound.hiY)) {
+      return false;
+    }
+
+    let minCellX = this.getCellX(bound.loX);
+    let maxCellX = this.getCellX(bound.hiX);
+    let minCellY = this.getCellY(bound.loY);
+    let maxCellY = this.getCellY(bound.hiY);
+
+    // Check all overlapped cells to verify that we can insert.
+    let baseIdx = minCellY * this.numHorizCells + minCellX;
+    let idx = baseIdx;
+    for (let j = minCellY; j <= maxCellY; j++) {
+      for (let i = minCellX; i <= maxCellX; i++) {
+        let cell = this.grid[idx++];
+        if (cell) {
+          for (let k = 0; k < cell.length; k++) {
+            if (this.boundsIntersect(bound, cell[k])) {
+              return false;
+            }
+          }
+        }
+      }
+      idx += this.numHorizCells - (maxCellX - minCellX + 1);
+    }
+
+    if (justTest) {
+      return true;
+    }
+
+    // Insert into the overlapped cells.
+    idx = baseIdx;
+    for (let j = minCellY; j <= maxCellY; j++) {
+      for (let i = minCellX; i <= maxCellX; i++) {
+        if (!this.grid[idx]) {
+          this.grid[idx] = [bound];
+        } else {
+          this.grid[idx].push(bound);
+        }
+        idx++;
+      }
+      idx += this.numHorizCells - (maxCellX - minCellX + 1);
+    }
+    return true;
+  }
+
+  /**
+   * Returns the x index of the grid cell where the given x coordinate falls.
+   *
+   * @param x the coordinate, in world space.
+   * @return the x index of the cell.
+   */
+  private getCellX(x: number) {
+    return Math.floor((x - this.bound.loX) / this.cellWidth);
+  };
+
+  /**
+   * Returns the y index of the grid cell where the given y coordinate falls.
+   *
+   * @param y the coordinate, in world space.
+   * @return the y index of the cell.
+   */
+  private getCellY(y: number) {
+    return Math.floor((y - this.bound.loY) / this.cellHeight);
+  };
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/logging.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/logging.ts
new file mode 100644
index 00000000000..59f37206012
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/logging.ts
@@ -0,0 +1,103 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/** Duration in ms for showing warning messages to the user */
+const WARNING_DURATION_MS = 10000;
+
+let dom: HTMLElement = null;
+let msgId = 0;
+let numActiveMessages = 0;
+
+export function setDomContainer(domElement: HTMLElement) {
+  dom = domElement;
+}
+
+/**
+ * Updates the user message with the provided id.
+ *
+ * @param msg The message shown to the user. If null, the message is removed.
+ * @param id The id of an existing message. If no id is provided, a unique id
+ *     is assigned.
+ * @param title The title of the notification.
+ * @param isErrorMsg If true, the message is error and the dialog will have a
+ *                   close button.
+ * @return The id of the message.
+ */
+export function setModalMessage(
+    msg: string, id: string = null, title = null, isErrorMsg = false): string {
+  if (dom == null) {
+    console.warn('Can\'t show modal message before the dom is initialized');
+    return;
+  }
+  if (id == null) {
+    id = (msgId++).toString();
+  }
+  let dialog = dom.querySelector('#notification-dialog') as any;
+  dialog.querySelector('.close-button').style.display =
+      isErrorMsg ? null : 'none';
+  let spinner = dialog.querySelector('.progress');
+  spinner.style.display = isErrorMsg ? 'none' : null;
+  spinner.active = isErrorMsg ? null : true;
+  dialog.querySelector('#notification-title').innerHTML = title;
+  let msgsContainer = dialog.querySelector('#notify-msgs') as HTMLElement;
+  if (isErrorMsg) {
+    msgsContainer.innerHTML = '';
+  } else {
+    const errors = msgsContainer.querySelectorAll('.error');
+    for (let i = 0; i < errors.length; i++) {
+      msgsContainer.removeChild(errors[i]);
+    }
+  }
+  let divId = `notify-msg-${id}`;
+  let msgDiv = dialog.querySelector('#' + divId) as HTMLDivElement;
+  if (msgDiv == null) {
+    msgDiv = document.createElement('div');
+    msgDiv.className = 'notify-msg ' + (isErrorMsg ? 'error' : '');
+    msgDiv.id = divId;
+
+    msgsContainer.insertBefore(msgDiv, msgsContainer.firstChild);
+
+    if (!isErrorMsg) {
+      numActiveMessages++;
+    } else {
+      numActiveMessages = 0;
+    }
+  }
+  if (msg == null) {
+    numActiveMessages--;
+    if (numActiveMessages === 0) {
+      dialog.close();
+    }
+    msgDiv.remove();
+  } else {
+    msgDiv.innerText = msg;
+    dialog.open();
+  }
+  return id;
+}
+
+export function setErrorMessage(errMsg: string, task?: string) {
+  setModalMessage(errMsg, null, 'Error ' + (task != null ? task : ''), true);
+}
+
+/**
+ * Shows a warning message to the user for a certain amount of time.
+ */
+export function setWarningMessage(msg: string): void {
+  let toast = dom.querySelector('#toast') as any;
+  toast.text = msg;
+  toast.duration = WARNING_DURATION_MS;
+  toast.open();
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/projectorEventContext.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/projectorEventContext.ts
new file mode 100644
index 00000000000..36f5c4c5841
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/projectorEventContext.ts
@@ -0,0 +1,45 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {DistanceFunction, Projection} from './data';
+import {NearestEntry} from './knn';
+
+export type HoverListener = (index: number) => void;
+export type SelectionChangedListener =
+    (selectedPointIndices: number[], neighborsOfFirstPoint: NearestEntry[]) =>
+        void;
+export type ProjectionChangedListener = (projection: Projection) => void;
+export type DistanceMetricChangedListener =
+    (distanceMetric: DistanceFunction) => void;
+export interface ProjectorEventContext {
+  /** Register a callback to be invoked when the mouse hovers over a point. */
+  registerHoverListener(listener: HoverListener);
+  /** Notify the hover system that a point is under the mouse. */
+  notifyHoverOverPoint(pointIndex: number);
+  /** Registers a callback to be invoked when the selection changes. */
+  registerSelectionChangedListener(listener: SelectionChangedListener);
+  /**
+   * Notify the selection system that a client has changed the selected point
+   * set.
+   */
+  notifySelectionChanged(newSelectedPointIndices: number[]);
+  /** Registers a callback to be invoked when the projection changes. */
+  registerProjectionChangedListener(listener: ProjectionChangedListener);
+  /** Notify listeners that a reprojection occurred. */
+  notifyProjectionChanged(projection: Projection);
+  registerDistanceMetricChangedListener(listener:
+                                            DistanceMetricChangedListener);
+  notifyDistanceMetricChanged(distMetric: DistanceFunction);
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/projectorScatterPlotAdapter.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/projectorScatterPlotAdapter.ts
new file mode 100644
index 00000000000..bb09e2b153a
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/projectorScatterPlotAdapter.ts
@@ -0,0 +1,713 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import * as d3 from 'd3';  // from //third_party/javascript/typings/d3_v4
+
+import {DataSet, DistanceFunction, Projection, ProjectionComponents3D, State} from './data';
+import {NearestEntry} from './knn';
+import {ProjectorEventContext} from './projectorEventContext';
+import {LabelRenderParams} from './renderContext';
+import {ScatterPlot} from './scatterPlot';
+import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels';
+import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels';
+import {ScatterPlotVisualizerPolylines} from './scatterPlotVisualizerPolylines';
+import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites';
+import * as vector from './vector';
+
+const LABEL_FONT_SIZE = 10;
+const LABEL_SCALE_DEFAULT = 1.0;
+const LABEL_SCALE_LARGE = 2;
+const LABEL_FILL_COLOR_SELECTED = 0x000000;
+const LABEL_FILL_COLOR_HOVER = 0x000000;
+const LABEL_FILL_COLOR_NEIGHBOR = 0x000000;
+const LABEL_STROKE_COLOR_SELECTED = 0xFFFFFF;
+const LABEL_STROKE_COLOR_HOVER = 0xFFFFFF;
+const LABEL_STROKE_COLOR_NEIGHBOR = 0xFFFFFF;
+
+const POINT_COLOR_UNSELECTED = 0xE3E3E3;
+const POINT_COLOR_NO_SELECTION = 0x7575D9;
+const POINT_COLOR_SELECTED = 0xFA6666;
+const POINT_COLOR_HOVER = 0x760B4F;
+
+const POINT_SCALE_DEFAULT = 1.0;
+const POINT_SCALE_SELECTED = 1.2;
+const POINT_SCALE_NEIGHBOR = 1.2;
+const POINT_SCALE_HOVER = 1.2;
+
+const LABELS_3D_COLOR_UNSELECTED = 0xFFFFFF;
+const LABELS_3D_COLOR_NO_SELECTION = 0xFFFFFF;
+
+const SPRITE_IMAGE_COLOR_UNSELECTED = 0xFFFFFF;
+const SPRITE_IMAGE_COLOR_NO_SELECTION = 0xFFFFFF;
+
+const POLYLINE_START_HUE = 60;
+const POLYLINE_END_HUE = 360;
+const POLYLINE_SATURATION = 1;
+const POLYLINE_LIGHTNESS = .3;
+
+const POLYLINE_DEFAULT_OPACITY = .2;
+const POLYLINE_DEFAULT_LINEWIDTH = 2;
+const POLYLINE_SELECTED_OPACITY = .9;
+const POLYLINE_SELECTED_LINEWIDTH = 3;
+const POLYLINE_DESELECTED_OPACITY = .05;
+
+const SCATTER_PLOT_CUBE_LENGTH = 2;
+
+/** Color scale for nearest neighbors. */
+const NN_COLOR_SCALE =
+    d3.scaleLinear<string, string>()
+        .domain([1, 0.7, 0.4])
+        .range(['hsl(285, 80%, 40%)', 'hsl(0, 80%, 65%)', 'hsl(40, 70%, 60%)'])
+        .clamp(true);
+
+/**
+ * Interprets projector events and assembes the arrays and commands necessary
+ * to use the ScatterPlot to render the current projected data set.
+ */
+export class ProjectorScatterPlotAdapter {
+  public scatterPlot: ScatterPlot;
+  private projection: Projection;
+  private hoverPointIndex: number;
+  private selectedPointIndices: number[];
+  private neighborsOfFirstSelectedPoint: NearestEntry[];
+  private renderLabelsIn3D: boolean = false;
+  private labelPointAccessor: string;
+  private legendPointColorer: (ds: DataSet, index: number) => string;
+  private distanceMetric: DistanceFunction;
+
+  private spriteVisualizer: ScatterPlotVisualizerSprites;
+  private labels3DVisualizer: ScatterPlotVisualizer3DLabels;
+  private canvasLabelsVisualizer: ScatterPlotVisualizerCanvasLabels;
+  private polylineVisualizer: ScatterPlotVisualizerPolylines;
+
+  constructor(
+      private scatterPlotContainer: HTMLElement,
+      projectorEventContext: ProjectorEventContext) {
+    this.scatterPlot =
+        new ScatterPlot(scatterPlotContainer, projectorEventContext);
+    projectorEventContext.registerProjectionChangedListener(projection => {
+      this.projection = projection;
+      this.updateScatterPlotWithNewProjection(projection);
+    });
+    projectorEventContext.registerSelectionChangedListener(
+        (selectedPointIndices, neighbors) => {
+          this.selectedPointIndices = selectedPointIndices;
+          this.neighborsOfFirstSelectedPoint = neighbors;
+          this.updateScatterPlotPositions();
+          this.updateScatterPlotAttributes();
+          this.scatterPlot.render();
+        });
+    projectorEventContext.registerHoverListener(hoverPointIndex => {
+      this.hoverPointIndex = hoverPointIndex;
+      this.updateScatterPlotAttributes();
+      this.scatterPlot.render();
+    });
+    projectorEventContext.registerDistanceMetricChangedListener(
+        distanceMetric => {
+          this.distanceMetric = distanceMetric;
+          this.updateScatterPlotAttributes();
+          this.scatterPlot.render();
+        });
+    this.createVisualizers(false);
+  }
+
+  notifyProjectionPositionsUpdated() {
+    this.updateScatterPlotPositions();
+    this.scatterPlot.render();
+  }
+
+  setDataSet(dataSet: DataSet) {
+    if (this.projection != null) {
+      // TODO(nicholsonc): setDataSet needs to go away, the projection is the
+      // atomic unit of update.
+      this.projection.dataSet = dataSet;
+    }
+    if (this.polylineVisualizer != null) {
+      this.polylineVisualizer.setDataSet(dataSet);
+    }
+    if (this.labels3DVisualizer != null) {
+      this.labels3DVisualizer.setLabelStrings(
+          this.generate3DLabelsArray(dataSet, this.labelPointAccessor));
+    }
+    if (this.spriteVisualizer == null) {
+      return;
+    }
+    this.spriteVisualizer.clearSpriteAtlas();
+    if ((dataSet == null) || (dataSet.spriteAndMetadataInfo == null)) {
+      return;
+    }
+    const metadata = dataSet.spriteAndMetadataInfo;
+    if ((metadata.spriteImage == null) || (metadata.spriteMetadata == null)) {
+      return;
+    }
+    const n = dataSet.points.length;
+    const spriteIndices = new Float32Array(n);
+    for (let i = 0; i < n; ++i) {
+      spriteIndices[i] = dataSet.points[i].index;
+    }
+    this.spriteVisualizer.setSpriteAtlas(
+        metadata.spriteImage, metadata.spriteMetadata.singleImageDim,
+        spriteIndices);
+  }
+
+  set3DLabelMode(renderLabelsIn3D: boolean) {
+    this.renderLabelsIn3D = renderLabelsIn3D;
+    this.createVisualizers(renderLabelsIn3D);
+    this.updateScatterPlotAttributes();
+    this.scatterPlot.render();
+  }
+
+  setLegendPointColorer(
+      legendPointColorer: (ds: DataSet, index: number) => string) {
+    this.legendPointColorer = legendPointColorer;
+  }
+
+  setLabelPointAccessor(labelPointAccessor: string) {
+    this.labelPointAccessor = labelPointAccessor;
+    if (this.labels3DVisualizer != null) {
+      const ds = (this.projection == null) ? null : this.projection.dataSet;
+      this.labels3DVisualizer.setLabelStrings(
+          this.generate3DLabelsArray(ds, labelPointAccessor));
+    }
+  }
+
+  resize() {
+    this.scatterPlot.resize();
+  }
+
+  populateBookmarkFromUI(state: State) {
+    state.cameraDef = this.scatterPlot.getCameraDef();
+  }
+
+  restoreUIFromBookmark(state: State) {
+    this.scatterPlot.setCameraParametersForNextCameraCreation(
+        state.cameraDef, false);
+  }
+
+  updateScatterPlotPositions() {
+    const ds = (this.projection == null) ? null : this.projection.dataSet;
+    const projectionComponents =
+        (this.projection == null) ? null : this.projection.projectionComponents;
+    const newPositions =
+        this.generatePointPositionArray(ds, projectionComponents);
+    this.scatterPlot.setPointPositions(newPositions);
+  }
+
+  updateScatterPlotAttributes() {
+    if (this.projection == null) {
+      return;
+    }
+    const dataSet = this.projection.dataSet;
+    const selectedSet = this.selectedPointIndices;
+    const hoverIndex = this.hoverPointIndex;
+    const neighbors = this.neighborsOfFirstSelectedPoint;
+    const pointColorer = this.legendPointColorer;
+
+    const pointColors = this.generatePointColorArray(
+        dataSet, pointColorer, this.distanceMetric, selectedSet, neighbors,
+        hoverIndex, this.renderLabelsIn3D, this.getSpriteImageMode());
+    const pointScaleFactors = this.generatePointScaleFactorArray(
+        dataSet, selectedSet, neighbors, hoverIndex);
+    const labels = this.generateVisibleLabelRenderParams(
+        dataSet, selectedSet, neighbors, hoverIndex);
+    const polylineColors =
+        this.generateLineSegmentColorMap(dataSet, pointColorer);
+    const polylineOpacities =
+        this.generateLineSegmentOpacityArray(dataSet, selectedSet);
+    const polylineWidths =
+        this.generateLineSegmentWidthArray(dataSet, selectedSet);
+
+    this.scatterPlot.setPointColors(pointColors);
+    this.scatterPlot.setPointScaleFactors(pointScaleFactors);
+    this.scatterPlot.setLabels(labels);
+    this.scatterPlot.setPolylineColors(polylineColors);
+    this.scatterPlot.setPolylineOpacities(polylineOpacities);
+    this.scatterPlot.setPolylineWidths(polylineWidths);
+  }
+
+  render() {
+    this.scatterPlot.render();
+  }
+
+  generatePointPositionArray(
+      ds: DataSet, projectionComponents: ProjectionComponents3D): Float32Array {
+    if (ds == null) {
+      return null;
+    }
+
+    const xScaler = d3.scaleLinear();
+    const yScaler = d3.scaleLinear();
+    let zScaler = null;
+    {
+      // Determine max and min of each axis of our data.
+      const xExtent = d3.extent(
+          ds.points,
+          (p, i) => ds.points[i].projections[projectionComponents[0]]);
+      const yExtent = d3.extent(
+          ds.points,
+          (p, i) => ds.points[i].projections[projectionComponents[1]]);
+
+      const range =
+          [-SCATTER_PLOT_CUBE_LENGTH / 2, SCATTER_PLOT_CUBE_LENGTH / 2];
+
+      xScaler.domain(xExtent).range(range);
+      yScaler.domain(yExtent).range(range);
+
+      if (projectionComponents[2] != null) {
+        const zExtent = d3.extent(
+            ds.points,
+            (p, i) => ds.points[i].projections[projectionComponents[2]]);
+        zScaler = d3.scaleLinear();
+        zScaler.domain(zExtent).range(range);
+      }
+    }
+
+    const positions = new Float32Array(ds.points.length * 3);
+    let dst = 0;
+
+    ds.points.forEach((d, i) => {
+      positions[dst++] =
+          xScaler(ds.points[i].projections[projectionComponents[0]]);
+      positions[dst++] =
+          yScaler(ds.points[i].projections[projectionComponents[1]]);
+      positions[dst++] = 0.0;
+    });
+
+    if (zScaler) {
+      dst = 2;
+      ds.points.forEach((d, i) => {
+        positions[dst] =
+            zScaler(ds.points[i].projections[projectionComponents[2]]);
+        dst += 3;
+      });
+    }
+
+    return positions;
+  }
+
+  generateVisibleLabelRenderParams(
+      ds: DataSet, selectedPointIndices: number[],
+      neighborsOfFirstPoint: NearestEntry[],
+      hoverPointIndex: number): LabelRenderParams {
+    if (ds == null) {
+      return null;
+    }
+
+    const selectedPointCount =
+        (selectedPointIndices == null) ? 0 : selectedPointIndices.length;
+    const neighborCount =
+        (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length;
+    const n = selectedPointCount + neighborCount +
+        ((hoverPointIndex != null) ? 1 : 0);
+
+    const visibleLabels = new Uint32Array(n);
+    const scale = new Float32Array(n);
+    const opacityFlags = new Int8Array(n);
+    const fillColors = new Uint8Array(n * 3);
+    const strokeColors = new Uint8Array(n * 3);
+    const labelStrings: string[] = [];
+
+    scale.fill(LABEL_SCALE_DEFAULT);
+    opacityFlags.fill(1);
+
+    let dst = 0;
+
+    if (hoverPointIndex != null) {
+      labelStrings.push(
+          this.getLabelText(ds, hoverPointIndex, this.labelPointAccessor));
+      visibleLabels[dst] = hoverPointIndex;
+      scale[dst] = LABEL_SCALE_LARGE;
+      opacityFlags[dst] = 0;
+      const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER);
+      packRgbIntoUint8Array(
+          fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
+      const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER);
+      packRgbIntoUint8Array(
+          strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[1]);
+      ++dst;
+    }
+
+    // Selected points
+    {
+      const n = selectedPointCount;
+      const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED);
+      const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED);
+      for (let i = 0; i < n; ++i) {
+        const labelIndex = selectedPointIndices[i];
+        labelStrings.push(
+            this.getLabelText(ds, labelIndex, this.labelPointAccessor));
+        visibleLabels[dst] = labelIndex;
+        scale[dst] = LABEL_SCALE_LARGE;
+        opacityFlags[dst] = (n === 1) ? 0 : 1;
+        packRgbIntoUint8Array(
+            fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
+        packRgbIntoUint8Array(
+            strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]);
+        ++dst;
+      }
+    }
+
+    // Neighbors
+    {
+      const n = neighborCount;
+      const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR);
+      const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR);
+      for (let i = 0; i < n; ++i) {
+        const labelIndex = neighborsOfFirstPoint[i].index;
+        labelStrings.push(
+            this.getLabelText(ds, labelIndex, this.labelPointAccessor));
+        visibleLabels[dst] = labelIndex;
+        packRgbIntoUint8Array(
+            fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
+        packRgbIntoUint8Array(
+            strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]);
+        ++dst;
+      }
+    }
+
+    return new LabelRenderParams(
+        visibleLabels, labelStrings, scale, opacityFlags, LABEL_FONT_SIZE,
+        fillColors, strokeColors);
+  }
+
+  generatePointScaleFactorArray(
+      ds: DataSet, selectedPointIndices: number[],
+      neighborsOfFirstPoint: NearestEntry[],
+      hoverPointIndex: number): Float32Array {
+    if (ds == null) {
+      return new Float32Array(0);
+    }
+
+    const scale = new Float32Array(ds.points.length);
+    scale.fill(POINT_SCALE_DEFAULT);
+
+    const selectedPointCount =
+        (selectedPointIndices == null) ? 0 : selectedPointIndices.length;
+    const neighborCount =
+        (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length;
+
+    // Scale up all selected points.
+    {
+      const n = selectedPointCount;
+      for (let i = 0; i < n; ++i) {
+        const p = selectedPointIndices[i];
+        scale[p] = POINT_SCALE_SELECTED;
+      }
+    }
+
+    // Scale up the neighbor points.
+    {
+      const n = neighborCount;
+      for (let i = 0; i < n; ++i) {
+        const p = neighborsOfFirstPoint[i].index;
+        scale[p] = POINT_SCALE_NEIGHBOR;
+      }
+    }
+
+    // Scale up the hover point.
+    if (hoverPointIndex != null) {
+      scale[hoverPointIndex] = POINT_SCALE_HOVER;
+    }
+
+    return scale;
+  }
+
+  generateLineSegmentColorMap(
+      ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string):
+      {[polylineIndex: number]: Float32Array} {
+    let polylineColorArrayMap: {[polylineIndex: number]: Float32Array} = {};
+    if (ds == null) {
+      return polylineColorArrayMap;
+    }
+
+    for (let i = 0; i < ds.sequences.length; i++) {
+      let sequence = ds.sequences[i];
+      let colors = new Float32Array(2 * (sequence.pointIndices.length - 1) * 3);
+      let colorIndex = 0;
+
+      if (legendPointColorer) {
+        for (let j = 0; j < sequence.pointIndices.length - 1; j++) {
+          const c1 =
+              new THREE.Color(legendPointColorer(ds, sequence.pointIndices[j]));
+          const c2 = new THREE.Color(
+              legendPointColorer(ds, sequence.pointIndices[j + 1]));
+          colors[colorIndex++] = c1.r;
+          colors[colorIndex++] = c1.g;
+          colors[colorIndex++] = c1.b;
+          colors[colorIndex++] = c2.r;
+          colors[colorIndex++] = c2.g;
+          colors[colorIndex++] = c2.b;
+        }
+      } else {
+        for (let j = 0; j < sequence.pointIndices.length - 1; j++) {
+          const c1 =
+              getDefaultPointInPolylineColor(j, sequence.pointIndices.length);
+          const c2 = getDefaultPointInPolylineColor(
+              j + 1, sequence.pointIndices.length);
+          colors[colorIndex++] = c1.r;
+          colors[colorIndex++] = c1.g;
+          colors[colorIndex++] = c1.b;
+          colors[colorIndex++] = c2.r;
+          colors[colorIndex++] = c2.g;
+          colors[colorIndex++] = c2.b;
+        }
+      }
+
+      polylineColorArrayMap[i] = colors;
+    }
+
+    return polylineColorArrayMap;
+  }
+
+  generateLineSegmentOpacityArray(ds: DataSet, selectedPoints: number[]):
+      Float32Array {
+    if (ds == null) {
+      return new Float32Array(0);
+    }
+    const opacities = new Float32Array(ds.sequences.length);
+    const selectedPointCount =
+        (selectedPoints == null) ? 0 : selectedPoints.length;
+    if (selectedPointCount > 0) {
+      opacities.fill(POLYLINE_DESELECTED_OPACITY);
+      const i = ds.points[selectedPoints[0]].sequenceIndex;
+      opacities[i] = POLYLINE_SELECTED_OPACITY;
+    } else {
+      opacities.fill(POLYLINE_DEFAULT_OPACITY);
+    }
+    return opacities;
+  }
+
+  generateLineSegmentWidthArray(ds: DataSet, selectedPoints: number[]):
+      Float32Array {
+    if (ds == null) {
+      return new Float32Array(0);
+    }
+    const widths = new Float32Array(ds.sequences.length);
+    widths.fill(POLYLINE_DEFAULT_LINEWIDTH);
+    const selectedPointCount =
+        (selectedPoints == null) ? 0 : selectedPoints.length;
+    if (selectedPointCount > 0) {
+      const i = ds.points[selectedPoints[0]].sequenceIndex;
+      widths[i] = POLYLINE_SELECTED_LINEWIDTH;
+    }
+    return widths;
+  }
+
+  generatePointColorArray(
+      ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string,
+      distFunc: DistanceFunction, selectedPointIndices: number[],
+      neighborsOfFirstPoint: NearestEntry[], hoverPointIndex: number,
+      label3dMode: boolean, spriteImageMode: boolean): Float32Array {
+    if (ds == null) {
+      return new Float32Array(0);
+    }
+
+    const selectedPointCount =
+        (selectedPointIndices == null) ? 0 : selectedPointIndices.length;
+    const neighborCount =
+        (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length;
+    const colors = new Float32Array(ds.points.length * 3);
+
+    let unselectedColor = POINT_COLOR_UNSELECTED;
+    let noSelectionColor = POINT_COLOR_NO_SELECTION;
+
+    if (label3dMode) {
+      unselectedColor = LABELS_3D_COLOR_UNSELECTED;
+      noSelectionColor = LABELS_3D_COLOR_NO_SELECTION;
+    }
+
+    if (spriteImageMode) {
+      unselectedColor = SPRITE_IMAGE_COLOR_UNSELECTED;
+      noSelectionColor = SPRITE_IMAGE_COLOR_NO_SELECTION;
+    }
+
+    // Give all points the unselected color.
+    {
+      const n = ds.points.length;
+      let dst = 0;
+      if (selectedPointCount > 0) {
+        const c = new THREE.Color(unselectedColor);
+        for (let i = 0; i < n; ++i) {
+          colors[dst++] = c.r;
+          colors[dst++] = c.g;
+          colors[dst++] = c.b;
+        }
+      } else {
+        if (legendPointColorer != null) {
+          for (let i = 0; i < n; ++i) {
+            const c = new THREE.Color(legendPointColorer(ds, i));
+            colors[dst++] = c.r;
+            colors[dst++] = c.g;
+            colors[dst++] = c.b;
+          }
+        } else {
+          const c = new THREE.Color(noSelectionColor);
+          for (let i = 0; i < n; ++i) {
+            colors[dst++] = c.r;
+            colors[dst++] = c.g;
+            colors[dst++] = c.b;
+          }
+        }
+      }
+    }
+
+    // Color the selected points.
+    {
+      const n = selectedPointCount;
+      const c = new THREE.Color(POINT_COLOR_SELECTED);
+      for (let i = 0; i < n; ++i) {
+        let dst = selectedPointIndices[i] * 3;
+        colors[dst++] = c.r;
+        colors[dst++] = c.g;
+        colors[dst++] = c.b;
+      }
+    }
+
+    // Color the neighbors.
+    {
+      const n = neighborCount;
+      let minDist = n > 0 ? neighborsOfFirstPoint[0].dist : 0;
+      for (let i = 0; i < n; ++i) {
+        const c = new THREE.Color(
+            dist2color(distFunc, neighborsOfFirstPoint[i].dist, minDist));
+        let dst = neighborsOfFirstPoint[i].index * 3;
+        colors[dst++] = c.r;
+        colors[dst++] = c.g;
+        colors[dst++] = c.b;
+      }
+    }
+
+    // Color the hover point.
+    if (hoverPointIndex != null) {
+      const c = new THREE.Color(POINT_COLOR_HOVER);
+      let dst = hoverPointIndex * 3;
+      colors[dst++] = c.r;
+      colors[dst++] = c.g;
+      colors[dst++] = c.b;
+    }
+
+    return colors;
+  }
+
+  generate3DLabelsArray(ds: DataSet, accessor: string) {
+    if ((ds == null) || (accessor == null)) {
+      return null;
+    }
+    let labels: string[] = [];
+    const n = ds.points.length;
+    for (let i = 0; i < n; ++i) {
+      labels.push(this.getLabelText(ds, i, accessor));
+    }
+    return labels;
+  }
+
+  private getLabelText(ds: DataSet, i: number, accessor: string) {
+    return ds.points[i].metadata[accessor].toString();
+  }
+
+  private updateScatterPlotWithNewProjection(projection: Projection) {
+    if (projection == null) {
+      this.createVisualizers(this.renderLabelsIn3D);
+      this.scatterPlot.render();
+      return;
+    }
+    this.setDataSet(projection.dataSet);
+    this.scatterPlot.setDimensions(projection.dimensionality);
+    if (projection.dataSet.projectionCanBeRendered(projection.projectionType)) {
+      this.updateScatterPlotAttributes();
+      this.notifyProjectionPositionsUpdated();
+    }
+    this.scatterPlot.setCameraParametersForNextCameraCreation(null, false);
+  }
+
+  private createVisualizers(inLabels3DMode: boolean) {
+    const ds = (this.projection == null) ? null : this.projection.dataSet;
+    const scatterPlot = this.scatterPlot;
+    scatterPlot.removeAllVisualizers();
+    this.labels3DVisualizer = null;
+    this.canvasLabelsVisualizer = null;
+    this.spriteVisualizer = null;
+    this.polylineVisualizer = null;
+    if (inLabels3DMode) {
+      this.labels3DVisualizer = new ScatterPlotVisualizer3DLabels();
+      this.labels3DVisualizer.setLabelStrings(
+          this.generate3DLabelsArray(ds, this.labelPointAccessor));
+    } else {
+      this.spriteVisualizer = new ScatterPlotVisualizerSprites();
+      scatterPlot.addVisualizer(this.spriteVisualizer);
+      this.canvasLabelsVisualizer =
+          new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer);
+    }
+    this.polylineVisualizer = new ScatterPlotVisualizerPolylines();
+    this.setDataSet(ds);
+    if (this.spriteVisualizer) {
+      scatterPlot.addVisualizer(this.spriteVisualizer);
+    }
+    if (this.labels3DVisualizer) {
+      scatterPlot.addVisualizer(this.labels3DVisualizer);
+    }
+    if (this.canvasLabelsVisualizer) {
+      scatterPlot.addVisualizer(this.canvasLabelsVisualizer);
+    }
+    scatterPlot.addVisualizer(this.polylineVisualizer);
+  }
+
+  private getSpriteImageMode(): boolean {
+    if (this.projection == null) {
+      return false;
+    }
+    const ds = this.projection.dataSet;
+    if ((ds == null) || (ds.spriteAndMetadataInfo == null)) {
+      return false;
+    }
+    return ds.spriteAndMetadataInfo.spriteImage != null;
+  }
+}
+
+function packRgbIntoUint8Array(
+    rgbArray: Uint8Array, labelIndex: number, r: number, g: number, b: number) {
+  rgbArray[labelIndex * 3] = r;
+  rgbArray[labelIndex * 3 + 1] = g;
+  rgbArray[labelIndex * 3 + 2] = b;
+}
+
+function styleRgbFromHexColor(hex: number): [number, number, number] {
+  const c = new THREE.Color(hex);
+  return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0];
+}
+
+function getDefaultPointInPolylineColor(
+    index: number, totalPoints: number): THREE.Color {
+  let hue = POLYLINE_START_HUE +
+      (POLYLINE_END_HUE - POLYLINE_START_HUE) * index / totalPoints;
+
+  let rgb = d3.hsl(hue, POLYLINE_SATURATION, POLYLINE_LIGHTNESS).rgb();
+  return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255);
+}
+
+/**
+ * Normalizes the distance so it can be visually encoded with color.
+ * The normalization depends on the distance metric (cosine vs euclidean).
+ */
+export function normalizeDist(
+    distFunc: DistanceFunction, d: number, minDist: number): number {
+  return (distFunc === vector.dist) ? (minDist / d) : (1 - d);
+}
+
+/** Normalizes and encodes the provided distance with color. */
+export function dist2color(
+    distFunc: DistanceFunction, d: number, minDist: number): string {
+  return NN_COLOR_SCALE(normalizeDist(distFunc, d, minDist));
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/renderContext.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/renderContext.ts
new file mode 100644
index 00000000000..8d5232a8048
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/renderContext.ts
@@ -0,0 +1,53 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http:www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/**
+ * LabelRenderParams describes the set of points that should have labels
+ * rendered next to them.
+ */
+export class LabelRenderParams {
+  constructor(
+      public pointIndices: Float32Array, public labelStrings: string[],
+      public scaleFactors: Float32Array, public useSceneOpacityFlags: Int8Array,
+      public defaultFontSize: number, public fillColors: Uint8Array,
+      public strokeColors: Uint8Array) {}
+}
+
+/** Details about the camera projection being used to render the scene. */
+export enum CameraType {
+  Perspective,
+  Orthographic
+}
+
+/**
+ * RenderContext contains all of the state required to color and render the data
+ * set. ScatterPlot passes this to every attached visualizer as part of the
+ * render callback.
+ * TODO(nicholsonc): This should only contain the data that's changed between
+ * each frame. Data like colors / scale factors / labels should be reapplied
+ * only when they change.
+ */
+export class RenderContext {
+  constructor(
+      public camera: THREE.Camera, public cameraType: CameraType,
+      public cameraTarget: THREE.Vector3, public screenWidth: number,
+      public screenHeight: number, public nearestCameraSpacePointZ: number,
+      public farthestCameraSpacePointZ: number, public backgroundColor: number,
+      public pointColors: Float32Array, public pointScaleFactors: Float32Array,
+      public labels: LabelRenderParams,
+      public polylineColors: {[polylineIndex: number]: Float32Array},
+      public polylineOpacities: Float32Array,
+      public polylineWidths: Float32Array) {}
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlot.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlot.ts
new file mode 100644
index 00000000000..283b608e836
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlot.ts
@@ -0,0 +1,723 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {ProjectorEventContext} from './projectorEventContext';
+import {CameraType, LabelRenderParams, RenderContext} from './renderContext';
+import {BoundingBox, ScatterPlotRectangleSelector} from './scatterPlotRectangleSelector';
+import {ScatterPlotVisualizer} from './scatterPlotVisualizer';
+import * as util from './util';
+import {Point2D, Point3D} from './vector';
+
+const BACKGROUND_COLOR = 0xffffff;
+
+/**
+ * The length of the cube (diameter of the circumscribing sphere) where all the
+ * points live.
+ */
+const CUBE_LENGTH = 2;
+const MAX_ZOOM = 5 * CUBE_LENGTH;
+const MIN_ZOOM = 0.025 * CUBE_LENGTH;
+
+// Constants relating to the camera parameters.
+const PERSP_CAMERA_FOV_VERTICAL = 70;
+const PERSP_CAMERA_NEAR_CLIP_PLANE = 0.01;
+const PERSP_CAMERA_FAR_CLIP_PLANE = 100;
+const ORTHO_CAMERA_FRUSTUM_HALF_EXTENT = 1.2;
+
+// Key presses.
+const SHIFT_KEY = 16;
+const CTRL_KEY = 17;
+
+const START_CAMERA_POS_3D = new THREE.Vector3(0.45, 0.9, 1.6);
+const START_CAMERA_TARGET_3D = new THREE.Vector3(0, 0, 0);
+const START_CAMERA_POS_2D = new THREE.Vector3(0, 0, 4);
+const START_CAMERA_TARGET_2D = new THREE.Vector3(0, 0, 0);
+
+const ORBIT_MOUSE_ROTATION_SPEED = 1;
+const ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS = 7;
+
+export type OnCameraMoveListener =
+    (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) => void;
+
+/** Supported modes of interaction. */
+export enum MouseMode {
+  AREA_SELECT,
+  CAMERA_AND_CLICK_SELECT
+}
+
+/** Defines a camera, suitable for serialization. */
+export class CameraDef {
+  orthographic: boolean = false;
+  position: Point3D;
+  target: Point3D;
+  zoom: number;
+}
+
+/**
+ * Maintains a three.js instantiation and context,
+ * animation state, and all other logic that's
+ * independent of how a 3D scatter plot is actually rendered. Also holds an
+ * array of visualizers and dispatches application events to them.
+ */
+export class ScatterPlot {
+  private visualizers: ScatterPlotVisualizer[] = [];
+
+  private onCameraMoveListeners: OnCameraMoveListener[] = [];
+
+  private height: number;
+  private width: number;
+
+  private mouseMode: MouseMode;
+  private backgroundColor: number = BACKGROUND_COLOR;
+
+  private dimensionality: number = 3;
+  private renderer: THREE.WebGLRenderer;
+
+  private scene: THREE.Scene;
+  private pickingTexture: THREE.WebGLRenderTarget;
+  private light: THREE.PointLight;
+
+  private cameraDef: CameraDef = null;
+  private camera: THREE.Camera;
+  private orbitAnimationOnNextCameraCreation: boolean = false;
+  private orbitCameraControls: any;
+  private orbitAnimationId: number;
+
+  private worldSpacePointPositions: Float32Array;
+  private pointColors: Float32Array;
+  private pointScaleFactors: Float32Array;
+  private labels: LabelRenderParams;
+  private polylineColors: {[polylineIndex: number]: Float32Array};
+  private polylineOpacities: Float32Array;
+  private polylineWidths: Float32Array;
+
+  private selecting = false;
+  private nearestPoint: number;
+  private mouseIsDown = false;
+  private isDragSequence = false;
+  private rectangleSelector: ScatterPlotRectangleSelector;
+
+  constructor(
+      private container: HTMLElement,
+      private projectorEventContext: ProjectorEventContext) {
+    this.getLayoutValues();
+
+    this.scene = new THREE.Scene();
+    this.renderer = new THREE.WebGLRenderer(
+        {alpha: true, premultipliedAlpha: false, antialias: false});
+    this.renderer.setClearColor(BACKGROUND_COLOR, 1);
+    this.container.appendChild(this.renderer.domElement);
+    this.light = new THREE.PointLight(0xFFECBF, 1, 0);
+    this.scene.add(this.light);
+
+    this.setDimensions(3);
+    this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality));
+    this.renderer.render(this.scene, this.camera);
+
+    this.rectangleSelector = new ScatterPlotRectangleSelector(
+        this.container,
+        (boundingBox: BoundingBox) => this.selectBoundingBox(boundingBox));
+    this.addInteractionListeners();
+  }
+
+  private addInteractionListeners() {
+    this.container.addEventListener('mousemove', this.onMouseMove.bind(this));
+    this.container.addEventListener('mousedown', this.onMouseDown.bind(this));
+    this.container.addEventListener('mouseup', this.onMouseUp.bind(this));
+    this.container.addEventListener('click', this.onClick.bind(this));
+    window.addEventListener('keydown', this.onKeyDown.bind(this), false);
+    window.addEventListener('keyup', this.onKeyUp.bind(this), false);
+  }
+
+  private addCameraControlsEventListeners(cameraControls: any) {
+    // Start is called when the user stars interacting with
+    // controls.
+    cameraControls.addEventListener('start', () => {
+      this.stopOrbitAnimation();
+      this.onCameraMoveListeners.forEach(
+          l => l(this.camera.position, cameraControls.target));
+    });
+
+    // Change is called everytime the user interacts with the controls.
+    cameraControls.addEventListener('change', () => {
+      this.render();
+    });
+
+    // End is called when the user stops interacting with the
+    // controls (e.g. on mouse up, after dragging).
+    cameraControls.addEventListener('end', () => {});
+  }
+
+  private makeOrbitControls(
+      camera: THREE.Camera, cameraDef: CameraDef, cameraIs3D: boolean) {
+    if (this.orbitCameraControls != null) {
+      this.orbitCameraControls.dispose();
+    }
+    const occ =
+        new (THREE as any).OrbitControls(camera, this.renderer.domElement);
+    occ.target0 = new THREE.Vector3(
+        cameraDef.target[0], cameraDef.target[1], cameraDef.target[2]);
+    occ.position0 = new THREE.Vector3().copy(camera.position);
+    occ.zoom0 = cameraDef.zoom;
+    occ.enableRotate = cameraIs3D;
+    occ.autoRotate = false;
+    occ.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED;
+    if (cameraIs3D) {
+      occ.mouseButtons.ORBIT = THREE.MOUSE.LEFT;
+      occ.mouseButtons.PAN = THREE.MOUSE.RIGHT;
+    } else {
+      occ.mouseButtons.ORBIT = null;
+      occ.mouseButtons.PAN = THREE.MOUSE.LEFT;
+    }
+    occ.reset();
+
+    this.camera = camera;
+    this.orbitCameraControls = occ;
+    this.addCameraControlsEventListeners(this.orbitCameraControls);
+  }
+
+  private makeCamera3D(cameraDef: CameraDef, w: number, h: number) {
+    let camera: THREE.PerspectiveCamera;
+    {
+      const aspectRatio = w / h;
+      camera = new THREE.PerspectiveCamera(
+          PERSP_CAMERA_FOV_VERTICAL, aspectRatio, PERSP_CAMERA_NEAR_CLIP_PLANE,
+          PERSP_CAMERA_FAR_CLIP_PLANE);
+      camera.position.set(
+          cameraDef.position[0], cameraDef.position[1], cameraDef.position[2]);
+      const at = new THREE.Vector3(
+          cameraDef.target[0], cameraDef.target[1], cameraDef.target[2]);
+      camera.lookAt(at);
+      camera.zoom = cameraDef.zoom;
+      camera.updateProjectionMatrix();
+    }
+    this.camera = camera;
+    this.makeOrbitControls(camera, cameraDef, true);
+  }
+
+  private makeCamera2D(cameraDef: CameraDef, w: number, h: number) {
+    let camera: THREE.OrthographicCamera;
+    const target = new THREE.Vector3(
+        cameraDef.target[0], cameraDef.target[1], cameraDef.target[2]);
+    {
+      const aspectRatio = w / h;
+      let left = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT;
+      let right = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT;
+      let bottom = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT;
+      let top = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT;
+      // Scale up the larger of (w, h) to match the aspect ratio.
+      if (aspectRatio > 1) {
+        left *= aspectRatio;
+        right *= aspectRatio;
+      } else {
+        top /= aspectRatio;
+        bottom /= aspectRatio;
+      }
+      camera =
+          new THREE.OrthographicCamera(left, right, top, bottom, -1000, 1000);
+      camera.position.set(
+          cameraDef.position[0], cameraDef.position[1], cameraDef.position[2]);
+      camera.up = new THREE.Vector3(0, 1, 0);
+      camera.lookAt(target);
+      camera.zoom = cameraDef.zoom;
+      camera.updateProjectionMatrix();
+    }
+    this.camera = camera;
+    this.makeOrbitControls(camera, cameraDef, false);
+  }
+
+  private makeDefaultCameraDef(dimensionality: number): CameraDef {
+    const def = new CameraDef();
+    def.orthographic = (dimensionality === 2);
+    def.zoom = 1.0;
+    if (def.orthographic) {
+      def.position =
+          [START_CAMERA_POS_2D.x, START_CAMERA_POS_2D.y, START_CAMERA_POS_2D.z];
+      def.target = [
+        START_CAMERA_TARGET_2D.x, START_CAMERA_TARGET_2D.y,
+        START_CAMERA_TARGET_2D.z
+      ];
+    } else {
+      def.position =
+          [START_CAMERA_POS_3D.x, START_CAMERA_POS_3D.y, START_CAMERA_POS_3D.z];
+      def.target = [
+        START_CAMERA_TARGET_3D.x, START_CAMERA_TARGET_3D.y,
+        START_CAMERA_TARGET_3D.z
+      ];
+    }
+    return def;
+  }
+
+  /** Recreate the scatter plot camera from a definition structure. */
+  recreateCamera(cameraDef: CameraDef) {
+    if (cameraDef.orthographic) {
+      this.makeCamera2D(cameraDef, this.width, this.height);
+    } else {
+      this.makeCamera3D(cameraDef, this.width, this.height);
+    }
+    this.orbitCameraControls.minDistance = MIN_ZOOM;
+    this.orbitCameraControls.maxDistance = MAX_ZOOM;
+    this.orbitCameraControls.update();
+    if (this.orbitAnimationOnNextCameraCreation) {
+      this.startOrbitAnimation();
+    }
+  }
+
+  private onClick(e?: MouseEvent, notify = true) {
+    if (e && this.selecting) {
+      return;
+    }
+    // Only call event handlers if the click originated from the scatter plot.
+    if (!this.isDragSequence && notify) {
+      const selection = (this.nearestPoint != null) ? [this.nearestPoint] : [];
+      this.projectorEventContext.notifySelectionChanged(selection);
+    }
+    this.isDragSequence = false;
+    this.render();
+  }
+
+  private onMouseDown(e: MouseEvent) {
+    this.isDragSequence = false;
+    this.mouseIsDown = true;
+    if (this.selecting) {
+      this.orbitCameraControls.enabled = false;
+      this.rectangleSelector.onMouseDown(e.offsetX, e.offsetY);
+      this.setNearestPointToMouse(e);
+    } else if (
+        !e.ctrlKey && this.sceneIs3D() &&
+        this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.RIGHT) {
+      // The user happened to press the ctrl key when the tab was active,
+      // unpressed the ctrl when the tab was inactive, and now he/she
+      // is back to the projector tab.
+      this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT;
+      this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT;
+    } else if (
+        e.ctrlKey && this.sceneIs3D() &&
+        this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.LEFT) {
+      // Similarly to the situation above.
+      this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT;
+      this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT;
+    }
+  }
+
+  /** When we stop dragging/zooming, return to normal behavior. */
+  private onMouseUp(e: any) {
+    if (this.selecting) {
+      this.orbitCameraControls.enabled = true;
+      this.rectangleSelector.onMouseUp();
+      this.render();
+    }
+    this.mouseIsDown = false;
+  }
+
+  /**
+   * When the mouse moves, find the nearest point (if any) and send it to the
+   * hoverlisteners (usually called from embedding.ts)
+   */
+  private onMouseMove(e: MouseEvent) {
+    this.isDragSequence = this.mouseIsDown;
+    // Depending if we're selecting or just navigating, handle accordingly.
+    if (this.selecting && this.mouseIsDown) {
+      this.rectangleSelector.onMouseMove(e.offsetX, e.offsetY);
+      this.render();
+    } else if (!this.mouseIsDown) {
+      this.setNearestPointToMouse(e);
+      this.projectorEventContext.notifyHoverOverPoint(this.nearestPoint);
+    }
+  }
+
+  /** For using ctrl + left click as right click, and for circle select */
+  private onKeyDown(e: any) {
+    // If ctrl is pressed, use left click to orbit
+    if (e.keyCode === CTRL_KEY && this.sceneIs3D()) {
+      this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT;
+      this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT;
+    }
+
+    // If shift is pressed, start selecting
+    if (e.keyCode === SHIFT_KEY) {
+      this.selecting = true;
+      this.container.style.cursor = 'crosshair';
+    }
+  }
+
+  /** For using ctrl + left click as right click, and for circle select */
+  private onKeyUp(e: any) {
+    if (e.keyCode === CTRL_KEY && this.sceneIs3D()) {
+      this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT;
+      this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT;
+    }
+
+    // If shift is released, stop selecting
+    if (e.keyCode === SHIFT_KEY) {
+      this.selecting = (this.getMouseMode() === MouseMode.AREA_SELECT);
+      if (!this.selecting) {
+        this.container.style.cursor = 'default';
+      }
+      this.render();
+    }
+  }
+
+  /**
+   * Returns a list of indices of points in a bounding box from the picking
+   * texture.
+   * @param boundingBox The bounding box to select from.
+   */
+  private getPointIndicesFromPickingTexture(boundingBox: BoundingBox):
+      number[] {
+    if (this.worldSpacePointPositions == null) {
+      return null;
+    }
+    const pointCount = this.worldSpacePointPositions.length / 3;
+    const dpr = window.devicePixelRatio || 1;
+    const x = Math.floor(boundingBox.x * dpr);
+    const y = Math.floor(boundingBox.y * dpr);
+    const width = Math.floor(boundingBox.width * dpr);
+    const height = Math.floor(boundingBox.height * dpr);
+
+    // Create buffer for reading all of the pixels from the texture.
+    let pixelBuffer = new Uint8Array(width * height * 4);
+
+    // Read the pixels from the bounding box.
+    this.renderer.readRenderTargetPixels(
+        this.pickingTexture, x, this.pickingTexture.height - y, width, height,
+        pixelBuffer);
+
+    // Keep a flat list of each point and whether they are selected or not. This
+    // approach is more efficient than using an object keyed by the index.
+    let pointIndicesSelection =
+        new Uint8Array(this.worldSpacePointPositions.length);
+    for (let i = 0; i < width * height; i++) {
+      const id = (pixelBuffer[i * 4] << 16) | (pixelBuffer[i * 4 + 1] << 8) |
+          pixelBuffer[i * 4 + 2];
+      if (id !== 0xffffff && (id < pointCount)) {
+        pointIndicesSelection[id] = 1;
+      }
+    }
+    let pointIndices: number[] = [];
+    for (let i = 0; i < pointIndicesSelection.length; i++) {
+      if (pointIndicesSelection[i] === 1) {
+        pointIndices.push(i);
+      }
+    }
+
+    return pointIndices;
+  }
+
+
+  private selectBoundingBox(boundingBox: BoundingBox) {
+    let pointIndices = this.getPointIndicesFromPickingTexture(boundingBox);
+    this.projectorEventContext.notifySelectionChanged(pointIndices);
+  }
+
+  private setNearestPointToMouse(e: MouseEvent) {
+    if (this.pickingTexture == null) {
+      this.nearestPoint = null;
+      return;
+    }
+    const boundingBox:
+        BoundingBox = {x: e.offsetX, y: e.offsetY, width: 1, height: 1};
+    const pointIndices = this.getPointIndicesFromPickingTexture(boundingBox);
+    this.nearestPoint = (pointIndices != null) ? pointIndices[0] : null;
+  }
+
+  private getLayoutValues(): Point2D {
+    this.width = this.container.offsetWidth;
+    this.height = Math.max(1, this.container.offsetHeight);
+    return [this.width, this.height];
+  }
+
+  private sceneIs3D(): boolean {
+    return this.dimensionality === 3;
+  }
+
+  private remove3dAxisFromScene(): THREE.Object3D {
+    const axes = this.scene.getObjectByName('axes');
+    if (axes != null) {
+      this.scene.remove(axes);
+    }
+    return axes;
+  }
+
+  private add3dAxis() {
+    const axes = new THREE.AxisHelper();
+    axes.name = 'axes';
+    this.scene.add(axes);
+  }
+
+  /** Set 2d vs 3d mode. */
+  setDimensions(dimensionality: number) {
+    if ((dimensionality !== 2) && (dimensionality !== 3)) {
+      throw new RangeError('dimensionality must be 2 or 3');
+    }
+    this.dimensionality = dimensionality;
+
+    const def = this.cameraDef || this.makeDefaultCameraDef(dimensionality);
+    this.recreateCamera(def);
+
+    this.remove3dAxisFromScene();
+    if (dimensionality === 3) {
+      this.add3dAxis();
+    }
+  }
+
+  /** Gets the current camera information, suitable for serialization. */
+  getCameraDef(): CameraDef {
+    const def = new CameraDef();
+    const pos = this.camera.position;
+    const tgt = this.orbitCameraControls.target;
+    def.orthographic = !this.sceneIs3D();
+    def.position = [pos.x, pos.y, pos.z];
+    def.target = [tgt.x, tgt.y, tgt.z];
+    def.zoom = (this.camera as any).zoom;
+    return def;
+  }
+
+  /** Sets parameters for the next camera recreation. */
+  setCameraParametersForNextCameraCreation(
+      def: CameraDef, orbitAnimation: boolean) {
+    this.cameraDef = def;
+    this.orbitAnimationOnNextCameraCreation = orbitAnimation;
+  }
+
+  /** Gets the current camera position. */
+  getCameraPosition(): Point3D {
+    const currPos = this.camera.position;
+    return [currPos.x, currPos.y, currPos.z];
+  }
+
+  /** Gets the current camera target. */
+  getCameraTarget(): Point3D {
+    let currTarget = this.orbitCameraControls.target;
+    return [currTarget.x, currTarget.y, currTarget.z];
+  }
+
+  /** Sets up the camera from given position and target coordinates. */
+  setCameraPositionAndTarget(position: Point3D, target: Point3D) {
+    this.stopOrbitAnimation();
+    this.camera.position.set(position[0], position[1], position[2]);
+    this.orbitCameraControls.target.set(target[0], target[1], target[2]);
+    this.orbitCameraControls.update();
+    this.render();
+  }
+
+  /** Starts orbiting the camera around its current lookat target. */
+  startOrbitAnimation() {
+    if (!this.sceneIs3D()) {
+      return;
+    }
+    if (this.orbitAnimationId != null) {
+      this.stopOrbitAnimation();
+    }
+    this.orbitCameraControls.autoRotate = true;
+    this.orbitCameraControls.rotateSpeed =
+        ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS;
+    this.updateOrbitAnimation();
+  }
+
+  private updateOrbitAnimation() {
+    this.orbitCameraControls.update();
+    this.orbitAnimationId =
+        requestAnimationFrame(() => this.updateOrbitAnimation());
+  }
+
+  /** Stops the orbiting animation on the camera. */
+  stopOrbitAnimation() {
+    this.orbitCameraControls.autoRotate = false;
+    this.orbitCameraControls.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED;
+    if (this.orbitAnimationId != null) {
+      cancelAnimationFrame(this.orbitAnimationId);
+      this.orbitAnimationId = null;
+    }
+  }
+
+  /** Adds a visualizer to the set, will start dispatching events to it */
+  addVisualizer(visualizer: ScatterPlotVisualizer) {
+    if (this.scene) {
+      visualizer.setScene(this.scene);
+    }
+    visualizer.onResize(this.width, this.height);
+    visualizer.onPointPositionsChanged(this.worldSpacePointPositions);
+    this.visualizers.push(visualizer);
+  }
+
+  /** Removes all visualizers attached to this scatter plot. */
+  removeAllVisualizers() {
+    this.visualizers.forEach(v => v.dispose());
+    this.visualizers = [];
+  }
+
+  /** Update scatter plot with a new array of packed xyz point positions. */
+  setPointPositions(worldSpacePointPositions: Float32Array) {
+    this.worldSpacePointPositions = worldSpacePointPositions;
+    this.visualizers.forEach(
+        v => v.onPointPositionsChanged(worldSpacePointPositions));
+  }
+
+  render() {
+    {
+      const lightPos = this.camera.position.clone();
+      lightPos.x += 1;
+      lightPos.y += 1;
+      this.light.position.set(lightPos.x, lightPos.y, lightPos.z);
+    }
+
+    const cameraType = (this.camera instanceof THREE.PerspectiveCamera) ?
+        CameraType.Perspective :
+        CameraType.Orthographic;
+
+    let cameraSpacePointExtents: [number, number] = [0, 0];
+    if (this.worldSpacePointPositions != null) {
+      cameraSpacePointExtents = util.getNearFarPoints(
+          this.worldSpacePointPositions, this.camera.position,
+          this.orbitCameraControls.target);
+    }
+
+    const rc = new RenderContext(
+        this.camera, cameraType, this.orbitCameraControls.target, this.width,
+        this.height, cameraSpacePointExtents[0], cameraSpacePointExtents[1],
+        this.backgroundColor, this.pointColors, this.pointScaleFactors,
+        this.labels, this.polylineColors, this.polylineOpacities,
+        this.polylineWidths);
+
+    // Render first pass to picking target. This render fills pickingTexture
+    // with colors that are actually point ids, so that sampling the texture at
+    // the mouse's current x,y coordinates will reveal the data point that the
+    // mouse is over.
+    this.visualizers.forEach(v => v.onPickingRender(rc));
+
+    {
+      const axes = this.remove3dAxisFromScene();
+      this.renderer.render(this.scene, this.camera, this.pickingTexture);
+      if (axes != null) {
+        this.scene.add(axes);
+      }
+    }
+
+    // Render second pass to color buffer, to be displayed on the canvas.
+    this.visualizers.forEach(v => v.onRender(rc));
+
+    this.renderer.render(this.scene, this.camera);
+  }
+
+  setMouseMode(mouseMode: MouseMode) {
+    this.mouseMode = mouseMode;
+    if (mouseMode === MouseMode.AREA_SELECT) {
+      this.selecting = true;
+      this.container.style.cursor = 'crosshair';
+    } else {
+      this.selecting = false;
+      this.container.style.cursor = 'default';
+    }
+  }
+
+  /** Set the colors for every data point. (RGB triplets) */
+  setPointColors(colors: Float32Array) {
+    this.pointColors = colors;
+  }
+
+  /** Set the scale factors for every data point. (scalars) */
+  setPointScaleFactors(scaleFactors: Float32Array) {
+    this.pointScaleFactors = scaleFactors;
+  }
+
+  /** Set the labels to rendered */
+  setLabels(labels: LabelRenderParams) {
+    this.labels = labels;
+  }
+
+  /** Set the colors for every data polyline. (RGB triplets) */
+  setPolylineColors(colors: {[polylineIndex: number]: Float32Array}) {
+    this.polylineColors = colors;
+  }
+
+  setPolylineOpacities(opacities: Float32Array) {
+    this.polylineOpacities = opacities;
+  }
+
+  setPolylineWidths(widths: Float32Array) {
+    this.polylineWidths = widths;
+  }
+
+  getMouseMode(): MouseMode {
+    return this.mouseMode;
+  }
+
+  resetZoom() {
+    this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality));
+    this.render();
+  }
+
+  setDayNightMode(isNight: boolean) {
+    const canvases = this.container.querySelectorAll('canvas');
+    const filterValue = isNight ? 'invert(100%)' : null;
+    for (let i = 0; i < canvases.length; i++) {
+      canvases[i].style.filter = filterValue;
+    }
+  }
+
+  resize(render = true) {
+    const [oldW, oldH] = [this.width, this.height];
+    const [newW, newH] = this.getLayoutValues();
+
+    if (this.dimensionality === 3) {
+      const camera = (this.camera as THREE.PerspectiveCamera);
+      camera.aspect = newW / newH;
+      camera.updateProjectionMatrix();
+    } else {
+      const camera = (this.camera as THREE.OrthographicCamera);
+      // Scale the ortho frustum by however much the window changed.
+      const scaleW = newW / oldW;
+      const scaleH = newH / oldH;
+      const newCamHalfWidth = ((camera.right - camera.left) * scaleW) / 2;
+      const newCamHalfHeight = ((camera.top - camera.bottom) * scaleH) / 2;
+      camera.top = newCamHalfHeight;
+      camera.bottom = -newCamHalfHeight;
+      camera.left = -newCamHalfWidth;
+      camera.right = newCamHalfWidth;
+      camera.updateProjectionMatrix();
+    }
+
+    // Accouting for retina displays.
+    const dpr = window.devicePixelRatio || 1;
+    this.renderer.setPixelRatio(dpr);
+    this.renderer.setSize(newW, newH);
+
+    // the picking texture needs to be exactly the same as the render texture.
+    {
+      const renderCanvasSize = this.renderer.getSize();
+      const pixelRatio = this.renderer.getPixelRatio();
+      this.pickingTexture = new THREE.WebGLRenderTarget(
+          renderCanvasSize.width * pixelRatio,
+          renderCanvasSize.height * pixelRatio);
+      this.pickingTexture.texture.minFilter = THREE.LinearFilter;
+    }
+
+    this.visualizers.forEach(v => v.onResize(newW, newH));
+
+    if (render) {
+      this.render();
+    };
+  }
+
+  onCameraMove(listener: OnCameraMoveListener) {
+    this.onCameraMoveListeners.push(listener);
+  }
+
+  clickOnPoint(pointIndex: number) {
+    this.nearestPoint = pointIndex;
+    this.onClick(null, false);
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector.ts
new file mode 100644
index 00000000000..a781877014e
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector.ts
@@ -0,0 +1,107 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+const FILL = '#dddddd';
+const FILL_OPACITY = .2;
+const STROKE = '#aaaaaa';
+const STROKE_WIDTH = 2;
+const STROKE_DASHARRAY = '10 5';
+
+export interface BoundingBox {
+  // The bounding box (x, y) position refers to the bottom left corner of the
+  // rect.
+  x: number;
+  y: number;
+  width: number;
+  height: number;
+}
+
+/**
+ * A class that manages and renders a data selection rectangle.
+ */
+export class ScatterPlotRectangleSelector {
+  private svgElement: SVGElement;
+  private rectElement: SVGRectElement;
+
+  private isMouseDown: boolean;
+  private startCoordinates: [number, number];
+  private lastBoundingBox: BoundingBox;
+
+  private selectionCallback: (boundingBox: BoundingBox) => void;
+
+  /**
+   * @param container The container HTML element that the selection SVG rect
+   *     will be a child of.
+   * @param selectionCallback The callback that accepts a bounding box to be
+   *     called when selection changes. Currently, we only call the callback on
+   *     mouseUp.
+   */
+  constructor(
+      container: HTMLElement,
+      selectionCallback: (boundingBox: BoundingBox) => void) {
+    this.svgElement = container.querySelector('#selector') as SVGElement;
+    this.rectElement =
+        document.createElementNS('http://www.w3.org/2000/svg', 'rect');
+    this.rectElement.style.stroke = STROKE;
+    this.rectElement.style.strokeDasharray = STROKE_DASHARRAY;
+    this.rectElement.style.strokeWidth = '' + STROKE_WIDTH;
+    this.rectElement.style.fill = FILL;
+    this.rectElement.style.fillOpacity = '' + FILL_OPACITY;
+    this.svgElement.appendChild(this.rectElement);
+
+    this.selectionCallback = selectionCallback;
+    this.isMouseDown = false;
+  }
+
+  onMouseDown(offsetX: number, offsetY: number) {
+    this.isMouseDown = true;
+    this.rectElement.style.display = 'block';
+
+    this.startCoordinates = [offsetX, offsetY];
+    this.lastBoundingBox = {
+      x: this.startCoordinates[0],
+      y: this.startCoordinates[1],
+      width: 1,
+      height: 1
+    };
+  }
+
+  onMouseMove(offsetX: number, offsetY: number) {
+    if (!this.isMouseDown) {
+      return;
+    }
+
+    this.lastBoundingBox.x = Math.min(offsetX, this.startCoordinates[0]);
+    this.lastBoundingBox.y = Math.max(offsetY, this.startCoordinates[1]);
+    this.lastBoundingBox.width =
+        Math.max(offsetX, this.startCoordinates[0]) - this.lastBoundingBox.x;
+    this.lastBoundingBox.height =
+        this.lastBoundingBox.y - Math.min(offsetY, this.startCoordinates[1]);
+
+    this.rectElement.setAttribute('x', '' + this.lastBoundingBox.x);
+    this.rectElement.setAttribute(
+        'y', '' + (this.lastBoundingBox.y - this.lastBoundingBox.height));
+    this.rectElement.setAttribute('width', '' + this.lastBoundingBox.width);
+    this.rectElement.setAttribute('height', '' + this.lastBoundingBox.height);
+  }
+
+  onMouseUp() {
+    this.isMouseDown = false;
+    this.rectElement.style.display = 'none';
+    this.rectElement.setAttribute('width', '0');
+    this.rectElement.setAttribute('height', '0');
+    this.selectionCallback(this.lastBoundingBox);
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector_test.ts
new file mode 100644
index 00000000000..91cb10a97eb
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotRectangleSelector_test.ts
@@ -0,0 +1,69 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {BoundingBox, ScatterPlotRectangleSelector} from './scatterPlotRectangleSelector';
+
+describe('selector callbacks make bounding box start bottom left', () => {
+  let containerElement: HTMLElement;
+  let selectionCallback: (boundingBox: BoundingBox) => void;
+  let selection: ScatterPlotRectangleSelector;
+
+  beforeEach(() => {
+    containerElement = document.createElement('div');
+    const selector = document.createElement('svg');
+    selector.id = 'selector';
+    containerElement.appendChild(selector);
+
+    selectionCallback = jasmine.createSpy('selectionCallback');
+    selection =
+        new ScatterPlotRectangleSelector(containerElement, selectionCallback);
+  });
+
+  it('Simple mouse event starting top left', () => {
+    selection.onMouseDown(0, 0);
+    selection.onMouseMove(10, 10);
+    selection.onMouseUp();
+
+    expect(selectionCallback)
+        .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10});
+  });
+
+  it('Simple mouse event starting bottom left', () => {
+    selection.onMouseDown(0, 10);
+    selection.onMouseMove(10, 0);
+    selection.onMouseUp();
+
+    expect(selectionCallback)
+        .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10});
+  });
+
+  it('Simple mouse event starting top right', () => {
+    selection.onMouseDown(10, 0);
+    selection.onMouseMove(0, 10);
+    selection.onMouseUp();
+
+    expect(selectionCallback)
+        .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10});
+  });
+
+  it('Simple mouse event starting bottom right', () => {
+    selection.onMouseDown(10, 10);
+    selection.onMouseMove(0, 0);
+    selection.onMouseUp();
+
+    expect(selectionCallback)
+        .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10});
+  });
+});
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer.ts
new file mode 100644
index 00000000000..b0974a20538
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer.ts
@@ -0,0 +1,51 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {RenderContext} from './renderContext';
+
+/**
+ * ScatterPlotVisualizer is an interface used by ScatterPlotContainer
+ * to manage and aggregate any number of concurrent visualization behaviors.
+ * To add a new visualization to the 3D scatter plot, create a new class that
+ * implements this interface and attach it to the ScatterPlotContainer.
+ */
+export interface ScatterPlotVisualizer {
+  /** Called to initialize the visualizer with the primary scene. */
+  setScene(scene: THREE.Scene);
+  /**
+   * Called when the main scatter plot tears down the visualizer. Remove all
+   * objects from the scene, and dispose any heavy resources.
+   */
+  dispose();
+  /**
+   * Called when the positions of the scatter plot points have changed.
+   */
+  onPointPositionsChanged(newWorldSpacePointPositions: Float32Array);
+  /**
+   * Called immediately before the main scatter plot performs a picking
+   * (selection) render. Set up render state for any geometry to use picking IDs
+   * instead of visual colors.
+   */
+  onPickingRender(renderContext: RenderContext);
+  /**
+   * Called immediately before the main scatter plot performs a color (visual)
+   * render. Set up render state, lights, etc here.
+   */
+  onRender(renderContext: RenderContext);
+  /**
+   * Called when the canvas size changes.
+   */
+  onResize(newWidth: number, newHeight: number);
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer3DLabels.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer3DLabels.ts
new file mode 100644
index 00000000000..cbd9785e2f6
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizer3DLabels.ts
@@ -0,0 +1,367 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {RenderContext} from './renderContext';
+import {ScatterPlotVisualizer} from './scatterPlotVisualizer';
+import * as util from './util';
+
+const FONT_SIZE = 80;
+const ONE_OVER_FONT_SIZE = 1 / FONT_SIZE;
+const LABEL_SCALE = 2.2;  // at 1:1 texel/pixel ratio
+const LABEL_COLOR = 'black';
+const LABEL_BACKGROUND = 'white';
+const MAX_CANVAS_DIMENSION = 8192;
+const NUM_GLYPHS = 256;
+const RGB_ELEMENTS_PER_ENTRY = 3;
+const XYZ_ELEMENTS_PER_ENTRY = 3;
+const UV_ELEMENTS_PER_ENTRY = 2;
+const VERTICES_PER_GLYPH = 2 * 3;  // 2 triangles, 3 verts per triangle
+
+/**
+ * Each label is made up of triangles (two per letter.) Each vertex, then, is
+ * the corner of one of these triangles (and thus the corner of a letter
+ * rectangle.)
+ * Each has the following attributes:
+ *    posObj: The (x, y) position of the vertex within the label, where the
+ *            bottom center of the word is positioned at (0, 0);
+ *    position: The position of the label in worldspace.
+ *    vUv: The (u, v) coordinates that index into the glyphs sheet (range 0, 1.)
+ *    color: The color of the label (matches the cooresponding point's color.)
+ *    wordShown: Boolean. Whether or not the label is visible.
+ */
+
+const VERTEX_SHADER = `
+    attribute vec2 posObj;
+    attribute vec3 color;
+    varying vec2 vUv;
+    varying vec3 vColor;
+
+    void main() {
+      vUv = uv;
+      vColor = color;
+
+      // Rotate label to face camera.
+
+      vec4 vRight = vec4(
+        modelViewMatrix[0][0], modelViewMatrix[1][0], modelViewMatrix[2][0], 0);
+
+      vec4 vUp = vec4(
+        modelViewMatrix[0][1], modelViewMatrix[1][1], modelViewMatrix[2][1], 0);
+
+      vec4 vAt = -vec4(
+        modelViewMatrix[0][2], modelViewMatrix[1][2], modelViewMatrix[2][2], 0);
+
+      mat4 pointToCamera = mat4(vRight, vUp, vAt, vec4(0, 0, 0, 1));
+
+      vec2 scaledPos = posObj * ${ONE_OVER_FONT_SIZE} * ${LABEL_SCALE};
+
+      vec4 posRotated = pointToCamera * vec4(scaledPos, 0, 1);
+      vec4 mvPosition = modelViewMatrix * (vec4(position, 0) + posRotated);
+      gl_Position = projectionMatrix * mvPosition;
+    }`;
+
+const FRAGMENT_SHADER = `
+    uniform sampler2D texture;
+    uniform bool picking;
+    varying vec2 vUv;
+    varying vec3 vColor;
+
+    void main() {
+      if (picking) {
+        gl_FragColor = vec4(vColor, 1.0);
+      } else {
+        vec4 fromTexture = texture2D(texture, vUv);
+        gl_FragColor = vec4(vColor, 1.0) * fromTexture;
+      }
+    }`;
+
+type GlyphTexture = {
+  texture: THREE.Texture; lengths: Float32Array; offsets: Float32Array;
+};
+
+/**
+ * Renders the text labels as 3d geometry in the world.
+ */
+export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer {
+  private scene: THREE.Scene;
+  private labelStrings: string[];
+  private geometry: THREE.BufferGeometry;
+  private worldSpacePointPositions: Float32Array;
+  private pickingColors: Float32Array;
+  private renderColors: Float32Array;
+  private material: THREE.ShaderMaterial;
+  private uniforms: Object;
+  private labelsMesh: THREE.Mesh;
+  private positions: THREE.BufferAttribute;
+  private totalVertexCount: number;
+  private labelVertexMap: number[][];
+  private glyphTexture: GlyphTexture;
+
+  private createGlyphTexture(): GlyphTexture {
+    let canvas = document.createElement('canvas');
+    canvas.width = MAX_CANVAS_DIMENSION;
+    canvas.height = FONT_SIZE;
+    let ctx = canvas.getContext('2d');
+    ctx.font = 'bold ' + FONT_SIZE * 0.75 + 'px roboto';
+    ctx.textBaseline = 'top';
+    ctx.fillStyle = LABEL_BACKGROUND;
+    ctx.rect(0, 0, canvas.width, canvas.height);
+    ctx.fill();
+    ctx.fillStyle = LABEL_COLOR;
+    let spaceOffset = ctx.measureText(' ').width;
+    // For each letter, store length, position at the encoded index.
+    let glyphLengths = new Float32Array(NUM_GLYPHS);
+    let glyphOffset = new Float32Array(NUM_GLYPHS);
+    let leftCoord = 0;
+    for (let i = 0; i < NUM_GLYPHS; i++) {
+      let text = ' ' + String.fromCharCode(i);
+      let textLength = ctx.measureText(text).width;
+      glyphLengths[i] = textLength - spaceOffset;
+      glyphOffset[i] = leftCoord;
+      ctx.fillText(text, leftCoord - spaceOffset, 0);
+      leftCoord += textLength;
+    }
+    const tex = util.createTexture(canvas);
+    return {texture: tex, lengths: glyphLengths, offsets: glyphOffset};
+  }
+
+  private processLabelVerts(pointCount: number) {
+    let numTotalLetters = 0;
+    this.labelVertexMap = [];
+    for (let i = 0; i < pointCount; i++) {
+      const label = this.labelStrings[i];
+      let vertsArray: number[] = [];
+      for (let j = 0; j < label.length; j++) {
+        for (let k = 0; k < VERTICES_PER_GLYPH; k++) {
+          vertsArray.push(numTotalLetters * VERTICES_PER_GLYPH + k);
+        }
+        numTotalLetters++;
+      }
+      this.labelVertexMap.push(vertsArray);
+    }
+    this.totalVertexCount = numTotalLetters * VERTICES_PER_GLYPH;
+  }
+
+  private createColorBuffers(pointCount: number) {
+    this.pickingColors =
+        new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY);
+    this.renderColors =
+        new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY);
+    for (let i = 0; i < pointCount; i++) {
+      let color = new THREE.Color(i);
+      this.labelVertexMap[i].forEach((j) => {
+        this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j] = color.r;
+        this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = color.g;
+        this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = color.b;
+        this.renderColors[RGB_ELEMENTS_PER_ENTRY * j] = 1.0;
+        this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = 1.0;
+        this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = 1.0;
+      });
+    }
+  }
+
+  private createLabels() {
+    if ((this.labelStrings == null) ||
+        (this.worldSpacePointPositions == null)) {
+      return;
+    }
+    const pointCount =
+        this.worldSpacePointPositions.length / XYZ_ELEMENTS_PER_ENTRY;
+    if (pointCount !== this.labelStrings.length) {
+      return;
+    }
+    this.glyphTexture = this.createGlyphTexture();
+
+    this.uniforms = {
+      texture: {type: 't'},
+      picking: {type: 'bool'},
+    };
+
+    this.material = new THREE.ShaderMaterial({
+      uniforms: this.uniforms,
+      transparent: true,
+      vertexShader: VERTEX_SHADER,
+      fragmentShader: FRAGMENT_SHADER,
+    });
+
+    this.processLabelVerts(pointCount);
+    this.createColorBuffers(pointCount);
+
+    let positionArray =
+        new Float32Array(this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY);
+    this.positions =
+        new THREE.BufferAttribute(positionArray, XYZ_ELEMENTS_PER_ENTRY);
+
+    let posArray =
+        new Float32Array(this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY);
+    let uvArray =
+        new Float32Array(this.totalVertexCount * UV_ELEMENTS_PER_ENTRY);
+    let colorsArray =
+        new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY);
+    let positionObject = new THREE.BufferAttribute(posArray, 2);
+    let uv = new THREE.BufferAttribute(uvArray, UV_ELEMENTS_PER_ENTRY);
+    let colors = new THREE.BufferAttribute(colorsArray, RGB_ELEMENTS_PER_ENTRY);
+
+    this.geometry = new THREE.BufferGeometry();
+    this.geometry.addAttribute('posObj', positionObject);
+    this.geometry.addAttribute('position', this.positions);
+    this.geometry.addAttribute('uv', uv);
+    this.geometry.addAttribute('color', colors);
+
+    let lettersSoFar = 0;
+    for (let i = 0; i < pointCount; i++) {
+      const label = this.labelStrings[i];
+      let leftOffset = 0;
+      // Determine length of word in pixels.
+      for (let j = 0; j < label.length; j++) {
+        let letterCode = label.charCodeAt(j);
+        leftOffset += this.glyphTexture.lengths[letterCode];
+      }
+      leftOffset /= -2;  // centers text horizontally around the origin
+      for (let j = 0; j < label.length; j++) {
+        let letterCode = label.charCodeAt(j);
+        let letterWidth = this.glyphTexture.lengths[letterCode];
+        let scale = FONT_SIZE;
+        let right = (leftOffset + letterWidth) / scale;
+        let left = (leftOffset) / scale;
+        let top = FONT_SIZE / scale;
+
+        // First triangle
+        positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, left, 0);
+        positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, right, 0);
+        positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 2, left, top);
+
+        // Second triangle
+        positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 3, left, top);
+        positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, right, 0);
+        positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 5, right, top);
+
+        // Set UVs based on letter.
+        let uLeft = (this.glyphTexture.offsets[letterCode]);
+        let uRight = (this.glyphTexture.offsets[letterCode] + letterWidth);
+        // Scale so that uvs lie between 0 and 1 on the texture.
+        uLeft /= MAX_CANVAS_DIMENSION;
+        uRight /= MAX_CANVAS_DIMENSION;
+        let vTop = 1;
+        let vBottom = 0;
+        uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, uLeft, vTop);
+        uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, uRight, vTop);
+        uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 2, uLeft, vBottom);
+        uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 3, uLeft, vBottom);
+        uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, uRight, vTop);
+        uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 5, uRight, vBottom);
+
+        lettersSoFar++;
+        leftOffset += letterWidth;
+      }
+    }
+
+    for (let i = 0; i < pointCount; i++) {
+      const p = util.vector3FromPackedArray(this.worldSpacePointPositions, i);
+      this.labelVertexMap[i].forEach((j) => {
+        this.positions.setXYZ(j, p.x, p.y, p.z);
+      });
+    };
+
+    this.labelsMesh = new THREE.Mesh(this.geometry, this.material);
+    this.labelsMesh.frustumCulled = false;
+    this.scene.add(this.labelsMesh);
+  }
+
+  private colorLabels(pointColors: Float32Array) {
+    if (this.labelStrings == null || this.geometry == null ||
+        pointColors == null) {
+      return;
+    }
+
+    const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute;
+    colors.array = this.renderColors;
+
+    const n = pointColors.length / XYZ_ELEMENTS_PER_ENTRY;
+    let src = 0;
+    for (let i = 0; i < n; ++i) {
+      const c = new THREE.Color(
+          pointColors[src], pointColors[src + 1], pointColors[src + 2]);
+      const m = this.labelVertexMap[i].length;
+      for (let j = 0; j < m; ++j) {
+        colors.setXYZ(this.labelVertexMap[i][j], c.r, c.g, c.b);
+      }
+      src += RGB_ELEMENTS_PER_ENTRY;
+    }
+    colors.needsUpdate = true;
+  }
+
+  setScene(scene: THREE.Scene) {
+    this.scene = scene;
+  }
+
+  dispose() {
+    if (this.labelsMesh) {
+      if (this.scene) {
+        this.scene.remove(this.labelsMesh);
+      }
+      this.labelsMesh = null;
+    }
+    if (this.geometry) {
+      this.geometry.dispose();
+      this.geometry = null;
+    }
+    if ((this.glyphTexture != null) && (this.glyphTexture.texture != null)) {
+      this.glyphTexture.texture.dispose();
+      this.glyphTexture.texture = null;
+    }
+  }
+
+  onPickingRender(rc: RenderContext) {
+    if (this.geometry == null) {
+      this.createLabels();
+    }
+    if (this.geometry == null) {
+      return;
+    }
+    this.material.uniforms.texture.value = this.glyphTexture.texture;
+    this.material.uniforms.picking.value = true;
+    const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute;
+    colors.array = this.pickingColors;
+    colors.needsUpdate = true;
+  }
+
+  onRender(rc: RenderContext) {
+    if (this.geometry == null) {
+      this.createLabels();
+    }
+    if (this.geometry == null) {
+      return;
+    }
+    this.colorLabels(rc.pointColors);
+    this.material.uniforms.texture.value = this.glyphTexture.texture;
+    this.material.uniforms.picking.value = false;
+    const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute;
+    colors.array = this.renderColors;
+    colors.needsUpdate = true;
+  }
+
+  onPointPositionsChanged(newPositions: Float32Array) {
+    this.worldSpacePointPositions = newPositions;
+    this.dispose();
+  }
+
+  setLabelStrings(labelStrings: string[]) {
+    this.labelStrings = labelStrings;
+    this.dispose();
+  }
+
+  onResize(newWidth: number, newHeight: number) {}
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerCanvasLabels.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerCanvasLabels.ts
new file mode 100644
index 00000000000..ece4d84ef28
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerCanvasLabels.ts
@@ -0,0 +1,187 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import * as d3 from 'd3';  // from //third_party/javascript/typings/d3_v4
+import {BoundingBox, CollisionGrid} from './label';
+import {CameraType, RenderContext} from './renderContext';
+import {ScatterPlotVisualizer} from './scatterPlotVisualizer';
+import * as util from './util';
+
+const MAX_LABELS_ON_SCREEN = 10000;
+const LABEL_STROKE_WIDTH = 3;
+const LABEL_FILL_WIDTH = 6;
+
+/**
+ * Creates and maintains a 2d canvas on top of the GL canvas. All labels, when
+ * active, are rendered to the 2d canvas as part of the visible render pass.
+ */
+export class ScatterPlotVisualizerCanvasLabels implements
+    ScatterPlotVisualizer {
+  private worldSpacePointPositions: Float32Array;
+  private gc: CanvasRenderingContext2D;
+  private canvas: HTMLCanvasElement;
+  private labelsActive: boolean = true;
+
+  constructor(container: HTMLElement) {
+    this.canvas = document.createElement('canvas');
+    container.appendChild(this.canvas);
+
+    this.gc = this.canvas.getContext('2d');
+    this.canvas.style.position = 'absolute';
+    this.canvas.style.left = '0';
+    this.canvas.style.top = '0';
+    this.canvas.style.pointerEvents = 'none';
+  }
+
+  private removeAllLabels() {
+    const pixelWidth = this.canvas.width * window.devicePixelRatio;
+    const pixelHeight = this.canvas.height * window.devicePixelRatio;
+    this.gc.clearRect(0, 0, pixelWidth, pixelHeight);
+  }
+
+  /** Render all of the non-overlapping visible labels to the canvas. */
+  private makeLabels(rc: RenderContext) {
+    if ((rc.labels == null) || (rc.labels.pointIndices.length === 0)) {
+      return;
+    }
+    if (this.worldSpacePointPositions == null) {
+      return;
+    }
+
+    const lrc = rc.labels;
+    const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective);
+    const labelHeight = parseInt(this.gc.font, 10);
+    const dpr = window.devicePixelRatio;
+
+    let grid: CollisionGrid;
+    {
+      const pixw = this.canvas.width * dpr;
+      const pixh = this.canvas.height * dpr;
+      const bb: BoundingBox = {loX: 0, hiX: pixw, loY: 0, hiY: pixh};
+      grid = new CollisionGrid(bb, pixw / 25, pixh / 50);
+    }
+
+    let opacityMap =
+        d3.scalePow()
+            .exponent(Math.E)
+            .domain([rc.farthestCameraSpacePointZ, rc.nearestCameraSpacePointZ])
+            .range([0.1, 1]);
+
+    const camPos = rc.camera.position;
+    const camToTarget = camPos.clone().sub(rc.cameraTarget);
+    let camToPoint = new THREE.Vector3();
+
+    this.gc.textBaseline = 'middle';
+    this.gc.miterLimit = 2;
+
+    // Have extra space between neighboring labels. Don't pack too tightly.
+    const labelMargin = 2;
+    // Shift the label to the right of the point circle.
+    const xShift = 4;
+
+    const n = Math.min(MAX_LABELS_ON_SCREEN, lrc.pointIndices.length);
+    for (let i = 0; i < n; ++i) {
+      let point: THREE.Vector3;
+      {
+        const pi = lrc.pointIndices[i];
+        point = util.vector3FromPackedArray(this.worldSpacePointPositions, pi);
+      }
+
+      // discard points that are behind the camera
+      camToPoint.copy(camPos).sub(point);
+      if (camToTarget.dot(camToPoint) < 0) {
+        continue;
+      }
+
+      let [x, y] = util.vector3DToScreenCoords(
+          rc.camera, rc.screenWidth, rc.screenHeight, point);
+      x += xShift;
+
+      // Computing the width of the font is expensive,
+      // so we assume width of 1 at first. Then, if the label doesn't
+      // conflict with other labels, we measure the actual width.
+      const textBoundingBox: BoundingBox = {
+        loX: x - labelMargin,
+        hiX: x + 1 + labelMargin,
+        loY: y - labelHeight / 2 - labelMargin,
+        hiY: y + labelHeight / 2 + labelMargin
+      };
+
+      if (grid.insert(textBoundingBox, true)) {
+        const text = lrc.labelStrings[i];
+        const fontSize = lrc.defaultFontSize * lrc.scaleFactors[i] * dpr;
+        this.gc.font = fontSize + 'px roboto';
+
+        // Now, check with properly computed width.
+        textBoundingBox.hiX += this.gc.measureText(text).width - 1;
+        if (grid.insert(textBoundingBox)) {
+          let opacity = 1;
+          if (sceneIs3D && (lrc.useSceneOpacityFlags[i] === 1)) {
+            opacity = opacityMap(camToPoint.length());
+          }
+          this.gc.fillStyle =
+              this.styleStringFromPackedRgba(lrc.fillColors, i, opacity);
+          this.gc.strokeStyle =
+              this.styleStringFromPackedRgba(lrc.strokeColors, i, opacity);
+          this.gc.lineWidth = LABEL_STROKE_WIDTH;
+          this.gc.strokeText(text, x, y);
+          this.gc.lineWidth = LABEL_FILL_WIDTH;
+          this.gc.fillText(text, x, y);
+        }
+      }
+    }
+  }
+
+  private styleStringFromPackedRgba(
+      packedRgbaArray: Uint8Array, colorIndex: number,
+      opacity: number): string {
+    const offset = colorIndex * 3;
+    const r = packedRgbaArray[offset];
+    const g = packedRgbaArray[offset + 1];
+    const b = packedRgbaArray[offset + 2];
+    return 'rgba(' + r + ',' + g + ',' + b + ',' + opacity + ')';
+  }
+
+  onResize(newWidth: number, newHeight: number) {
+    let dpr = window.devicePixelRatio;
+    this.canvas.width = newWidth * dpr;
+    this.canvas.height = newHeight * dpr;
+    this.canvas.style.width = newWidth + 'px';
+    this.canvas.style.height = newHeight + 'px';
+  }
+
+  dispose() {
+    this.removeAllLabels();
+    this.canvas = null;
+    this.gc = null;
+  }
+
+  onPointPositionsChanged(newPositions: Float32Array) {
+    this.worldSpacePointPositions = newPositions;
+    this.removeAllLabels();
+  }
+
+  onRender(rc: RenderContext) {
+    if (!this.labelsActive) {
+      return;
+    }
+
+    this.removeAllLabels();
+    this.makeLabels(rc);
+  }
+
+  setScene(scene: THREE.Scene) {}
+  onPickingRender(renderContext: RenderContext) {}
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerPolylines.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerPolylines.ts
new file mode 100644
index 00000000000..e6d4aeda28b
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerPolylines.ts
@@ -0,0 +1,149 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {DataSet} from './data';
+import {RenderContext} from './renderContext';
+import {ScatterPlotVisualizer} from './scatterPlotVisualizer';
+import * as util from './util';
+
+const RGB_NUM_ELEMENTS = 3;
+const XYZ_NUM_ELEMENTS = 3;
+
+/**
+ * Renders polylines that connect multiple points in the dataset.
+ */
+export class ScatterPlotVisualizerPolylines implements ScatterPlotVisualizer {
+  private dataSet: DataSet;
+  private scene: THREE.Scene;
+  private polylines: THREE.Line[];
+  private polylinePositionBuffer:
+      {[polylineIndex: number]: THREE.BufferAttribute} = {};
+  private polylineColorBuffer:
+      {[polylineIndex: number]: THREE.BufferAttribute} = {};
+
+  private updateSequenceIndicesInDataSet(ds: DataSet) {
+    for (let i = 0; i < ds.sequences.length; i++) {
+      const sequence = ds.sequences[i];
+      for (let j = 0; j < sequence.pointIndices.length - 1; j++) {
+        ds.points[sequence.pointIndices[j]].sequenceIndex = i;
+        ds.points[sequence.pointIndices[j + 1]].sequenceIndex = i;
+      }
+    }
+  }
+
+  private createPolylines(scene: THREE.Scene) {
+    if (!this.dataSet || !this.dataSet.sequences) {
+      return;
+    }
+
+    this.updateSequenceIndicesInDataSet(this.dataSet);
+    this.polylines = [];
+
+    for (let i = 0; i < this.dataSet.sequences.length; i++) {
+      const geometry = new THREE.BufferGeometry();
+      geometry.addAttribute('position', this.polylinePositionBuffer[i]);
+      geometry.addAttribute('color', this.polylineColorBuffer[i]);
+
+      const material = new THREE.LineBasicMaterial({
+        linewidth: 1,  // unused default, overwritten by width array.
+        opacity: 1.0,  // unused default, overwritten by opacity array.
+        transparent: true,
+        vertexColors: THREE.VertexColors
+      });
+
+      const polyline = new THREE.LineSegments(geometry, material);
+      polyline.frustumCulled = false;
+      this.polylines.push(polyline);
+      scene.add(polyline);
+    }
+  }
+
+  dispose() {
+    if (this.polylines == null) {
+      return;
+    }
+    for (let i = 0; i < this.polylines.length; i++) {
+      this.scene.remove(this.polylines[i]);
+      this.polylines[i].geometry.dispose();
+    }
+    this.polylines = null;
+    this.polylinePositionBuffer = {};
+    this.polylineColorBuffer = {};
+  }
+
+  setScene(scene: THREE.Scene) {
+    this.scene = scene;
+  }
+
+  setDataSet(dataSet: DataSet) {
+    this.dataSet = dataSet;
+  }
+
+  onPointPositionsChanged(newPositions: Float32Array) {
+    if ((newPositions == null) || (this.polylines != null)) {
+      this.dispose();
+    }
+    if ((newPositions == null) || (this.dataSet == null)) {
+      return;
+    }
+    // Set up the position buffer arrays for each polyline.
+    for (let i = 0; i < this.dataSet.sequences.length; i++) {
+      let sequence = this.dataSet.sequences[i];
+      const vertexCount = 2 * (sequence.pointIndices.length - 1);
+
+      let polylines = new Float32Array(vertexCount * XYZ_NUM_ELEMENTS);
+      this.polylinePositionBuffer[i] =
+          new THREE.BufferAttribute(polylines, XYZ_NUM_ELEMENTS);
+
+      let colors = new Float32Array(vertexCount * RGB_NUM_ELEMENTS);
+      this.polylineColorBuffer[i] =
+          new THREE.BufferAttribute(colors, RGB_NUM_ELEMENTS);
+    }
+    for (let i = 0; i < this.dataSet.sequences.length; i++) {
+      const sequence = this.dataSet.sequences[i];
+      let src = 0;
+      for (let j = 0; j < sequence.pointIndices.length - 1; j++) {
+        const p1Index = sequence.pointIndices[j];
+        const p2Index = sequence.pointIndices[j + 1];
+        const p1 = util.vector3FromPackedArray(newPositions, p1Index);
+        const p2 = util.vector3FromPackedArray(newPositions, p2Index);
+        this.polylinePositionBuffer[i].setXYZ(src, p1.x, p1.y, p1.z);
+        this.polylinePositionBuffer[i].setXYZ(src + 1, p2.x, p2.y, p2.z);
+        src += 2;
+      }
+      this.polylinePositionBuffer[i].needsUpdate = true;
+    }
+
+    if (this.polylines == null) {
+      this.createPolylines(this.scene);
+    }
+  }
+
+  onRender(renderContext: RenderContext) {
+    if (this.polylines == null) {
+      return;
+    }
+    for (let i = 0; i < this.polylines.length; i++) {
+      this.polylines[i].material.opacity = renderContext.polylineOpacities[i];
+      (this.polylines[i].material as THREE.LineBasicMaterial).linewidth =
+          renderContext.polylineWidths[i];
+      this.polylineColorBuffer[i].array = renderContext.polylineColors[i];
+      this.polylineColorBuffer[i].needsUpdate = true;
+    }
+  }
+
+  onPickingRender(renderContext: RenderContext) {}
+  onResize(newWidth: number, newHeight: number) {}
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerSprites.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerSprites.ts
new file mode 100644
index 00000000000..8adc9a9bd23
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/scatterPlotVisualizerSprites.ts
@@ -0,0 +1,435 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {CameraType, RenderContext} from './renderContext';
+import {ScatterPlotVisualizer} from './scatterPlotVisualizer';
+import * as util from './util';
+
+const NUM_POINTS_FOG_THRESHOLD = 5000;
+const MIN_POINT_SIZE = 5.0;
+const IMAGE_SIZE = 30;
+
+// Constants relating to the indices of buffer arrays.
+const RGB_NUM_ELEMENTS = 3;
+const INDEX_NUM_ELEMENTS = 1;
+const XYZ_NUM_ELEMENTS = 3;
+
+const VERTEX_SHADER = `
+  // Index of the specific vertex (passed in as bufferAttribute), and the
+  // variable that will be used to pass it to the fragment shader.
+  attribute float spriteIndex;
+  attribute vec3 color;
+  attribute float scaleFactor;
+
+  varying vec2 xyIndex;
+  varying vec3 vColor;
+
+  uniform bool sizeAttenuation;
+  uniform float pointSize;
+  uniform float spritesPerRow;
+  uniform float spritesPerColumn;
+
+  void main() {
+    // Pass index and color values to fragment shader.
+    vColor = color;
+    xyIndex = vec2(mod(spriteIndex, spritesPerRow),
+              floor(spriteIndex / spritesPerColumn));
+
+    // Transform current vertex by modelViewMatrix (model world position and
+    // camera world position matrix).
+    vec4 cameraSpacePos = modelViewMatrix * vec4(position, 1.0);
+
+    // Project vertex in camera-space to screen coordinates using the camera's
+    // projection matrix.
+    gl_Position = projectionMatrix * cameraSpacePos;
+
+    // Create size attenuation (if we're in 3D mode) by making the size of
+    // each point inversly proportional to its distance to the camera.
+    float outputPointSize = pointSize;
+    if (sizeAttenuation) {
+      outputPointSize = -pointSize / cameraSpacePos.z;
+    }
+
+    gl_PointSize =
+      max(outputPointSize * scaleFactor, ${MIN_POINT_SIZE.toFixed(1)});
+  }`;
+
+const FRAGMENT_SHADER_POINT_TEST_CHUNK = `
+  bool point_in_unit_circle(vec2 spriteCoord) {
+    vec2 centerToP = spriteCoord - vec2(0.5, 0.5);
+    return dot(centerToP, centerToP) < (0.5 * 0.5);
+  }
+
+  bool point_in_unit_equilateral_triangle(vec2 spriteCoord) {
+    vec3 v0 = vec3(0, 1, 0);
+    vec3 v1 = vec3(0.5, 0, 0);
+    vec3 v2 = vec3(1, 1, 0);
+    vec3 p = vec3(spriteCoord, 0);
+    float p_in_v0_v1 = cross(v1 - v0, p - v0).z;
+    float p_in_v1_v2 = cross(v2 - v1, p - v1).z;
+    return (p_in_v0_v1 > 0.0) && (p_in_v1_v2 > 0.0);
+  }
+
+  bool point_in_unit_square(vec2 spriteCoord) {
+    return true;
+  }
+`;
+
+const FRAGMENT_SHADER = `
+  varying vec2 xyIndex;
+  varying vec3 vColor;
+
+  uniform sampler2D texture;
+  uniform float spritesPerRow;
+  uniform float spritesPerColumn;
+  uniform bool isImage;
+
+  ${THREE.ShaderChunk['common']}
+  ${THREE.ShaderChunk['fog_pars_fragment']}
+  ${FRAGMENT_SHADER_POINT_TEST_CHUNK}
+
+  void main() {
+    if (isImage) {
+      // Coordinates of the vertex within the entire sprite image.
+      vec2 coords =
+        (gl_PointCoord + xyIndex) / vec2(spritesPerRow, spritesPerColumn);
+      gl_FragColor = vec4(vColor, 1.0) * texture2D(texture, coords);
+    } else {
+      bool inside = point_in_unit_circle(gl_PointCoord);
+      if (!inside) {
+        discard;
+      }
+      gl_FragColor = vec4(vColor, 1);
+    }
+    ${THREE.ShaderChunk['fog_fragment']}
+  }`;
+
+const FRAGMENT_SHADER_PICKING = `
+  varying vec2 xyIndex;
+  varying vec3 vColor;
+  uniform bool isImage;
+
+  ${FRAGMENT_SHADER_POINT_TEST_CHUNK}
+
+  void main() {
+    xyIndex; // Silence 'unused variable' warning.
+    if (isImage) {
+      gl_FragColor = vec4(vColor, 1);
+    } else {
+      bool inside = point_in_unit_circle(gl_PointCoord);
+      if (!inside) {
+        discard;
+      }
+      gl_FragColor = vec4(vColor, 1);
+    }
+  }`;
+
+/**
+ * Uses GL point sprites to render the dataset.
+ */
+export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
+  private scene: THREE.Scene;
+  private fog: THREE.Fog;
+  private texture: THREE.Texture = null;
+  private standinTextureForPoints: THREE.Texture;
+  private spritesPerRow: number;
+  private spritesPerColumn: number;
+  private spriteDimensions: [number, number];
+  private spriteIndexBufferAttribute: THREE.BufferAttribute;
+  private renderMaterial: THREE.ShaderMaterial;
+  private pickingMaterial: THREE.ShaderMaterial;
+
+  private points: THREE.Points;
+  private worldSpacePointPositions: Float32Array;
+  private pickingColors: Float32Array;
+  private renderColors: Float32Array;
+
+  constructor() {
+    this.standinTextureForPoints =
+        util.createTexture(document.createElement('canvas'));
+    this.renderMaterial = this.createRenderMaterial(false);
+    this.pickingMaterial = this.createPickingMaterial(false);
+  }
+
+  private createTextureFromSpriteAtlas(
+      spriteAtlas: HTMLImageElement, spriteDimensions: [number, number],
+      spriteIndices: Float32Array) {
+    this.texture = util.createTexture(spriteAtlas);
+    this.spritesPerRow = spriteAtlas.width / spriteDimensions[0];
+    this.spritesPerColumn = spriteAtlas.height / spriteDimensions[1];
+    this.spriteDimensions = spriteDimensions;
+    this.spriteIndexBufferAttribute =
+        new THREE.BufferAttribute(spriteIndices, INDEX_NUM_ELEMENTS);
+
+    if (this.points != null) {
+      (this.points.geometry as THREE.BufferGeometry)
+          .addAttribute('spriteIndex', this.spriteIndexBufferAttribute);
+    }
+  }
+
+  private createUniforms(): any {
+    return {
+      texture: {type: 't'},
+      spritesPerRow: {type: 'f'},
+      spritesPerColumn: {type: 'f'},
+      fogColor: {type: 'c'},
+      fogNear: {type: 'f'},
+      fogFar: {type: 'f'},
+      isImage: {type: 'bool'},
+      sizeAttenuation: {type: 'bool'},
+      pointSize: {type: 'f'}
+    };
+  }
+
+  private createRenderMaterial(haveImage: boolean): THREE.ShaderMaterial {
+    const uniforms = this.createUniforms();
+    return new THREE.ShaderMaterial({
+      uniforms: uniforms,
+      vertexShader: VERTEX_SHADER,
+      fragmentShader: FRAGMENT_SHADER,
+      transparent: !haveImage,
+      depthTest: haveImage,
+      depthWrite: haveImage,
+      fog: true,
+      blending: THREE.MultiplyBlending,
+    });
+  }
+
+  private createPickingMaterial(haveImage: boolean): THREE.ShaderMaterial {
+    const uniforms = this.createUniforms();
+    return new THREE.ShaderMaterial({
+      uniforms: uniforms,
+      vertexShader: VERTEX_SHADER,
+      fragmentShader: FRAGMENT_SHADER_PICKING,
+      transparent: true,
+      depthTest: true,
+      depthWrite: true,
+      fog: false,
+      blending: THREE.NormalBlending,
+    });
+  }
+
+  /**
+   * Create points, set their locations and actually instantiate the
+   * geometry.
+   */
+  private createPointSprites(scene: THREE.Scene, positions: Float32Array) {
+    const pointCount =
+        (positions != null) ? (positions.length / XYZ_NUM_ELEMENTS) : 0;
+    const geometry = this.createGeometry(pointCount);
+
+    this.fog = new THREE.Fog(0xFFFFFF);  // unused value, gets overwritten.
+
+    this.points = new THREE.Points(geometry, this.renderMaterial);
+    this.points.frustumCulled = false;
+    if (this.spriteIndexBufferAttribute != null) {
+      (this.points.geometry as THREE.BufferGeometry)
+          .addAttribute('spriteIndex', this.spriteIndexBufferAttribute);
+    }
+    scene.add(this.points);
+  }
+
+  private calculatePointSize(sceneIs3D: boolean): number {
+    if (this.texture != null) {
+      return sceneIs3D ? IMAGE_SIZE : this.spriteDimensions[0];
+    }
+    const n = (this.worldSpacePointPositions != null) ?
+        (this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS) :
+        1;
+    const SCALE = 200;
+    const LOG_BASE = 8;
+    const DIVISOR = 1.5;
+    // Scale point size inverse-logarithmically to the number of points.
+    const pointSize = SCALE / Math.log(n) / Math.log(LOG_BASE);
+    return sceneIs3D ? pointSize : (pointSize / DIVISOR);
+  }
+
+  /**
+   * Set up buffer attributes to be used for the points/images.
+   */
+  private createGeometry(pointCount: number): THREE.BufferGeometry {
+    const n = pointCount;
+
+    // Fill pickingColors with each point's unique id as its color.
+    this.pickingColors = new Float32Array(n * RGB_NUM_ELEMENTS);
+    {
+      let dst = 0;
+      for (let i = 0; i < n; i++) {
+        const c = new THREE.Color(i);
+        this.pickingColors[dst++] = c.r;
+        this.pickingColors[dst++] = c.g;
+        this.pickingColors[dst++] = c.b;
+      }
+    }
+
+    const geometry = new THREE.BufferGeometry();
+    geometry.addAttribute(
+        'position', new THREE.BufferAttribute(null, XYZ_NUM_ELEMENTS));
+    geometry.addAttribute(
+        'color', new THREE.BufferAttribute(null, RGB_NUM_ELEMENTS));
+    geometry.addAttribute(
+        'scaleFactor', new THREE.BufferAttribute(null, INDEX_NUM_ELEMENTS));
+    return geometry;
+  }
+
+  private setFogDistances(
+      sceneIs3D: boolean, nearestPointZ: number, farthestPointZ: number) {
+    if (sceneIs3D) {
+      const n = this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS;
+      this.fog.near = nearestPointZ;
+      // If there are fewer points we want less fog. We do this
+      // by making the "far" value (that is, the distance from the camera to the
+      // far edge of the fog) proportional to the number of points.
+      let multiplier =
+          2 - Math.min(n, NUM_POINTS_FOG_THRESHOLD) / NUM_POINTS_FOG_THRESHOLD;
+      this.fog.far = farthestPointZ * multiplier;
+    } else {
+      this.fog.near = Infinity;
+      this.fog.far = Infinity;
+    }
+  }
+
+  dispose() {
+    this.disposeGeometry();
+    this.disposeTextureAtlas();
+  }
+
+  private disposeGeometry() {
+    if (this.points != null) {
+      this.scene.remove(this.points);
+      this.points.geometry.dispose();
+      this.points = null;
+      this.worldSpacePointPositions = null;
+    }
+  }
+
+  private disposeTextureAtlas() {
+    if (this.texture != null) {
+      this.texture.dispose();
+    }
+    this.texture = null;
+    this.renderMaterial = null;
+    this.pickingMaterial = null;
+  }
+
+  setScene(scene: THREE.Scene) {
+    this.scene = scene;
+  }
+
+  setSpriteAtlas(
+      spriteImage: HTMLImageElement, spriteDimensions: [number, number],
+      spriteIndices: Uint8Array) {
+    this.disposeTextureAtlas();
+    this.createTextureFromSpriteAtlas(
+        spriteImage, spriteDimensions, spriteIndices);
+    this.renderMaterial = this.createRenderMaterial(true);
+    this.pickingMaterial = this.createPickingMaterial(true);
+  }
+
+  clearSpriteAtlas() {
+    this.disposeTextureAtlas();
+    this.renderMaterial = this.createRenderMaterial(false);
+    this.pickingMaterial = this.createPickingMaterial(false);
+  }
+
+  onPointPositionsChanged(newPositions: Float32Array) {
+    if ((newPositions == null) || (newPositions.length === 0)) {
+      this.dispose();
+      return;
+    }
+    if (this.points != null) {
+      if (this.worldSpacePointPositions.length !== newPositions.length) {
+        this.disposeGeometry();
+      }
+    }
+
+    this.worldSpacePointPositions = newPositions;
+
+    if (this.points == null) {
+      this.createPointSprites(this.scene, newPositions);
+    }
+
+    const positions = (this.points.geometry as THREE.BufferGeometry)
+                          .getAttribute('position') as THREE.BufferAttribute;
+    positions.array = newPositions;
+    positions.needsUpdate = true;
+  }
+
+  onPickingRender(rc: RenderContext) {
+    if (this.points == null) {
+      return;
+    }
+
+    const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective);
+
+    this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerRow;
+    this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerColumn;
+    this.pickingMaterial.uniforms.sizeAttenuation.value = sceneIs3D;
+    this.pickingMaterial.uniforms.pointSize.value =
+        this.calculatePointSize(sceneIs3D);
+    this.points.material = this.pickingMaterial;
+
+    let colors = (this.points.geometry as THREE.BufferGeometry)
+                     .getAttribute('color') as THREE.BufferAttribute;
+    colors.array = this.pickingColors;
+    colors.needsUpdate = true;
+
+    let scaleFactors =
+        (this.points.geometry as THREE.BufferGeometry)
+            .getAttribute('scaleFactor') as THREE.BufferAttribute;
+    scaleFactors.array = rc.pointScaleFactors;
+    scaleFactors.needsUpdate = true;
+  }
+
+  onRender(rc: RenderContext) {
+    if (!this.points) {
+      return;
+    }
+    const sceneIs3D: boolean = (rc.camera instanceof THREE.PerspectiveCamera);
+
+    this.setFogDistances(
+        sceneIs3D, rc.nearestCameraSpacePointZ, rc.farthestCameraSpacePointZ);
+
+    this.scene.fog = this.fog;
+    this.scene.fog.color = new THREE.Color(rc.backgroundColor);
+
+    this.renderMaterial.uniforms.fogColor.value = this.scene.fog.color;
+    this.renderMaterial.uniforms.fogNear.value = this.fog.near;
+    this.renderMaterial.uniforms.fogFar.value = this.fog.far;
+    this.renderMaterial.uniforms.spritesPerRow.value = this.spritesPerRow;
+    this.renderMaterial.uniforms.spritesPerColumn.value = this.spritesPerColumn;
+    this.renderMaterial.uniforms.isImage.value = (this.texture != null);
+    this.renderMaterial.uniforms.texture.value =
+        (this.texture != null) ? this.texture : this.standinTextureForPoints;
+    this.renderMaterial.uniforms.sizeAttenuation.value = sceneIs3D;
+    this.renderMaterial.uniforms.pointSize.value =
+        this.calculatePointSize(sceneIs3D);
+    this.points.material = this.renderMaterial;
+
+    let colors = (this.points.geometry as THREE.BufferGeometry)
+                     .getAttribute('color') as THREE.BufferAttribute;
+    this.renderColors = rc.pointColors;
+    colors.array = this.renderColors;
+    colors.needsUpdate = true;
+
+    let scaleFactors =
+        (this.points.geometry as THREE.BufferGeometry)
+            .getAttribute('scaleFactor') as THREE.BufferAttribute;
+    scaleFactors.array = rc.pointScaleFactors;
+    scaleFactors.needsUpdate = true;
+  }
+
+  onResize(newWidth: number, newHeight: number) {}
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/sptree.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/sptree.ts
new file mode 100644
index 00000000000..991369a3352
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/sptree.ts
@@ -0,0 +1,175 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/** N-dimensional point. Usually 2D or 3D. */
+export type Point = number[];
+
+export interface BBox {
+  center: Point;
+  halfDim: number;
+}
+
+/** A node in a space-partitioning tree. */
+export interface SPNode {
+  /** The children of this node. */
+  children?: SPNode[];
+  /** The bounding box of the region this node occupies. */
+  box: BBox;
+  /** One or more points this node has. */
+  point: Point;
+}
+
+/**
+ * A Space-partitioning tree (https://en.wikipedia.org/wiki/Space_partitioning)
+ * that recursively divides the space into regions of equal sizes. This data
+ * structure can act both as a Quad tree and an Octree when the data is 2 or
+ * 3 dimensional respectively. One usage is in t-SNE in order to do Barnes-Hut
+ * approximation.
+ */
+export class SPTree {
+  root: SPNode;
+
+  private masks: number[];
+  private dim: number;
+
+  /**
+   * Constructs a new tree with the provided data.
+   *
+   * @param data List of n-dimensional data points.
+   * @param capacity Number of data points to store in a single node.
+   */
+  constructor(data: Point[]) {
+    if (data.length < 1) {
+      throw new Error('There should be at least 1 data point');
+    }
+    // Make a bounding box based on the extent of the data.
+    this.dim = data[0].length;
+    // Each node has 2^d children, where d is the dimension of the space.
+    // Binary masks (e.g. 000, 001, ... 111 in 3D) are used to determine in
+    // which child (e.g. quadron in 2D) the new point is going to be assigned.
+    // For more details, see the insert() method and its comments.
+    this.masks = new Array(Math.pow(2, this.dim));
+    for (let d = 0; d < this.masks.length; ++d) {
+      this.masks[d] = (1 << d);
+    }
+    let min: Point = new Array(this.dim);
+    fillArray(min, Number.POSITIVE_INFINITY);
+    let max: Point = new Array(this.dim);
+    fillArray(max, Number.NEGATIVE_INFINITY);
+
+    for (let i = 0; i < data.length; ++i) {
+      // For each dim get the min and max.
+      // E.g. For 2-D, get the x_min, x_max, y_min, y_max.
+      for (let d = 0; d < this.dim; ++d) {
+        min[d] = Math.min(min[d], data[i][d]);
+        max[d] = Math.max(max[d], data[i][d]);
+      }
+    }
+    // Create a bounding box with the center of the largest span.
+    let center: Point = new Array(this.dim);
+    let halfDim = 0;
+    for (let d = 0; d < this.dim; ++d) {
+      let span = max[d] - min[d];
+      center[d] = min[d] + span / 2;
+      halfDim = Math.max(halfDim, span / 2);
+    }
+    this.root = {box: {center: center, halfDim: halfDim}, point: data[0]};
+    for (let i = 1; i < data.length; ++i) {
+      this.insert(this.root, data[i]);
+    }
+  }
+
+  /**
+   * Visits every node in the tree. Each node can store 1 or more points,
+   * depending on the node capacity provided in the constructor.
+   *
+   * @param accessor Method that takes the currently visited node, and the
+   * low and high point of the region that this node occupies. E.g. in 2D,
+   * the low and high points will be the lower-left corner and the upper-right
+   * corner.
+   */
+  visit(
+      accessor: (node: SPNode, lowPoint: Point, highPoint: Point) => boolean,
+      noBox = false) {
+    this.visitNode(this.root, accessor, noBox);
+  }
+
+  private visitNode(
+      node: SPNode,
+      accessor: (node: SPNode, lowPoint?: Point, highPoint?: Point) => boolean,
+      noBox: boolean) {
+    let skipChildren: boolean;
+    if (noBox) {
+      skipChildren = accessor(node);
+    } else {
+      let lowPoint = new Array(this.dim);
+      let highPoint = new Array(this.dim);
+      for (let d = 0; d < this.dim; ++d) {
+        lowPoint[d] = node.box.center[d] - node.box.halfDim;
+        highPoint[d] = node.box.center[d] + node.box.halfDim;
+      }
+      skipChildren = accessor(node, lowPoint, highPoint);
+    }
+    if (!node.children || skipChildren) {
+      return;
+    }
+    for (let i = 0; i < node.children.length; ++i) {
+      let child = node.children[i];
+      if (child) {
+        this.visitNode(child, accessor, noBox);
+      }
+    }
+  }
+
+  private insert(node: SPNode, p: Point) {
+    // Subdivide and then add the point to whichever node will accept it.
+    if (node.children == null) {
+      node.children = new Array(this.masks.length);
+    }
+
+    // Decide which child will get the new point by constructing a D-bits binary
+    // signature (D=3 for 3D) where the k-th bit is 1 if the point's k-th
+    // coordinate is greater than the node's k-th coordinate, 0 otherwise.
+    // Then the binary signature in decimal system gives us the index of the
+    // child where the new point should be.
+    let index = 0;
+    for (let d = 0; d < this.dim; ++d) {
+      if (p[d] > node.box.center[d]) {
+        index |= this.masks[d];
+      }
+    }
+    if (node.children[index] == null) {
+      this.makeChild(node, index, p);
+    } else {
+      this.insert(node.children[index], p);
+    }
+  }
+
+  private makeChild(node: SPNode, index: number, p: Point): void {
+    let oldC = node.box.center;
+    let h = node.box.halfDim / 2;
+    let newC: Point = new Array(this.dim);
+    for (let d = 0; d < this.dim; ++d) {
+      newC[d] = (index & (1 << d)) ? oldC[d] + h : oldC[d] - h;
+    }
+    node.children[index] = {box: {center: newC, halfDim: h}, point: p};
+  }
+}
+
+function fillArray<T>(arr: T[], value: T): void {
+  for (let i = 0; i < arr.length; ++i) {
+    arr[i] = value;
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/sptree_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/sptree_test.ts
new file mode 100644
index 00000000000..440680bdf1e
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/sptree_test.ts
@@ -0,0 +1,104 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {SPTree} from './sptree';
+
+const assert = chai.assert;
+
+it('simple 2D data', () => {
+  let data = [
+    [0, 1],
+    [1, 0],
+    [1, 1],
+    [0, 0],
+  ];
+  let tree = new SPTree(data);
+  // Check that each point is within the bound.
+  tree.visit((node, low, high) => {
+    assert.equal(low.length, 2);
+    assert.equal(high.length, 2);
+    let point = node.point;
+    assert.equal(point.length, 2);
+    // Each point should be in the node's bounding box.
+    assert.equal(
+        point[0] >= low[0] && point[0] <= high[0] && point[1] >= low[1] &&
+            point[1] <= high[1],
+        true);
+    return false;
+  });
+});
+
+it('simple 3D data', () => {
+  let data = [
+    [0, 1, 0],
+    [1, 0.4, 2],
+    [1, 1, 3],
+    [0, 0, 5],
+  ];
+  let tree = new SPTree(data);
+  // Check that each point is within the bound.
+  tree.visit((node, low, high) => {
+    assert.equal(low.length, 3);
+    assert.equal(high.length, 3);
+    let point = node.point;
+    assert.equal(point.length, 3);
+    // Each point should be in the node's bounding box.
+    assert.equal(
+        point[0] >= low[0] && point[0] <= high[0] && point[1] >= low[1] &&
+            point[1] <= high[1] && point[2] >= low[2] && point[2] <= high[2],
+        true);
+    return false;
+  });
+});
+
+it('Only visit root', () => {
+  let data = [
+    [0, 1, 0],
+    [1, 0.4, 2],
+    [1, 1, 3],
+    [0, 0, 5],
+  ];
+  let tree = new SPTree(data);
+  let numVisits = 0;
+  tree.visit((node, low, high) => {
+    numVisits++;
+    return true;
+  });
+  assert.equal(numVisits, 1);
+});
+
+it('Search in random data', () => {
+  let N = 10000;
+  let data = new Array(N);
+  for (let i = 0; i < N; i++) {
+    data[i] = [Math.random(), Math.random()];
+  }
+  let tree = new SPTree(data);
+  let numVisits = 0;
+  let query = data[Math.floor(Math.random() * N)];
+  let found = false;
+  tree.visit((node, low, high) => {
+    numVisits++;
+    if (node.point === query) {
+      found = true;
+      return true;
+    }
+    let outOfBounds = query[0] < low[0] || query[0] > high[0] ||
+        query[1] < low[1] || query[1] > high[1];
+    return outOfBounds;
+  });
+  assert.equal(found, true);
+  assert.isBelow(numVisits, N / 4);
+});
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/styles.html b/tensorflow/tensorboard/components/vz_projector_d3v4/styles.html
new file mode 100644
index 00000000000..32dc984b5d6
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/styles.html
@@ -0,0 +1,185 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<dom-module id="vz-projector-styles">
+<template>
+<style>
+:host {
+  --paper-input-container-label: {
+    font-size: 14px;
+  };
+  --paper-input-container-input: {
+    font-size: 14px;
+  };
+  /* TODO: Figure out why this doesn't work */
+  --paper-dropdown-menu-input: {
+    font-size: 14px;
+  };
+}
+
+paper-button {
+  background: #e3e3e3;
+  margin-left: 0;
+  text-transform: none;
+}
+
+paper-dropdown-menu paper-item {
+  font-size: 13px;
+}
+
+paper-tooltip {
+  max-width: 200px;
+  --paper-tooltip: {
+    font-size: 12px;
+  };
+}
+
+paper-checkbox {
+  --paper-checkbox-checked-color: #880E4F;
+}
+
+paper-toggle-button {
+  --paper-toggle-button-checked-bar-color:  #880E4F;
+  --paper-toggle-button-checked-button-color:  #880E4F;
+  --paper-toggle-button-checked-ink-color: #880E4F;
+}
+
+paper-icon-button {
+  border-radius: 50%;
+}
+
+paper-icon-button[active] {
+  color: white;
+  background-color: #880E4F;
+}
+
+.slider {
+  display: flex;
+  align-items: center;
+  margin-bottom: 10px;
+  justify-content: space-between;
+}
+
+.slider span {
+  width: 35px;
+  text-align: right;
+}
+
+.slider label {
+  align-items: center;
+  display: flex;
+}
+
+.help-icon {
+  height: 15px;
+  left: 2px;
+  min-width: 15px;
+  min-height: 15px;
+  margin: 0;
+  padding: 0;
+  top: -2px;
+  width: 15px;
+}
+
+.ink-panel {
+  display: flex;
+  flex-direction: column;
+  font-size: 14px;
+}
+
+.ink-panel h4 {
+  border-bottom: 1px solid #ddd;
+  font-size: 14px;
+  font-weight: 500;
+  margin: 0;
+  margin-bottom: 10px;
+  padding-bottom: 5px;
+}
+
+.ink-panel-header {
+  border-bottom: 1px solid rgba(0, 0, 0, 0.1);
+  border-top: 1px solid rgba(0, 0, 0, 0.1);
+  height: 50px;
+}
+
+.ink-panel-content {
+  display: none;
+  height: 100%;
+}
+
+.ink-panel-content.active {
+  display: block;
+}
+
+.ink-panel-content h3 {
+  font-weight: 500;
+  font-size: 14px;
+  margin-top: 20px;
+  margin-bottom: 5px;
+  text-transform: uppercase;
+}
+
+.ink-panel-header h3 {
+  font-weight: 500;
+  font-size: 14px;
+  margin: 0;
+  padding: 0 24px;
+  text-transform: uppercase;
+}
+
+
+/* - Tabs */
+.ink-tab-group {
+  align-items: center;
+  box-sizing: border-box;
+  display: flex;
+  height: 100%;
+  justify-content: space-around;
+}
+
+.ink-tab-group .projection-tab {
+  color: rgba(0, 0, 0, 0.5);
+  cursor: pointer;
+  font-weight: 300;
+  line-height: 49px;
+  padding: 0 12px;
+  text-align: center;
+  text-transform: uppercase;
+}
+
+.ink-tab-group .projection-tab:hover {
+  color: black;
+}
+
+.ink-tab-group .projection-tab.active {
+  border-bottom: 2px solid black;
+  color: black;
+  font-weight: 500;
+}
+
+h4 {
+  margin: 30px 0 10px 0;
+}
+
+.dismiss-dialog-note {
+  margin-top: 25px;
+  font-size: 11px;
+  text-align: right;
+}
+</style>
+</template>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/util.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/util.ts
new file mode 100644
index 00000000000..bd6df68b1a5
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/util.ts
@@ -0,0 +1,252 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {DataPoint} from './data';
+import * as logging from './logging';
+import {Point2D} from './vector';
+
+/**
+ * Delay for running expensive tasks, in milliseconds.
+ * The duration was empirically found so that it leaves enough time for the
+ * browser to update its UI state before starting an expensive UI-blocking task.
+ */
+const TASK_DELAY_MS = 200;
+
+/** Shuffles the array in-place in O(n) time using Fisher-Yates algorithm. */
+export function shuffle<T>(array: T[]): T[] {
+  let m = array.length;
+  let t: T;
+  let i: number;
+
+  // While there remain elements to shuffle.
+  while (m) {
+    // Pick a remaining element
+    i = Math.floor(Math.random() * m--);
+    // And swap it with the current element.
+    t = array[m];
+    array[m] = array[i];
+    array[i] = t;
+  }
+  return array;
+}
+
+export function range(count: number): number[] {
+  const rangeOutput: number[] = [];
+  for (let i = 0; i < count; i++) {
+    rangeOutput.push(i);
+  }
+  return rangeOutput;
+}
+
+export function classed(
+    element: HTMLElement, className: string, enabled: boolean) {
+  const classNames = element.className.split(' ');
+  if (enabled) {
+    if (className in classNames) {
+      return;
+    } else {
+      classNames.push(className);
+    }
+  } else {
+    const index = classNames.indexOf(className);
+    if (index === -1) {
+      return;
+    }
+    classNames.splice(index, 1);
+  }
+  element.className = classNames.join(' ');
+}
+
+/** Projects a 3d point into screen space */
+export function vector3DToScreenCoords(
+    cam: THREE.Camera, w: number, h: number, v: THREE.Vector3): Point2D {
+  let dpr = window.devicePixelRatio;
+  let pv = new THREE.Vector3().copy(v).project(cam);
+
+  // The screen-space origin is at the middle of the screen, with +y up.
+  let coords: Point2D =
+      [((pv.x + 1) / 2 * w) * dpr, -((pv.y - 1) / 2 * h) * dpr];
+  return coords;
+}
+
+/** Loads 3 contiguous elements from a packed xyz array into a Vector3. */
+export function vector3FromPackedArray(
+    a: Float32Array, pointIndex: number): THREE.Vector3 {
+  const offset = pointIndex * 3;
+  return new THREE.Vector3(a[offset], a[offset + 1], a[offset + 2]);
+}
+
+/**
+ * Gets the camera-space z coordinates of the nearest and farthest points.
+ * Ignores points that are behind the camera.
+ */
+export function getNearFarPoints(
+    worldSpacePoints: Float32Array, cameraPos: THREE.Vector3,
+    cameraTarget: THREE.Vector3): [number, number] {
+  let shortestDist: number = Infinity;
+  let furthestDist: number = 0;
+  const camToTarget = new THREE.Vector3().copy(cameraTarget).sub(cameraPos);
+  const camPlaneNormal = new THREE.Vector3().copy(camToTarget).normalize();
+  const n = worldSpacePoints.length / 3;
+  let src = 0;
+  let p = new THREE.Vector3();
+  let camToPoint = new THREE.Vector3();
+  for (let i = 0; i < n; i++) {
+    p.x = worldSpacePoints[src];
+    p.y = worldSpacePoints[src + 1];
+    p.z = worldSpacePoints[src + 2];
+    src += 3;
+
+    camToPoint.copy(p).sub(cameraPos);
+    const dist = camPlaneNormal.dot(camToPoint);
+    if (dist < 0) {
+      continue;
+    }
+    furthestDist = (dist > furthestDist) ? dist : furthestDist;
+    shortestDist = (dist < shortestDist) ? dist : shortestDist;
+  }
+  return [shortestDist, furthestDist];
+}
+
+/**
+ * Generate a texture for the points/images and sets some initial params
+ */
+export function createTexture(image: HTMLImageElement|
+                              HTMLCanvasElement): THREE.Texture {
+  let tex = new THREE.Texture(image);
+  tex.needsUpdate = true;
+  // Used if the texture isn't a power of 2.
+  tex.minFilter = THREE.LinearFilter;
+  tex.generateMipmaps = false;
+  tex.flipY = false;
+  return tex;
+}
+
+/**
+ * Assert that the condition is satisfied; if not, log user-specified message
+ * to the console.
+ */
+export function assert(condition: boolean, message?: string) {
+  if (!condition) {
+    message = message || 'Assertion failed';
+    throw new Error(message);
+  }
+}
+
+export type SearchPredicate = (p: DataPoint) => boolean;
+
+export function getSearchPredicate(
+    query: string, inRegexMode: boolean, fieldName: string): SearchPredicate {
+  let predicate: SearchPredicate;
+  if (inRegexMode) {
+    let regExp = new RegExp(query, 'i');
+    predicate = p => regExp.test(p.metadata[fieldName].toString());
+  } else {
+    // Doing a case insensitive substring match.
+    query = query.toLowerCase();
+    predicate = p => {
+      let label = p.metadata[fieldName].toString().toLowerCase();
+      return label.indexOf(query) >= 0;
+    };
+  }
+  return predicate;
+}
+
+/**
+ * Runs an expensive task asynchronously with some delay
+ * so that it doesn't block the UI thread immediately.
+ *
+ * @param message The message to display to the user.
+ * @param task The expensive task to run.
+ * @param msgId Optional. ID of an existing message. If provided, will overwrite
+ *     an existing message and won't automatically clear the message when the
+ *     task is done.
+ * @return The value returned by the task.
+ */
+export function runAsyncTask<T>(
+    message: string, task: () => T, msgId: string = null): Promise<T> {
+  let autoClear = (msgId == null);
+  msgId = logging.setModalMessage(message, msgId);
+  return new Promise<T>((resolve, reject) => {
+    setTimeout(() => {
+      try {
+        let result = task();
+        // Clearing the old message.
+        if (autoClear) {
+          logging.setModalMessage(null, msgId);
+        }
+        resolve(result);
+      } catch (ex) {
+        reject(ex);
+      }
+      return true;
+    }, TASK_DELAY_MS);
+  });
+}
+
+
+/**
+ * Parses the URL for query parameters, e.g. ?foo=1&bar=2 will return
+ *   {'foo': '1', 'bar': '2'}.
+ * @param url The URL to parse.
+ * @return A map of queryParam key to its value.
+ */
+export function getURLParams(url: string): {[key: string]: string} {
+  if (!url) {
+    return {};
+  }
+
+  let queryString = url.indexOf('?') !== -1 ? url.split('?')[1] : url;
+  if (queryString.indexOf('#')) {
+    queryString = queryString.split('#')[0];
+  }
+
+  const queryEntries = queryString.split('&');
+  let queryParams: {[key: string]: string} = {};
+  for (let i = 0; i < queryEntries.length; i++) {
+    let queryEntryComponents = queryEntries[i].split('=');
+    queryParams[queryEntryComponents[0].toLowerCase()] =
+        decodeURIComponent(queryEntryComponents[1]);
+  }
+  return queryParams;
+}
+
+/** List of substrings that auto generated tensors have in their name. */
+const SUBSTR_GEN_TENSORS = ['/Adagrad'];
+
+/** Returns true if the tensor was automatically generated by TF API calls. */
+export function tensorIsGenerated(tensorName: string): boolean {
+  for (let i = 0; i < SUBSTR_GEN_TENSORS.length; i++) {
+    if (tensorName.indexOf(SUBSTR_GEN_TENSORS[i]) >= 0) {
+      return true;
+    }
+  }
+  return false;
+}
+
+export function xor(cond1: boolean, cond2: boolean): boolean {
+  return (cond1 || cond2) && !(cond1 && cond2);
+}
+
+/** Checks to see if the browser supports webgl. */
+export function hasWebGLSupport(): boolean {
+  try {
+    let c = document.createElement('canvas');
+    let gl = c.getContext('webgl') || c.getContext('experimental-webgl');
+    return gl != null && typeof weblas !== 'undefined';
+  } catch (e) {
+    return false;
+  }
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/util_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/util_test.ts
new file mode 100644
index 00000000000..f7c0027c81b
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/util_test.ts
@@ -0,0 +1,42 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+import * as util from './util';
+
+describe('getURLParams', () => {
+  it('search query with valid param returns correct object', () => {
+    let urlParams = util.getURLParams('?config=http://google.com/');
+    expect(urlParams).toEqual({'config': 'http://google.com/'});
+  });
+
+  it('search query with multiple valid params returns correct object', () => {
+    let urlParams = util.getURLParams('?config=http://google.com/&foo=bar');
+    expect(urlParams).toEqual({'config': 'http://google.com/', 'foo': 'bar'});
+  });
+
+  it('search query with valid param with URL encoded characters', () => {
+    let urlParams = util.getURLParams('?config=http://google.com/%20search');
+    expect(urlParams).toEqual({'config': 'http://google.com/ search'});
+  });
+
+  it('search query with pound sign', () => {
+    let urlParams = util.getURLParams('?config=http://google.com/#foo');
+    expect(urlParams).toEqual({'config': 'http://google.com/'});
+  });
+
+  it('no search query returns empty object', () => {
+    let urlParams = util.getURLParams('');
+    expect(urlParams).toEqual({});
+  });
+});
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vector.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vector.ts
new file mode 100644
index 00000000000..0de78ad85df
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vector.ts
@@ -0,0 +1,266 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import * as d3 from 'd3';  // from //third_party/javascript/typings/d3_v4
+import {assert} from './util';
+
+/**
+ * @fileoverview Useful vector utilities.
+ */
+
+export type Vector = Float32Array | number[];
+export type Point2D = [number, number];
+export type Point3D = [number, number, number];
+
+/** Returns the dot product of two vectors. */
+export function dot(a: Vector, b: Vector): number {
+  assert(a.length === b.length, 'Vectors a and b must be of same length');
+  let result = 0;
+  for (let i = 0; i < a.length; ++i) {
+    result += a[i] * b[i];
+  }
+  return result;
+}
+
+/** Sums all the elements in the vector */
+export function sum(a: Vector): number {
+  let result = 0;
+  for (let i = 0; i < a.length; ++i) {
+    result += a[i];
+  }
+  return result;
+}
+
+/** Returns the sum of two vectors, i.e. a + b */
+export function add(a: Vector, b: Vector): Float32Array {
+  assert(a.length === b.length, 'Vectors a and b must be of same length');
+  let result = new Float32Array(a.length);
+  for (let i = 0; i < a.length; ++i) {
+    result[i] = a[i] + b[i];
+  }
+  return result;
+}
+
+/** Subtracts vector b from vector a, i.e. returns a - b */
+export function sub(a: Vector, b: Vector): Float32Array {
+  assert(a.length === b.length, 'Vectors a and b must be of same length');
+  let result = new Float32Array(a.length);
+  for (let i = 0; i < a.length; ++i) {
+    result[i] = a[i] - b[i];
+  }
+  return result;
+}
+
+/** Returns the square norm of the vector */
+export function norm2(a: Vector): number {
+  let result = 0;
+  for (let i = 0; i < a.length; ++i) {
+    result += a[i] * a[i];
+  }
+  return result;
+}
+
+/** Returns the euclidean distance between two vectors. */
+export function dist(a: Vector, b: Vector): number {
+  return Math.sqrt(dist2(a, b));
+}
+
+/** Returns the square euclidean distance between two vectors. */
+export function dist2(a: Vector, b: Vector): number {
+  assert(a.length === b.length, 'Vectors a and b must be of same length');
+  let result = 0;
+  for (let i = 0; i < a.length; ++i) {
+    let diff = a[i] - b[i];
+    result += diff * diff;
+  }
+  return result;
+}
+
+/** Returns the square euclidean distance between two 2D points. */
+export function dist2_2D(a: Vector, b: Vector): number {
+  let dX = a[0] - b[0];
+  let dY = a[1] - b[1];
+  return dX * dX + dY * dY;
+}
+
+/** Returns the square euclidean distance between two 3D points. */
+export function dist2_3D(a: Vector, b: Vector): number {
+  let dX = a[0] - b[0];
+  let dY = a[1] - b[1];
+  let dZ = a[2] - b[2];
+  return dX * dX + dY * dY + dZ * dZ;
+}
+
+/** Returns the euclidean distance between 2 3D points. */
+export function dist_3D(a: Vector, b: Vector): number {
+  return Math.sqrt(dist2_3D(a, b));
+}
+
+/**
+ * Returns the square euclidean distance between two vectors, with an early
+ * exit (returns -1) if the distance is >= to the provided limit.
+ */
+export function dist2WithLimit(a: Vector, b: Vector, limit: number): number {
+  assert(a.length === b.length, 'Vectors a and b must be of same length');
+  let result = 0;
+  for (let i = 0; i < a.length; ++i) {
+    let diff = a[i] - b[i];
+    result += diff * diff;
+    if (result >= limit) {
+      return -1;
+    }
+  }
+  return result;
+}
+
+/** Returns the square euclidean distance between two 2D points. */
+export function dist22D(a: Point2D, b: Point2D): number {
+  let dX = a[0] - b[0];
+  let dY = a[1] - b[1];
+  return dX * dX + dY * dY;
+}
+
+/** Modifies the vector in-place to have unit norm. */
+export function unit(a: Vector): void {
+  let norm = Math.sqrt(norm2(a));
+  assert(norm >= 0, 'Norm of the vector must be > 0');
+  for (let i = 0; i < a.length; ++i) {
+    a[i] /= norm;
+  }
+}
+
+/**
+ *  Projects the vectors to a lower dimension
+ *
+ * @param vectors Array of vectors to be projected.
+ * @param newDim The resulting dimension of the vectors.
+ */
+export function projectRandom(vectors: Float32Array[], newDim: number):
+    Float32Array[] {
+  let dim = vectors[0].length;
+  let N = vectors.length;
+  let newVectors: Float32Array[] = new Array(N);
+  for (let i = 0; i < N; ++i) {
+    newVectors[i] = new Float32Array(newDim);
+  }
+  // Make nDim projections.
+  for (let k = 0; k < newDim; ++k) {
+    let randomVector = rn(dim);
+    for (let i = 0; i < N; ++i) {
+      newVectors[i][k] = dot(vectors[i], randomVector);
+    }
+  }
+  return newVectors;
+}
+
+/**
+ * Projects a vector onto a 2D plane specified by the two direction vectors.
+ */
+export function project2d(a: Vector, dir1: Vector, dir2: Vector): Point2D {
+  return [dot(a, dir1), dot(a, dir2)];
+}
+
+/**
+ * Computes the centroid of the data points. If the provided data points are not
+ * vectors, an accessor function needs to be provided.
+ */
+export function centroid<T>(dataPoints: T[], accessor?: (a: T) => Vector):
+    Vector {
+  if (dataPoints.length === 0) {
+    return null;
+  }
+  if (accessor == null) {
+    accessor = (a: T) => <any>a;
+  }
+  assert(dataPoints.length >= 0, '`vectors` must be of length >= 1');
+  let centroid = new Float32Array(accessor(dataPoints[0]).length);
+  for (let i = 0; i < dataPoints.length; ++i) {
+    let dataPoint = dataPoints[i];
+    let vector = accessor(dataPoint);
+    for (let j = 0; j < centroid.length; ++j) {
+      centroid[j] += vector[j];
+    }
+  }
+  for (let j = 0; j < centroid.length; ++j) {
+    centroid[j] /= dataPoints.length;
+  }
+  return centroid;
+}
+
+/**
+ * Generates a vector of the specified size where each component is drawn from
+ * a random (0, 1) gaussian distribution.
+ */
+export function rn(size: number): Float32Array {
+  const normal = d3.randomNormal();
+  let result = new Float32Array(size);
+  for (let i = 0; i < size; ++i) {
+    result[i] = normal();
+  }
+  return result;
+}
+
+/**
+ * Returns the cosine distance ([0, 2]) between two vectors
+ * that have been normalized to unit norm.
+ */
+export function cosDistNorm(a: Vector, b: Vector): number {
+  return 1 - dot(a, b);
+}
+
+/**
+ * Returns the cosine distance ([0, 2]) between two vectors.
+ */
+export function cosDist(a: Vector, b: Vector): number {
+  return 1 - cosSim(a, b);
+}
+
+/** Returns the cosine similarity ([-1, 1]) between two vectors. */
+export function cosSim(a: Vector, b: Vector): number {
+  return dot(a, b) / Math.sqrt(norm2(a) * norm2(b));
+}
+
+/**
+ * Converts list of vectors (matrix) into a 1-dimensional
+ * typed array with row-first order.
+ */
+export function toTypedArray<T>(
+    dataPoints: T[], accessor: (dataPoint: T) => Float32Array): Float32Array {
+  let N = dataPoints.length;
+  let dim = accessor(dataPoints[0]).length;
+  let result = new Float32Array(N * dim);
+  for (let i = 0; i < N; ++i) {
+    let vector = accessor(dataPoints[i]);
+    for (let d = 0; d < dim; ++d) {
+      result[i * dim + d] = vector[d];
+    }
+  }
+  return result;
+}
+
+/**
+ * Transposes an RxC matrix represented as a flat typed array
+ * into a CxR matrix, again represented as a flat typed array.
+ */
+export function transposeTypedArray(
+    r: number, c: number, typedArray: Float32Array) {
+  let result = new Float32Array(r * c);
+  for (let i = 0; i < r; ++i) {
+    for (let j = 0; j < c; ++j) {
+      result[j * r + i] = typedArray[i * c + j];
+    }
+  }
+  return result;
+}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-app.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-app.html
new file mode 100644
index 00000000000..34aca77dde4
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-app.html
@@ -0,0 +1,105 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="../paper-icon-button/paper-icon-button.html">
+<link rel="import" href="../paper-tooltip/paper-tooltip.html">
+
+<link rel="import" href="vz-projector.html">
+<link rel="import" href="styles.html">
+
+<dom-module id="vz-projector-app">
+<template>
+<style include="vz-projector-styles"></style>
+<style>
+#appbar {
+  display: flex;
+  align-items: center;
+  justify-content: space-between;
+  padding: 0 24px;
+  height: 60px;
+  color: white;
+  background: #560731;
+}
+
+#appbar .logo {
+  font-size: 18px;
+  font-weight: 300;
+}
+
+.icons {
+  display: flex;
+}
+
+.icons a {
+  color: white;
+}
+
+vz-projector {
+  height: calc(100% - 60px);
+}
+
+#container {
+  height: 100%;
+}
+</style>
+
+<div id="container">
+  <div id="appbar">
+    <div>Embedding Projector</div>
+    <div class="icons">
+      <a title="Documentation" target="_blank" href="[[documentationLink]]">
+        <paper-icon-button icon="help-outline"></paper-icon-button>
+        <paper-tooltip position="bottom" animation-delay="0" fit-to-visible-bounds>
+          Open documentation
+        </paper-tooltip>
+      </a>
+      <a title="Report bug" target="_blank" href="[[bugReportLink]]">
+        <paper-icon-button icon="bug-report"></paper-icon-button>
+        <paper-tooltip position="bottom" animation-delay="0" fit-to-visible-bounds>
+          Report a bug
+        </paper-tooltip>
+      </a>
+    </div>
+  </div>
+  <vz-projector route-prefix="[[routePrefix]]"
+      serving-mode="[[servingMode]]"
+      projector-config-json-path="[[projectorConfigJsonPath]]"
+      page-view-logging="[[pageViewLogging]]"
+      event-logging="[[eventLogging]]">
+  </vz-projector>
+</div>
+<!-- Google analytics -->
+<script>
+  (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
+  (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
+  m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
+  })(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
+
+  ga('create', 'UA-46457317-5', 'auto');
+</script>
+</template>
+<script>
+  Polymer({
+    is: 'vz-projector-app',
+    properties: {
+      pageViewLogging: {type: Boolean, value: false},
+      eventLogging: {type: Boolean, value: false}
+    }
+  });
+</script>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.html
new file mode 100644
index 00000000000..c37d8d9571f
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.html
@@ -0,0 +1,205 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="../iron-collapse/iron-collapse.html">
+<link rel="import" href="../paper-input/paper-textarea.html">
+<link rel="import" href="../paper-icon-button/paper-icon-button.html">
+<link rel="import" href="../paper-tooltip/paper-tooltip.html">
+<link rel="import" href="styles.html">
+
+<dom-module id="vz-projector-bookmark-panel">
+<template>
+<style include="vz-projector-styles"></style>
+<style>
+#title {
+  background-color: #fafafa;
+  color: black;
+  font-weight: 500;
+  left: 0;
+  line-height: 60px;
+  padding-left: 24px;
+  position: absolute;
+  width: 276px;
+}
+#bookmark-container {
+  background-color: #fafafa;
+}
+#icon-container {
+  line-height: 60px;
+  position: absolute;
+  right: 0;
+}
+#header {
+  border-top: 1px solid rgba(0, 0, 0, 0.1);
+  position: relative;
+}
+#panel {
+  background-color: #fafafa;
+  position: relative;
+  overflow-y: scroll;
+  top: 60px;
+  max-height: 50vh;
+}
+
+#save-container {
+  text-align: center;
+}
+
+.state-radio {
+  display: table-cell;
+  vertical-align: middle;
+  padding-top: 16px;
+}
+
+.state-label {
+  display: table-cell;
+  vertical-align: middle;
+  top: 14px;
+}
+
+.state-label-input {
+  width: 194px;
+}
+
+.state-clear {
+  display: table-cell;
+  vertical-align: middle;
+  padding-top: 20px;
+}
+#state-file {
+  display: none;
+}
+#no-bookmarks {
+  padding: 0 24px;
+}
+#action-buttons-container .add-icon-button {
+  background-color: #03a9f4;
+  color: white;
+  margin: 0 4px 4px auto;
+  right: 7px;
+  top: -4px;
+}
+.upload-download-icon-button {
+  padding: 0;
+}
+#action-buttons-container {
+  display: flex;
+  margin-left: 34px;
+  margin-top: 6px;
+}
+.ink-fab {
+  border-radius: 50%;
+  background: white;
+  box-shadow: 0 1px 3px rgba(0, 0, 0, 0.3);
+}
+paper-textarea {
+  --paper-input-container-input: {
+    font-size: 12px;
+  }
+  --paper-font-caption: {
+    display: none
+  }
+}
+</style>
+
+<!-- Bookmarking controls -->
+<div id="bookmark-container">
+  <div id="header">
+    <div id="title">
+      BOOKMARKS ([[savedStates.length]])
+      <paper-icon-button icon="help" class="help-icon"></paper-icon-button>
+      <paper-tooltip animation-delay="0" position="top" offset="0">
+        Open this drawer to save a set of views of the projection, including
+        selected points. A file containing the bookmarks can then be saved and
+        later loaded to view them.
+      </paper-tooltip>
+    </div>
+    <div id="icon-container">
+      <!-- Icons and event handlers are inverted because the tray expands upwards. -->
+      <paper-icon-button id="expand-more"
+          icon="expand-less"
+          on-tap="_expandMore"></paper-icon-button>
+      <paper-icon-button id="expand-less"
+          style="display: none"
+          icon="expand-more"
+          on-tap="_expandLess"></paper-icon-button>
+    </div>
+  </div>
+  <iron-collapse id="panel">
+    <!-- Saving state section -->
+    <div id="state-section">
+      <template is="dom-if" if="[[!savedStates.length]]">
+        <p id="no-bookmarks">
+            No bookmarks yet, upload a bookmarks file or add a new bookmark by clicking the "+" below.
+        </p>
+      </template>
+
+      <template is="dom-repeat" items="{{savedStates}}">
+        <div class="state-row">
+          <div class="state-radio">
+            <template is="dom-if" if="{{item.isSelected}}">
+              <paper-icon-button icon="radio-button-checked"></paper-icon-button>
+            </template>
+            <template is="dom-if" if="{{!item.isSelected}}">
+              <paper-icon-button
+                  icon="radio-button-unchecked"
+                  data-index$="{{index}}"
+                  on-tap="_radioButtonHandler"></paper-icon-button>
+            </template>
+          </div>
+          <div class="state-label">
+            <paper-textarea value="[[item.label]]"
+                class="state-label-input"
+                on-keyup="_labelChange"
+                data-index$="[[index]]"
+                autoresizing></paper-input>
+          </div>
+          <div class="state-clear">
+            <paper-icon-button
+                icon="clear"
+                data-index$="{{index}}"
+                on-tap="_clearButtonHandler"></paper-icon-button>
+          </div>
+        </div>
+      </template>
+
+      <div id="action-buttons-container">
+        <paper-icon-button
+            class="upload-download-icon-button"
+            icon="save"
+            title="Save bookmarks"
+            disabled="[[!hasStates]]"
+            on-tap="_downloadFile"></paper-icon-button>
+        <paper-icon-button
+            class="upload-download-icon-button"
+            icon="file-upload"
+            title="Load bookmarks"
+            on-tap="_uploadFile"></paper-icon-button>
+        <paper-icon-button
+            class="add-icon-button ink-fab"
+            icon="add"
+            title="Add bookmark"
+            on-tap="_addBookmark"></paper-icon-button>
+        <input type="file" id="state-file" name="state-file"/>
+      </div>
+    </div>
+  </iron-collapse>
+</div>
+
+</template>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.ts
new file mode 100644
index 00000000000..53195fa47c0
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-bookmark-panel.ts
@@ -0,0 +1,283 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+import {State} from './data';
+import {DataProvider, EmbeddingInfo} from './data-provider';
+import * as logging from './logging';
+import {ProjectorEventContext} from './projectorEventContext';
+import {Projector} from './vz-projector';
+// tslint:disable-next-line:no-unused-variable
+import {PolymerElement, PolymerHTMLElement} from './vz-projector-util';
+
+// tslint:disable-next-line
+export let BookmarkPanelPolymer = PolymerElement({
+  is: 'vz-projector-bookmark-panel',
+  properties: {
+    savedStates: Object,
+    // Keep a separate polymer property because the savedStates doesn't change
+    // when adding and removing states.
+    hasStates: {type: Boolean, value: false},
+    selectedState: Number
+  }
+});
+
+export class BookmarkPanel extends BookmarkPanelPolymer {
+  private projector: Projector;
+
+  // A list containing all of the saved states.
+  private savedStates: State[];
+  private hasStates = false;
+  private selectedState: number;
+  private ignoreNextProjectionEvent: boolean;
+
+  private expandLessButton: HTMLButtonElement;
+  private expandMoreButton: HTMLButtonElement;
+
+  ready() {
+    this.savedStates = [];
+    this.setupUploadButton();
+    this.ignoreNextProjectionEvent = false;
+    this.expandLessButton =
+        this.querySelector('#expand-less') as HTMLButtonElement;
+    this.expandMoreButton =
+        this.querySelector('#expand-more') as HTMLButtonElement;
+  }
+
+  initialize(
+      projector: Projector, projectorEventContext: ProjectorEventContext) {
+    this.projector = projector;
+    projectorEventContext.registerProjectionChangedListener(() => {
+      if (this.ignoreNextProjectionEvent) {
+        this.ignoreNextProjectionEvent = false;
+      } else {
+        this.clearStateSelection();
+      }
+    });
+  }
+
+  setSelectedTensor(
+      run: string, tensorInfo: EmbeddingInfo, dataProvider: DataProvider) {
+    // Clear any existing bookmarks.
+    this.addStates(null);
+    if (tensorInfo && tensorInfo.bookmarksPath) {
+      // Get any bookmarks that may come when the projector starts up.
+      dataProvider.getBookmarks(run, tensorInfo.tensorName, bookmarks => {
+        this.addStates(bookmarks);
+        this._expandMore();
+      });
+    } else {
+      this._expandLess();
+    }
+  }
+
+  /** Handles a click on show bookmarks tray button. */
+  _expandMore() {
+    this.$.panel.show();
+    this.expandMoreButton.style.display = 'none';
+    this.expandLessButton.style.display = '';
+  }
+
+  /** Handles a click on hide bookmarks tray button. */
+  _expandLess() {
+    this.$.panel.hide();
+    this.expandMoreButton.style.display = '';
+    this.expandLessButton.style.display = 'none';
+  }
+
+  /** Handles a click on the add bookmark button. */
+  _addBookmark() {
+    let currentState = this.projector.getCurrentState();
+    currentState.label = 'State ' + this.savedStates.length;
+    currentState.isSelected = true;
+
+    this.selectedState = this.savedStates.length;
+
+    for (let i = 0; i < this.savedStates.length; i++) {
+      this.savedStates[i].isSelected = false;
+      // We have to call notifyPath so that polymer knows this element was
+      // updated.
+      this.notifyPath('savedStates.' + i + '.isSelected', false, false);
+    }
+
+    this.push('savedStates', currentState as any);
+    this.updateHasStates();
+  }
+
+  /** Handles a click on the download bookmarks button. */
+  _downloadFile() {
+    let serializedState = this.serializeAllSavedStates();
+    let blob = new Blob([serializedState], {type: 'text/plain'});
+    let textFile = window.URL.createObjectURL(blob);
+
+    // Force a download.
+    let a = document.createElement('a');
+    document.body.appendChild(a);
+    a.style.display = 'none';
+    a.href = textFile;
+    (a as any).download = 'state';
+    a.click();
+
+    document.body.removeChild(a);
+    window.URL.revokeObjectURL(textFile);
+  }
+
+  /** Handles a click on the upload bookmarks button. */
+  _uploadFile() {
+    let fileInput = this.dom.select('#state-file');
+    (fileInput.node() as HTMLInputElement).click();
+  }
+
+  private setupUploadButton() {
+    // Show and setup the load view button.
+    const fileInput = this.querySelector('#state-file') as HTMLInputElement;
+    fileInput.onchange = () => {
+      const file: File = fileInput.files[0];
+      // Clear out the value of the file chooser. This ensures that if the user
+      // selects the same file, we'll re-read it.
+      fileInput.value = '';
+      const fileReader = new FileReader();
+      fileReader.onload = (evt) => {
+        const str: string = fileReader.result;
+        const savedStates = JSON.parse(str);
+
+        // Verify the bookmarks match.
+        if (this.savedStatesValid(savedStates)) {
+          this.addStates(savedStates);
+          this.loadSavedState(0);
+        } else {
+          logging.setWarningMessage(
+              `Unable to load bookmarks: wrong dataset, expected dataset ` +
+              `with shape (${savedStates[0].dataSetDimensions}).`);
+        }
+      };
+      fileReader.readAsText(file);
+    };
+  }
+
+  addStates(savedStates?: State[]) {
+    if (savedStates == null) {
+      this.savedStates = [];
+    } else {
+      for (let i = 0; i < savedStates.length; i++) {
+        savedStates[i].isSelected = false;
+        this.push('savedStates', savedStates[i] as any);
+      }
+    }
+    this.updateHasStates();
+  }
+
+  /** Deselects any selected state selection. */
+  clearStateSelection() {
+    for (let i = 0; i < this.savedStates.length; i++) {
+      this.setSelectionState(i, false);
+    }
+  }
+
+  /** Handles a radio button click on a saved state. */
+  _radioButtonHandler(evt: Event) {
+    const index = this.getParentDataIndex(evt);
+    this.loadSavedState(index);
+    this.setSelectionState(index, true);
+  }
+
+  loadSavedState(index: number) {
+    for (let i = 0; i < this.savedStates.length; i++) {
+      if (this.savedStates[i].isSelected) {
+        this.setSelectionState(i, false);
+      } else if (index === i) {
+        this.setSelectionState(i, true);
+        this.ignoreNextProjectionEvent = true;
+        this.projector.loadState(this.savedStates[i]);
+      }
+    }
+  }
+
+  private setSelectionState(stateIndex: number, selected: boolean) {
+    this.savedStates[stateIndex].isSelected = selected;
+    const path = 'savedStates.' + stateIndex + '.isSelected';
+    this.notifyPath(path, selected, false);
+  }
+
+  /**
+   * Crawls up the DOM to find an ancestor with a data-index attribute. This is
+   * used to match events to their bookmark index.
+   */
+  private getParentDataIndex(evt: Event) {
+    for (let i = 0; i < (evt as any).path.length; i++) {
+      let dataIndex = (evt as any).path[i].getAttribute('data-index');
+      if (dataIndex != null) {
+        return +dataIndex;
+      }
+    }
+    return -1;
+  }
+
+  /** Handles a clear button click on a bookmark. */
+  _clearButtonHandler(evt: Event) {
+    let index = this.getParentDataIndex(evt);
+    this.splice('savedStates', index, 1);
+    this.updateHasStates();
+  }
+
+  /** Handles a label change event on a bookmark. */
+  _labelChange(evt: Event) {
+    let index = this.getParentDataIndex(evt);
+    this.savedStates[index].label = (evt.target as any).value;
+  }
+
+  /**
+   * Used to determine whether to select the radio button for a given bookmark.
+   */
+  _isSelectedState(index: number) {
+    return index === this.selectedState;
+  }
+  _isNotSelectedState(index: number) {
+    return index !== this.selectedState;
+  }
+
+  /**
+   * Gets all of the saved states as a serialized string.
+   */
+  serializeAllSavedStates(): string {
+    return JSON.stringify(this.savedStates);
+  }
+
+  /**
+   * Loads all of the serialized states and shows them in the list of
+   * viewable states.
+   */
+  loadSavedStates(serializedStates: string) {
+    this.savedStates = JSON.parse(serializedStates);
+    this.updateHasStates();
+  }
+
+  /**
+   * Updates the hasState polymer property.
+   */
+  private updateHasStates() {
+    this.hasStates = (this.savedStates.length !== 0);
+  }
+
+  /** Sanity checks a State array to ensure it matches the current dataset. */
+  private savedStatesValid(states: State[]): boolean {
+    for (let i = 0; i < states.length; i++) {
+      if (states[i].dataSetDimensions[0] !== this.projector.dataSet.dim[0] ||
+          states[i].dataSetDimensions[1] !== this.projector.dataSet.dim[1]) {
+        return false;
+      }
+    }
+    return true;
+  }
+}
+document.registerElement(BookmarkPanel.prototype.is, BookmarkPanel);
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-colab.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-colab.html
new file mode 100644
index 00000000000..2acb570b3c1
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-colab.html
@@ -0,0 +1,32 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="vz-projector.html">
+<dom-module id="vz-projector-colab">
+<template>
+<vz-projector serving-mode="proto" data-proto="[[dataProto]]"></vz-projector>
+</template>
+<script>
+Polymer({
+  is: 'vz-projector-colab',
+  properties: {
+    dataProto: Object
+  }
+});
+</script>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html
new file mode 100644
index 00000000000..3857113ac04
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-dashboard.html
@@ -0,0 +1,79 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="../tf-dashboard-common/tf-dashboard.html">
+<link rel="import" href="../tf-dashboard-common/tf-no-data-warning.html">
+<link rel="import" href="vz-projector.html">
+
+<dom-module id="vz-projector-dashboard">
+<template>
+  <tf-no-data-warning
+    data-type="projector"
+    show-warning="[[dataNotFound]]"
+  ></tf-no-data-warning>
+  <template is="dom-if" if="[[!dataNotFound]]">
+    <vz-projector
+      id="projector"
+      route-prefix="[[routePrefix]]"
+      serving-mode="server"
+      page-view-logging
+      event-logging
+    ></vz-projector>
+  </template>
+</template>
+<script>
+(function() {
+TF.Dashboard.VzProjectorDashboard = Polymer({
+  is: 'vz-projector-dashboard',
+  factoryImpl: function(routePrefix) {
+    this.routePrefix = routePrefix;
+  },
+  properties: {
+    dataNotFound: Boolean,
+    routePrefix: String,
+    // Whether this dashboard is initialized. This dashboard should only be initialized once.
+    _initialized: Boolean,
+  },
+  behaviors: [
+    TF.Dashboard.DashboardBehavior("embeddings"),
+  ],
+  reload: function() {
+    // Do not reload the embedding projector. Reloading could take a long time.
+  },
+  attached: function() {
+    if (this._initialized) {
+      return;
+    }
+    let xhr = new XMLHttpRequest();
+    xhr.open('GET', this.routePrefix + '/runs');
+    xhr.onload = () => {
+      // Set this to true so we only initialize once.
+      this._initialized = true;
+
+      let runs = JSON.parse(xhr.responseText);
+      this.set('dataNotFound', runs.length === 0);
+    };
+    xhr.onerror = () => {
+      this.set('dataNotFound', false);
+    };
+    xhr.send();
+  },
+});
+})();
+</script>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.html
new file mode 100644
index 00000000000..607d4467892
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.html
@@ -0,0 +1,399 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="../paper-button/paper-button.html">
+<link rel="import" href="../paper-dropdown-menu/paper-dropdown-menu.html">
+<link rel="import" href="../paper-input/paper-input.html">
+<link rel="import" href="../paper-input/paper-textarea.html">
+<link rel="import" href="../paper-item/paper-item.html">
+<link rel="import" href="../paper-listbox/paper-listbox.html">
+<link rel="import" href="../paper-icon-button/paper-icon-button.html">
+<link rel="import" href="../paper-checkbox/paper-checkbox.html">
+<link rel="import" href="../paper-dialog/paper-dialog.html">
+<link rel="import" href="../paper-dialog-scrollable/paper-dialog-scrollable.html">
+<link rel="import" href="../paper-tooltip/paper-tooltip.html">
+<link rel="import" href="vz-projector-legend.html">
+<link rel="import" href="styles.html">
+
+<dom-module id="vz-projector-data-panel">
+<template>
+<style include="vz-projector-styles"></style>
+<style>
+.container {
+  padding: 10px 20px 20px 20px;
+}
+
+input[type=file] {
+  display: none;
+}
+
+.file-name {
+  margin-right: 10px;
+}
+
+.dirs {
+  color: rgba(0, 0, 0, 0.7);
+  font-size: 12px;
+}
+
+.dirs table tr {
+  vertical-align: top;
+}
+
+.dirs table tr td {
+  padding-bottom: 10px;
+}
+
+paper-item {
+  --paper-item-disabled: {
+    border-bottom: 1px solid black;
+    justify-content: center;
+    font-size: 12px;
+    line-height: normal;
+    min-height: 0px;
+  };
+}
+
+.item-details {
+  margin-left: 5px;
+  color: gray;
+  font-size: 12px;
+}
+
+paper-dropdown-menu {
+  width: 100%;
+}
+
+paper-dropdown-menu paper-item {
+  justify-content: space-between;
+}
+
+.title {
+  align-items: center;
+  border-bottom: 1px solid rgba(0, 0, 0, 0.1);
+  color: black;
+  display: flex;
+  font-weight: 500;
+  height: 59px;
+  padding-left: 20px;
+}
+
+#normalize-data-checkbox {
+  margin: 10px 0;
+}
+
+#projector-config-template {
+  --paper-input-container-input: {
+    line-height: 13px;
+    font-family: monospace;
+    font-size: 12px;
+  };
+}
+
+#generate-share-url {
+  padding: 16px;
+  margin-left: 24px;
+}
+
+#projector-share-button-container {
+  margin: 10px 0;
+}
+
+.config-checkbox {
+  display: inline-block;
+  font-size: 11px;
+  margin-left: 10px;
+}
+
+.projector-config-options {
+  margin-top: 12px;
+}
+
+.projector-config-dialog-container {
+  padding: 24px;
+}
+
+.code {
+  background-color: #f7f7f7;
+  display: table;
+  font-family: monospace;
+  margin-top: 7px;
+  padding: 15px;
+}
+
+.delimiter {
+  color: #B71C1C;
+}
+
+.upload-step {
+  display: flex;
+  justify-content: space-between;
+  margin-bottom: 6px;
+}
+
+.upload-step paper-button {
+  margin-left: 30px;
+}
+
+.step-label {
+  color: rgb(38, 180, 226);
+}
+
+.scrollable-container {
+  margin-top: 0;
+  min-width: 400px;
+}
+
+#projectorConfigDialog p {
+  margin: 8px 0 8px;
+}
+
+.data-step {
+  margin-top: 40px;
+}
+
+.data-step-contents {
+  display: table;
+  width: 100%;
+}
+
+.data-step-contents-contents {
+  display: table-cell;
+  margin-top: 6px;
+}
+
+.data-step-contents-upload {
+  display: table-cell;
+  text-align: right;
+  vertical-align: bottom;
+}
+
+#demo-data-buttons-container {
+  display: none;
+}
+
+.colorby-container {
+  margin-bottom: 10px;
+}
+</style>
+<div class="title">DATA</div>
+<div class="container">
+  <!-- List of runs -->
+  <template is="dom-if" if="[[_hasChoices(runNames)]]">
+    <paper-dropdown-menu no-animations label="[[_getNumRunsLabel(runNames)]] found">
+      <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{selectedRun}}">
+        <template is="dom-repeat" items="[[runNames]]">
+          <paper-item value="[[item]]" label="[[item]]">
+            [[item]]
+          </paper-item>
+        </template>
+      </paper-listbox>
+    </paper-dropdown-menu>
+  </template>
+
+  <template is="dom-if" if="[[tensorNames]]">
+    <!-- List of tensors in checkpoint -->
+    <paper-dropdown-menu no-animations label="[[_getNumTensorsLabel(tensorNames)]] found">
+      <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{selectedTensor}}">
+        <template is="dom-repeat" items="[[tensorNames]]">
+          <paper-item value="[[item.name]]" label="[[item.name]]">
+            [[item.name]]
+            <span class="item-details">
+              [[item.shape.0]]x[[item.shape.1]]
+            </span>
+          </paper-item>
+        </template>
+      </paper-listbox>
+    </paper-dropdown-menu>
+  </template>
+  <!-- Label by -->
+  <template is="dom-if" if="[[_hasChoices(labelOptions)]]">
+    <paper-dropdown-menu no-animations label="Label by">
+      <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{selectedLabelOption}}">
+        <template is="dom-repeat" items="[[labelOptions]]">
+          <paper-item value="[[item]]" label="[[item]]">
+            [[item]]
+          </paper-item>
+        </template>
+      </paper-listbox>
+    </paper-dropdown-menu>
+  </template>
+
+  <!-- Color by -->
+  <div hidden$="[[!_hasChoices(colorOptions)]]" class="colorby-container">
+    <paper-dropdown-menu id="colorby" no-animations label="Color by">
+      <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{selectedColorOptionName}}">
+        <template is="dom-repeat" items="[[colorOptions]]">
+          <paper-item class$="[[getSeparatorClass(item.isSeparator)]]" value="[[item.name]]" label="[[item.name]]" disabled="[[item.isSeparator]]">
+            [[item.name]]
+            <span class="item-details">[[item.desc]]</span>
+          </paper-item>
+        </template>
+      </paper-listbox>
+    </paper-dropdown-menu>
+    <div hidden$="[[!showForceCategoricalColorsCheckbox]]">
+      <paper-checkbox id="force-categorical-checkbox"></paper-checkbox>
+      Use categorical coloring
+      <paper-icon-button icon="help" class="help-icon"></paper-icon-button>
+      <paper-tooltip position="bottom" animation-delay="0" fit-to-visible-bounds>
+        For metadata fields that have many unique values we use a gradient color map
+        by default. This checkbox allows you to force categorical coloring by a given
+        metadata field.
+      </paper-tooltip>
+    </div>
+    <template dom-if="[[colorLegendRenderInfo]]">
+      <vz-projector-legend render-info="[[colorLegendRenderInfo]]"></vz-projector-legend>
+    </template>
+  </div>
+  <paper-checkbox id="normalize-data-checkbox" checked="{{normalizeData}}">
+    Sphereize data
+    <paper-icon-button icon="help" class="help-icon"></paper-icon-button>
+    <paper-tooltip position="bottom" animation-delay="0" fit-to-visible-bounds>
+      The data is normalized by shifting each point by the centroid and making
+      it unit norm.
+    </paper-tooltip>
+  </paper-checkbox>
+  <p id="demo-data-buttons-container">
+    <span>
+      <paper-tooltip position="bottom" animation-delay="0" fit-to-visible-bounds>
+        Load data from your computer
+      </paper-tooltip>
+      <paper-button id="upload" class="ink-button" onclick="dataDialog.open()">Load data</paper-button>
+    </span>
+    <span id="publish-container">
+      <paper-tooltip position="bottom" animation-delay="0" fit-to-visible-bounds>
+        Publish your embedding visualization and data
+      </paper-tooltip>
+      <paper-button id="host-embedding" class="ink-button" onclick="projectorConfigDialog.open()">Publish</paper-button>
+    </span>
+  </p>
+  <div>
+    <paper-dialog id="dataDialog" with-backdrop>
+      <h2>Load data from your computer</h2>
+      <paper-dialog-scrollable class="scrollable-container">
+        <div class="data-step" id="upload-tensors-step-container">
+          <div class="upload-step">
+            <div>
+                <b><span class="step-label">Step 1:</span> Load a TSV file of vectors.</b>
+            </div>
+          </div>
+          <div class="data-step-contents">
+            <div class="data-step-contents-contents">
+              Example of 3 vectors with dimension 4:
+              <div class="code">
+                0.1<span class="delimiter">\t</span>0.2<span class="delimiter">\t</span>0.5<span class="delimiter">\t</span>0.9<br/>
+                0.2<span class="delimiter">\t</span>0.1<span class="delimiter">\t</span>5.0<span class="delimiter">\t</span>0.2<br/>
+                0.4<span class="delimiter">\t</span>0.1<span class="delimiter">\t</span>7.0<span class="delimiter">\t</span>0.8
+              </div>
+            </div>
+            <div class="data-step-contents-upload">
+              <paper-button id="upload-tensors" title="Choose a TSV tensor file">Choose file</paper-button>
+              <input type="file" id="file" name="file"/>
+            </div>
+          </div>
+        </div>
+        <div class="data-step">
+          <div class="upload-step">
+            <div>
+                <span class="step-label" id="upload-metadata-label"><b>Step 2</b> (optional):</span> <b>Load a TSV file of metadata.</b>
+            </div>
+          </div>
+          <div class="data-step-contents">
+            <div class="data-step-contents-contents">
+              Example of 3 data points and 2 columns.<br/>
+              <i>Note: If there is more than one column, the first row will be parsed as column labels.</i>
+              <div class="code">
+                <b>Pokémon<span class="delimiter">\t</span>Species</b><br/>
+                Wartortle<span class="delimiter">\t</span>Turtle<br/>
+                Venusaur<span class="delimiter">\t</span>Seed<br/>
+                Charmeleon<span class="delimiter">\t</span>Flame
+              </div>
+            </div>
+            <div class="data-step-contents-upload">
+              <paper-button id="upload-metadata" title="Choose a TSV metadata file" class="ink-button">Choose file</paper-button>
+              <input type="file" id="file-metadata" name="file-metadata"/>
+            </div>
+          </div>
+        </div>
+      </paper-dialog-scrollable>
+      <div class="dismiss-dialog-note">Click outside to dismiss.</div>
+    </paper-dialog>
+    <paper-dialog id="projectorConfigDialog" with-backdrop>
+      <h2>Publish your embedding visualization and data</h2>
+      <paper-dialog-scrollable class="scrollable-container">
+        <div>
+          <p>
+            If you'd like to share your visualization with the world, follow these simple steps.
+            See <a target=_blank href="https://www.tensorflow.org/get_started/embedding_viz">this tutorial</a> for more.
+          </p>
+          <h4><span class="step-label">Step 1:</span> Make data public</h4>
+          <p>
+            Host tensors, metadata, sprite image, and bookmarks TSV files <i>publicly</i> on the web.
+          </p>
+          <p>
+            One option is using a <a target=_blank href="https://gist.github.com/">github gist</a>.
+            If you choose this approach, make sure to link directly to the raw file.
+          </p>
+        </div>
+        <div>
+          <h4><span class="step-label">Step 2:</span> Projector config</h4>
+          <div class="projector-config-options">
+            <i>Optional:</i>
+            <div class="config-checkbox">
+              <paper-checkbox id="config-metadata-checkbox" checked>Metadata</paper-checkbox>
+            </div>
+            <div class="config-checkbox">
+              <paper-checkbox id="config-sprite-checkbox">Sprite</paper-checkbox>
+            </div>
+            <div class="config-checkbox">
+              <paper-checkbox id="config-bookmarks-checkbox">Bookmarks</paper-checkbox>
+            </div>
+          </div>
+        </div>
+        <paper-textarea id="projector-config-template" label="template_projector_config.json"></paper-textarea>
+        <div>
+          <h4><span class="step-label">Step 3:</span> Host projector config</h4>
+          After you have hosted the projector config JSON file you built above, paste the URL to the config below.
+        </div>
+        <paper-input id="projector-config-url" label="Path to projector config"></paper-input>
+        <paper-input id="projector-share-url" label="Your shareable URL" readonly></paper-input>
+        <div id="projector-share-button-container">
+          <a target=_blank id="projector-share-url-link">
+            <paper-button title="Test your shareable URL" class="ink-button">Test your shareable URL</paper-button>
+          </a>
+        </div>
+      </paper-dialog-scrollable>
+      <div class="dismiss-dialog-note">Click outside to dismiss.</div>
+    </paper-dialog>
+  </div>
+  <div class="dirs">
+    <table>
+      <tr>
+        <td>Checkpoint:</td>
+        <td><span id="checkpoint-file"></span></td>
+      </tr>
+      <tr>
+        <td>Metadata:</td>
+        <td><span id="metadata-file"></span></td>
+      </tr>
+    </table>
+  </div>
+</div>
+<!-- Closing global template -->
+</template>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.ts
new file mode 100644
index 00000000000..a6847ed3c87
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-data-panel.ts
@@ -0,0 +1,497 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import * as d3 from 'd3';  // from //third_party/javascript/typings/d3_v4
+import {ColorOption, ColumnStats, SpriteAndMetadataInfo} from './data';
+import {DataProvider, EmbeddingInfo, parseRawMetadata, parseRawTensors, ProjectorConfig} from './data-provider';
+import * as util from './util';
+import {Projector} from './vz-projector';
+import {ColorLegendRenderInfo, ColorLegendThreshold} from './vz-projector-legend';
+// tslint:disable-next-line:no-unused-variable
+import {PolymerElement, PolymerHTMLElement} from './vz-projector-util';
+
+export let DataPanelPolymer = PolymerElement({
+  is: 'vz-projector-data-panel',
+  properties: {
+    selectedTensor: {type: String, observer: '_selectedTensorChanged'},
+    selectedRun: {type: String, observer: '_selectedRunChanged'},
+    selectedColorOptionName: {
+      type: String,
+      notify: true,
+      observer: '_selectedColorOptionNameChanged'
+    },
+    selectedLabelOption:
+        {type: String, notify: true, observer: '_selectedLabelOptionChanged'},
+    normalizeData: Boolean,
+    showForceCategoricalColorsCheckbox: Boolean
+  }
+});
+
+export class DataPanel extends DataPanelPolymer {
+  selectedLabelOption: string;
+  selectedColorOptionName: string;
+  showForceCategoricalColorsCheckbox: boolean;
+
+  private normalizeData: boolean;
+  private labelOptions: string[];
+  private colorOptions: ColorOption[];
+  forceCategoricalColoring: boolean = false;
+
+  private selectedTensor: string;
+  private selectedRun: string;
+  private dataProvider: DataProvider;
+  private tensorNames: {name: string, shape: number[]}[];
+  private runNames: string[];
+  private projector: Projector;
+  private projectorConfig: ProjectorConfig;
+  private colorLegendRenderInfo: ColorLegendRenderInfo;
+  private spriteAndMetadata: SpriteAndMetadataInfo;
+  private metadataFile: string;
+
+  ready() {
+    this.normalizeData = true;
+  }
+
+  initialize(projector: Projector, dp: DataProvider) {
+    this.projector = projector;
+    this.dataProvider = dp;
+    this.setupUploadButtons();
+
+    // Tell the projector whenever the data normalization changes.
+    // Unknown why, but the polymer checkbox button stops working as soon as
+    // you do d3.select() on it.
+    this.querySelector('#normalize-data-checkbox')
+        .addEventListener('change', () => {
+          this.projector.setNormalizeData(this.normalizeData);
+        });
+
+    let forceCategoricalColoringCheckbox =
+        this.querySelector('#force-categorical-checkbox');
+    forceCategoricalColoringCheckbox.addEventListener('change', () => {
+      this.setForceCategoricalColoring(
+          (forceCategoricalColoringCheckbox as HTMLInputElement).checked);
+    });
+
+    // Get all the runs.
+    this.dataProvider.retrieveRuns(runs => {
+      this.runNames = runs;
+      // Choose the first run by default.
+      if (this.runNames.length > 0) {
+        this.selectedRun = runs[0];
+      }
+    });
+  }
+
+  setForceCategoricalColoring(forceCategoricalColoring: boolean) {
+    this.forceCategoricalColoring = forceCategoricalColoring;
+    (this.querySelector('#force-categorical-checkbox') as HTMLInputElement)
+        .checked = this.forceCategoricalColoring;
+
+    this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile);
+
+    // The selected color option name doesn't change when we switch to using
+    // categorical coloring for stats with too many unique values, so we
+    // manually call this polymer observer so that we update the UI.
+    this._selectedColorOptionNameChanged();
+  }
+
+  getSeparatorClass(isSeparator: boolean): string {
+    return isSeparator ? 'separator' : null;
+  }
+
+  metadataChanged(
+      spriteAndMetadata: SpriteAndMetadataInfo, metadataFile: string) {
+    this.spriteAndMetadata = spriteAndMetadata;
+    this.metadataFile = metadataFile;
+
+    this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile);
+    this.selectedColorOptionName = this.colorOptions[0].name;
+  }
+
+  private addWordBreaks(longString: string): string {
+    if (longString == null) {
+      return '';
+    }
+    return longString.replace(/([\/=-_,])/g, '$1<wbr>');
+  }
+
+  private updateMetadataUI(columnStats: ColumnStats[], metadataFile: string) {
+    const metadataFileElement =
+        this.querySelector('#metadata-file') as HTMLSpanElement;
+    metadataFileElement.innerHTML = this.addWordBreaks(metadataFile);
+    metadataFileElement.title = metadataFile;
+
+    // Label by options.
+    let labelIndex = -1;
+    this.labelOptions = columnStats.map((stats, i) => {
+      // Make the default label by the first non-numeric column.
+      if (!stats.isNumeric && labelIndex === -1) {
+        labelIndex = i;
+      }
+      return stats.name;
+    });
+    this.selectedLabelOption = this.labelOptions[Math.max(0, labelIndex)];
+
+    // Color by options.
+    const standardColorOption: ColorOption[] = [
+      {name: 'No color map'},
+      // TODO(smilkov): Implement this.
+      // {name: 'Distance of neighbors',
+      //    desc: 'How far is each point from its neighbors'}
+    ];
+    const metadataColorOption: ColorOption[] =
+        columnStats
+            .filter(stats => {
+              return !stats.tooManyUniqueValues || stats.isNumeric;
+            })
+            .map(stats => {
+              let map;
+              let items: {label: string, count: number}[];
+              let thresholds: ColorLegendThreshold[];
+              let isCategorical =
+                  this.forceCategoricalColoring || !stats.tooManyUniqueValues;
+              if (isCategorical) {
+                const scale = d3.scaleOrdinal(d3.schemeCategory20);
+                let range = scale.range();
+                // Re-order the range.
+                let newRange = range.map((color, i) => {
+                  let index = (i * 3) % range.length;
+                  return range[index];
+                });
+                items = stats.uniqueEntries;
+                scale.range(newRange).domain(items.map(x => x.label));
+                map = scale;
+              } else {
+                thresholds = [
+                  {color: '#ffffdd', value: stats.min},
+                  {color: '#1f2d86', value: stats.max}
+                ];
+                map = d3.scaleLinear<string, string>()
+                          .domain(thresholds.map(t => t.value))
+                          .range(thresholds.map(t => t.color));
+              }
+              let desc = !isCategorical ? 'gradient' :
+                                          stats.uniqueEntries.length +
+                      ((stats.uniqueEntries.length > 20) ? ' non-unique' : '') +
+                      ' colors';
+              return {
+                name: stats.name,
+                desc: desc,
+                map: map,
+                items: items,
+                thresholds: thresholds,
+                tooManyUniqueValues: stats.tooManyUniqueValues
+              };
+            });
+
+    if (metadataColorOption.length > 0) {
+      // Add a separator line between built-in color maps
+      // and those based on metadata columns.
+      standardColorOption.push({name: 'Metadata', isSeparator: true});
+    }
+    this.colorOptions = standardColorOption.concat(metadataColorOption);
+  }
+
+  setNormalizeData(normalizeData: boolean) {
+    this.normalizeData = normalizeData;
+  }
+
+  _selectedTensorChanged() {
+    this.projector.updateDataSet(null, null, null);
+    if (this.selectedTensor == null) {
+      return;
+    }
+    this.dataProvider.retrieveTensor(
+        this.selectedRun, this.selectedTensor, ds => {
+          let metadataFile =
+              this.getEmbeddingInfoByName(this.selectedTensor).metadataPath;
+          this.dataProvider.retrieveSpriteAndMetadata(
+              this.selectedRun, this.selectedTensor, metadata => {
+                this.projector.updateDataSet(ds, metadata, metadataFile);
+              });
+        });
+    this.projector.setSelectedTensor(
+        this.selectedRun, this.getEmbeddingInfoByName(this.selectedTensor));
+  }
+
+  _selectedRunChanged() {
+    this.dataProvider.retrieveProjectorConfig(this.selectedRun, info => {
+      this.projectorConfig = info;
+      let names =
+          this.projectorConfig.embeddings.map(e => e.tensorName)
+              .filter(name => {
+                let shape = this.getEmbeddingInfoByName(name).tensorShape;
+                return shape.length === 2 && shape[0] > 1 && shape[1] > 1;
+              })
+              .sort((a, b) => {
+                let embA = this.getEmbeddingInfoByName(a);
+                let embB = this.getEmbeddingInfoByName(b);
+
+                // Prefer tensors with metadata.
+                if (util.xor(!!embA.metadataPath, !!embB.metadataPath)) {
+                  return embA.metadataPath ? -1 : 1;
+                }
+
+                // Prefer non-generated tensors.
+                let isGenA = util.tensorIsGenerated(a);
+                let isGenB = util.tensorIsGenerated(b);
+                if (util.xor(isGenA, isGenB)) {
+                  return isGenB ? -1 : 1;
+                }
+
+                // Prefer bigger tensors.
+                let sizeA = embA.tensorShape[0];
+                let sizeB = embB.tensorShape[0];
+                if (sizeA !== sizeB) {
+                  return sizeB - sizeA;
+                }
+
+                // Sort alphabetically by tensor name.
+                return a <= b ? -1 : 1;
+              });
+      this.tensorNames = names.map(name => {
+        return {name, shape: this.getEmbeddingInfoByName(name).tensorShape};
+      });
+      const wordBreakablePath =
+          this.addWordBreaks(this.projectorConfig.modelCheckpointPath);
+      const checkpointFile =
+          this.querySelector('#checkpoint-file') as HTMLSpanElement;
+      checkpointFile.innerHTML = wordBreakablePath;
+      checkpointFile.title = this.projectorConfig.modelCheckpointPath;
+
+      // If in demo mode, let the order decide which tensor to load by default.
+      const defaultTensor = this.projector.servingMode === 'demo' ?
+          this.projectorConfig.embeddings[0].tensorName :
+          names[0];
+      if (this.selectedTensor === defaultTensor) {
+        // Explicitly call the observer. Polymer won't call it if the previous
+        // string matches the current string.
+        this._selectedTensorChanged();
+      } else {
+        this.selectedTensor = defaultTensor;
+      }
+    });
+  }
+
+  _selectedLabelOptionChanged() {
+    this.projector.setSelectedLabelOption(this.selectedLabelOption);
+  }
+
+  _selectedColorOptionNameChanged() {
+    let colorOption: ColorOption;
+    for (let i = 0; i < this.colorOptions.length; i++) {
+      if (this.colorOptions[i].name === this.selectedColorOptionName) {
+        colorOption = this.colorOptions[i];
+        break;
+      }
+    }
+    if (!colorOption) {
+      return;
+    }
+
+    this.showForceCategoricalColorsCheckbox = !!colorOption.tooManyUniqueValues;
+
+    if (colorOption.map == null) {
+      this.colorLegendRenderInfo = null;
+    } else if (colorOption.items) {
+      let items = colorOption.items.map(item => {
+        return {
+          color: colorOption.map(item.label),
+          label: item.label,
+          count: item.count
+        };
+      });
+      this.colorLegendRenderInfo = {items, thresholds: null};
+    } else {
+      this.colorLegendRenderInfo = {
+        items: null,
+        thresholds: colorOption.thresholds
+      };
+    }
+    this.projector.setSelectedColorOption(colorOption);
+  }
+
+  private tensorWasReadFromFile(rawContents: ArrayBuffer, fileName: string) {
+    parseRawTensors(rawContents, ds => {
+      const checkpointFile =
+          this.querySelector('#checkpoint-file') as HTMLSpanElement;
+      checkpointFile.innerText = fileName;
+      checkpointFile.title = fileName;
+      this.projector.updateDataSet(ds);
+    });
+  }
+
+  private metadataWasReadFromFile(rawContents: ArrayBuffer, fileName: string) {
+    parseRawMetadata(rawContents, metadata => {
+      this.projector.updateDataSet(this.projector.dataSet, metadata, fileName);
+    });
+  }
+
+  private getEmbeddingInfoByName(tensorName: string): EmbeddingInfo {
+    for (let i = 0; i < this.projectorConfig.embeddings.length; i++) {
+      const e = this.projectorConfig.embeddings[i];
+      if (e.tensorName === tensorName) {
+        return e;
+      }
+    }
+  }
+
+  private setupUploadButtons() {
+    // Show and setup the upload button.
+    const fileInput = this.querySelector('#file') as HTMLInputElement;
+    fileInput.onchange = () => {
+      const file: File = fileInput.files[0];
+      // Clear out the value of the file chooser. This ensures that if the user
+      // selects the same file, we'll re-read it.
+      fileInput.value = '';
+      const fileReader = new FileReader();
+      fileReader.onload = evt => {
+        const content: ArrayBuffer = fileReader.result;
+        this.tensorWasReadFromFile(content, file.name);
+      };
+      fileReader.readAsArrayBuffer(file);
+    };
+
+    const uploadButton =
+        this.querySelector('#upload-tensors') as HTMLButtonElement;
+    uploadButton.onclick = () => {
+      fileInput.click();
+    };
+
+    // Show and setup the upload metadata button.
+    const fileMetadataInput =
+        this.querySelector('#file-metadata') as HTMLInputElement;
+    fileMetadataInput.onchange = () => {
+      const file: File = fileMetadataInput.files[0];
+      // Clear out the value of the file chooser. This ensures that if the user
+      // selects the same file, we'll re-read it.
+      fileMetadataInput.value = '';
+      const fileReader = new FileReader();
+      fileReader.onload = evt => {
+        const contents: ArrayBuffer = fileReader.result;
+        this.metadataWasReadFromFile(contents, file.name);
+      };
+      fileReader.readAsArrayBuffer(file);
+    };
+
+    const uploadMetadataButton =
+        this.querySelector('#upload-metadata') as HTMLButtonElement;
+    uploadMetadataButton.onclick = () => {
+      fileMetadataInput.click();
+    };
+
+    if (this.projector.servingMode !== 'demo') {
+      (this.$$('#publish-container') as HTMLElement).style.display = 'none';
+      (this.$$('#upload-tensors-step-container') as HTMLElement).style.display =
+          'none';
+      (this.$$('#upload-metadata-label') as HTMLElement).style.display = 'none';
+    }
+
+    (this.$$('#demo-data-buttons-container') as HTMLElement).style.display =
+        'block';
+
+    // Fill out the projector config.
+    const projectorConfigTemplate =
+        this.$$('#projector-config-template') as HTMLTextAreaElement;
+    const projectorConfigTemplateJson: ProjectorConfig = {
+      embeddings: [{
+        tensorName: 'My tensor',
+        tensorShape: [1000, 50],
+        tensorPath: 'https://raw.githubusercontent.com/.../tensors.tsv',
+        metadataPath:
+            'https://raw.githubusercontent.com/.../optional.metadata.tsv',
+      }],
+    };
+    this.setProjectorConfigTemplateJson(
+        projectorConfigTemplate, projectorConfigTemplateJson);
+
+    // Set up optional field checkboxes.
+    const spriteFieldCheckbox =
+        this.$$('#config-sprite-checkbox') as HTMLInputElement;
+    spriteFieldCheckbox.onchange = () => {
+      if ((spriteFieldCheckbox as any).checked) {
+        projectorConfigTemplateJson.embeddings[0].sprite = {
+          imagePath: 'https://github.com/.../optional.sprite.png',
+          singleImageDim: [32, 32]
+        };
+      } else {
+        delete projectorConfigTemplateJson.embeddings[0].sprite;
+      }
+      this.setProjectorConfigTemplateJson(
+          projectorConfigTemplate, projectorConfigTemplateJson);
+    };
+    const bookmarksFieldCheckbox =
+        this.$$('#config-bookmarks-checkbox') as HTMLInputElement;
+    bookmarksFieldCheckbox.onchange = () => {
+      if ((bookmarksFieldCheckbox as any).checked) {
+        projectorConfigTemplateJson.embeddings[0].bookmarksPath =
+            'https://raw.githubusercontent.com/.../bookmarks.txt';
+      } else {
+        delete projectorConfigTemplateJson.embeddings[0].bookmarksPath;
+      }
+      this.setProjectorConfigTemplateJson(
+          projectorConfigTemplate, projectorConfigTemplateJson);
+    };
+    const metadataFieldCheckbox =
+        this.$$('#config-metadata-checkbox') as HTMLInputElement;
+    metadataFieldCheckbox.onchange = () => {
+      if ((metadataFieldCheckbox as HTMLInputElement).checked) {
+        projectorConfigTemplateJson.embeddings[0].metadataPath =
+            'https://raw.githubusercontent.com/.../optional.metadata.tsv';
+      } else {
+        delete projectorConfigTemplateJson.embeddings[0].metadataPath;
+      }
+      this.setProjectorConfigTemplateJson(
+          projectorConfigTemplate, projectorConfigTemplateJson);
+    };
+
+    // Update the link and the readonly shareable URL.
+    const projectorConfigUrlInput =
+        this.$$('#projector-config-url') as HTMLInputElement;
+    const projectorConfigDemoUrlInput = this.$$('#projector-share-url');
+    const projectorConfigDemoUrlLink = this.$$('#projector-share-url-link');
+    projectorConfigUrlInput.onchange = () => {
+      let projectorDemoUrl = location.protocol + '//' + location.host +
+          location.pathname +
+          '?config=' + (projectorConfigUrlInput as HTMLInputElement).value;
+
+      (projectorConfigDemoUrlInput as HTMLInputElement).value =
+          projectorDemoUrl;
+      (projectorConfigDemoUrlLink as HTMLLinkElement).href = projectorDemoUrl;
+    };
+  }
+
+  private setProjectorConfigTemplateJson(
+      projectorConfigTemplate: HTMLTextAreaElement, config: ProjectorConfig) {
+    projectorConfigTemplate.value =
+        JSON.stringify(config, null, /** replacer */ 2 /** white space */);
+  }
+
+  _getNumTensorsLabel(): string {
+    return this.tensorNames.length === 1 ? '1 tensor' :
+                                           this.tensorNames.length + ' tensors';
+  }
+
+  _getNumRunsLabel(): string {
+    return this.runNames.length === 1 ? '1 run' :
+                                        this.runNames.length + ' runs';
+  }
+
+  _hasChoices(choices: any[]): boolean {
+    return choices.length > 1;
+  }
+}
+
+document.registerElement(DataPanel.prototype.is, DataPanel);
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.html
new file mode 100644
index 00000000000..e77694426eb
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.html
@@ -0,0 +1,64 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="../paper-input/paper-input.html">
+<link rel="import" href="../paper-button/paper-button.html">
+<link rel="import" href="../paper-tooltip/paper-tooltip.html">
+<link rel="import" href="styles.html">
+
+<dom-module id="vz-projector-input">
+<template>
+<style include="vz-projector-styles"></style>
+<style>
+.info {
+  color: rgba(0, 0, 0, 0.5);
+  display: block;
+  font-size: 11px;
+}
+
+.toggle {
+  font-size: 12px;
+  height: 21px;
+  margin: 0px;
+  min-width: 0px;
+  min-height: 0px;
+  padding: 0;
+  width: 17px;
+}
+
+.toggle[active] {
+  background-color: #880E4F;
+  color: white;
+}
+</style>
+
+<paper-input label="[[label]]">
+  <div class="slash" prefix>/</div>
+  <div class="slash" suffix>/</div>
+  <div suffix>
+    <paper-button id="regex" toggles class="toggle">.*</paper-button>
+  </div>
+</paper-input>
+<paper-tooltip for="regex" position="bottom" animation-delay="0" fit-to-visible-bounds>
+  Enable/disable regex mode.
+</paper-tooltip>
+<span class="info">[[message]]</span>
+
+<!-- Closing global template -->
+</template>
+</dom-module>
\ No newline at end of file
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.ts
new file mode 100644
index 00000000000..e11346d327f
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-input.ts
@@ -0,0 +1,113 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// tslint:disable-next-line:no-unused-variable
+import {PolymerElement, PolymerHTMLElement} from './vz-projector-util';
+
+// tslint:disable-next-line
+export let PolymerClass = PolymerElement(
+    {is: 'vz-projector-input', properties: {label: String, message: String}});
+
+export interface InputChangedListener {
+  (value: string, inRegexMode: boolean): void;
+}
+
+/** Input control with custom capabilities (e.g. regex). */
+export class ProjectorInput extends PolymerClass {
+  private textChangedListeners: InputChangedListener[];
+  private paperInput: HTMLInputElement;
+  private inRegexModeButton: HTMLButtonElement;
+  private inRegexMode: boolean;
+
+  /** Message that will be displayed at the bottom of the input control. */
+  message: string;
+
+  /** Subscribe to be called everytime the input changes. */
+  registerInputChangedListener(listener: InputChangedListener) {
+    this.textChangedListeners.push(listener);
+  }
+
+  ready() {
+    this.inRegexMode = false;
+    this.textChangedListeners = [];
+    this.paperInput = this.querySelector('paper-input') as HTMLInputElement;
+    this.inRegexModeButton =
+        this.querySelector('paper-button') as HTMLButtonElement;
+    this.paperInput.setAttribute('error-message', 'Invalid regex');
+
+    this.paperInput.addEventListener('input', () => {
+      this.onTextChanged();
+    });
+
+    this.paperInput.addEventListener('keydown', event => {
+      event.stopPropagation();
+    });
+
+    this.inRegexModeButton.addEventListener(
+        'click', () => this.onClickRegexModeButton());
+    this.updateRegexModeDisplaySlashes();
+    this.onTextChanged();
+  }
+
+  private onClickRegexModeButton() {
+    this.inRegexMode = (this.inRegexModeButton as any).active;
+    this.updateRegexModeDisplaySlashes();
+    this.onTextChanged();
+  }
+
+  private notifyInputChanged(value: string, inRegexMode: boolean) {
+    this.textChangedListeners.forEach(l => l(value, inRegexMode));
+  }
+
+  private onTextChanged() {
+    try {
+      if (this.inRegexMode) {
+        new RegExp(this.paperInput.value);
+      }
+    } catch (invalidRegexException) {
+      this.paperInput.setAttribute('invalid', 'true');
+      this.message = '';
+      this.notifyInputChanged(null, true);
+      return;
+    }
+    this.paperInput.removeAttribute('invalid');
+    this.notifyInputChanged(this.paperInput.value, this.inRegexMode);
+  }
+
+  private updateRegexModeDisplaySlashes() {
+    const slashes = this.paperInput.querySelectorAll('.slash');
+    const display = this.inRegexMode ? '' : 'none';
+
+    for (let i = 0; i < slashes.length; i++) {
+      (slashes[i] as HTMLDivElement).style.display = display;
+    }
+  }
+
+  getValue(): string {
+    return this.paperInput.value;
+  }
+
+  getInRegexMode(): boolean {
+    return this.inRegexMode;
+  }
+
+  set(value: string, inRegexMode: boolean) {
+    (this.inRegexModeButton as any).active = inRegexMode;
+    this.paperInput.value = value;
+    this.onClickRegexModeButton();
+  }
+}
+
+document.registerElement(ProjectorInput.prototype.is, ProjectorInput);
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.html
new file mode 100644
index 00000000000..7554c322cef
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.html
@@ -0,0 +1,240 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="../paper-slider/paper-slider.html">
+
+<link rel="import" href="styles.html">
+<link rel="import" href="vz-projector-input.html">
+
+<dom-module id="vz-projector-inspector-panel">
+<style include="vz-projector-styles"></style>
+<style>
+:host {
+   display: flex;
+   flex-direction: column;
+   /* Account for the bookmark pane at the bottom */
+   height: calc(100% - 55px);
+}
+
+.container {
+  display: block;
+  padding: 10px 20px 0 20px;
+}
+
+.buttons {
+  display: flex;
+  height: 60px;
+}
+
+.button {
+  margin-right: 10px;
+  border: none;
+  border-radius: 7px;
+  font-size: 13px;
+  padding: 10px;
+  background: #e3e3e3;
+}
+
+.button:last-child {
+  margin-right: 0;
+}
+
+.nn {
+  display: flex;
+  flex-direction: column;
+}
+
+.nn > * {
+  padding: 0 20px;
+}
+
+.nn-list {
+  overflow-y: auto;
+}
+
+.nn-list .neighbor {
+  font-size: 12px;
+  margin-bottom: 8px;
+}
+
+.nn-list .label-and-value {
+  display: flex;
+  justify-content: space-between;
+}
+
+.label {
+  overflow: hidden;
+  text-overflow: ellipsis;
+  white-space: nowrap;
+}
+
+.nn-list .value {
+  color: #666;
+  float: right;
+  font-weight: 300;
+  margin-left: 8px;
+}
+
+.nn-list .bar {
+  position: relative;
+  border-top: 1px solid rgba(0, 0, 0, 0.15);
+  margin: 2px 0;
+}
+
+.nn-list .bar .fill {
+  position: absolute;
+  top: -1px;
+  border-top: 1px solid white;
+}
+
+.nn-list .tick {
+  position: absolute;
+  top: 0px;
+  height: 3px;
+  border-left: 1px solid rgba(0, 0, 0, 0.15);
+}
+
+.nn-list .neighbor-link:hover {
+  cursor: pointer;
+}
+
+.search-by {
+  display: flex;
+}
+
+.search-by vz-projector-input {
+  width: 100%;
+}
+
+.search-by paper-dropdown-menu {
+  margin-left: 10px;
+  width: 100px;
+}
+
+.distance .options {
+  float: right;
+}
+
+.options a {
+  color: #727272;
+  font-size: 13px;
+  margin-left: 12px;
+  text-decoration: none;
+}
+
+.options a.selected {
+  color: #009EFE;
+}
+
+.neighbors {
+  margin-bottom: 30px;
+}
+
+.neighbors-options {
+  margin-top: 6px;
+}
+
+.neighbors-options .option-label, .distance .option-label {
+  color: #727272;
+  margin-right: 2px;
+  width: auto;
+}
+
+.num-neighbors-container {
+  display: inline-block;
+}
+
+#nn-slider {
+  margin: 0 -12px 0 10px;
+}
+
+.euclidian {
+  margin-right: 10px;
+}
+
+.matches-list {
+  padding: 0 20px;
+}
+
+.matches-list .row {
+  border-bottom: 1px solid #ddd;
+  cursor: pointer;
+  display: flex;
+  font-size: 12px;
+  margin: 5px 0;
+  padding: 4px 0;
+}
+
+.results {
+  display: flex;
+  flex-direction: column;
+}
+</style>
+<template>
+<div class="container">
+  <div class="buttons">
+    <button class="button reset-filter">Show All Data</button>
+    <button class="button set-filter">Isolate selection</button>
+    <button class="button clear-selection">Clear selection</button>
+  </div>
+  <div class="search-by">
+    <vz-projector-input id="search-box" label="Search"></vz-projector-input>
+    <paper-dropdown-menu no-animations label="by">
+      <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{selectedMetadataField}}">
+        <template is="dom-repeat" items="[[metadataFields]]">
+          <paper-item value="[[item]]" label="[[item]]">
+            [[item]]
+          </paper-item>
+        </template>
+      </paper-listbox>
+    </paper-dropdown-menu>
+  </div>
+</div>
+<div class="results">
+  <div class="nn" style="display: none">
+    <div class="neighbors">
+      <div class="neighbors-options">
+        <div class="slider num-nn">
+          <span class="option-label">neighbors</span>
+          <paper-icon-button icon="help" class="help-icon"></paper-icon-button>
+          <paper-tooltip position="bottom" animation-delay="0" fit-to-visible-bounds>
+            The number of neighbors (in the original space) to show when clicking on a point.
+          </paper-tooltip>
+          <paper-slider id="nn-slider" pin min="5" max="1000" value="100"></paper-slider>
+          <span class="nn-count"></span>
+        </div>
+      </div>
+      <div class="distance">
+        <span class="option-label">distance</span>
+        <div class="options">
+          <a class="selected cosine" href="javascript:void(0);">COSINE</a>
+          <a class="euclidean" href="javascript:void(0);">EUCLIDIAN</a>
+        </div>
+      </div>
+    </div>
+    <p>Nearest points in the original space:
+    <div class="nn-list"></div>
+  </div>
+  <div class="matches-list" style="display: none">
+    <div class="list"></div>
+    <div class="limit-msg">Showing only the first 100 results...</div>
+  </div>
+</div>
+<!-- Closing global template -->
+</template>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.ts
new file mode 100644
index 00000000000..3ee2c2165f2
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-inspector-panel.ts
@@ -0,0 +1,337 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {DistanceFunction, SpriteAndMetadataInfo, State} from './data';
+import * as knn from './knn';
+import {ProjectorEventContext} from './projectorEventContext';
+import * as adapter from './projectorScatterPlotAdapter';
+import * as util from './util';
+import * as vector from './vector';
+import {Projector} from './vz-projector';
+import {ProjectorInput} from './vz-projector-input';
+// tslint:disable-next-line:no-unused-variable
+import {PolymerElement, PolymerHTMLElement} from './vz-projector-util';
+
+/** Limit the number of search results we show to the user. */
+const LIMIT_RESULTS = 100;
+
+// tslint:disable-next-line
+export let PolymerClass = PolymerElement({
+  is: 'vz-projector-inspector-panel',
+  properties: {selectedMetadataField: String, metadataFields: Array}
+});
+
+export class InspectorPanel extends PolymerClass {
+  distFunc: DistanceFunction;
+  numNN: number;
+
+  private projectorEventContext: ProjectorEventContext;
+
+  private selectedMetadataField: string;
+  private metadataFields: string[];
+  private projector: Projector;
+  private selectedPointIndices: number[];
+  private neighborsOfFirstPoint: knn.NearestEntry[];
+  private searchBox: ProjectorInput;
+
+  private resetFilterButton: HTMLButtonElement;
+  private setFilterButton: HTMLButtonElement;
+  private clearSelectionButton: HTMLButtonElement;
+  private limitMessage: HTMLDivElement;
+
+  ready() {
+    this.resetFilterButton =
+        this.querySelector('.reset-filter') as HTMLButtonElement;
+    this.setFilterButton =
+        this.querySelector('.set-filter') as HTMLButtonElement;
+    this.clearSelectionButton =
+        this.querySelector('.clear-selection') as HTMLButtonElement;
+    this.limitMessage = this.querySelector('.limit-msg') as HTMLDivElement;
+    this.searchBox = this.querySelector('#search-box') as ProjectorInput;
+    // https://www.polymer-project.org/1.0/docs/devguide/styling#scope-subtree
+    this.scopeSubtree(this, true);
+  }
+
+  initialize(
+      projector: Projector, projectorEventContext: ProjectorEventContext) {
+    this.projector = projector;
+    this.projectorEventContext = projectorEventContext;
+    this.setupUI(projector);
+    projectorEventContext.registerSelectionChangedListener(
+        (selection, neighbors) =>
+            this.updateInspectorPane(selection, neighbors));
+  }
+
+  /** Updates the nearest neighbors list in the inspector. */
+  private updateInspectorPane(
+      indices: number[], neighbors: knn.NearestEntry[]) {
+    this.neighborsOfFirstPoint = neighbors;
+    this.selectedPointIndices = indices;
+
+    this.updateFilterButtons(indices.length + neighbors.length);
+    this.updateNeighborsList(neighbors);
+    if (neighbors.length === 0) {
+      this.updateSearchResults(indices);
+    } else {
+      this.updateSearchResults([]);
+    }
+  }
+
+  private enableResetFilterButton(enabled: boolean) {
+    this.resetFilterButton.disabled = !enabled;
+  }
+
+  restoreUIFromBookmark(bookmark: State) {
+    this.enableResetFilterButton(bookmark.filteredPoints != null);
+  }
+
+  metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) {
+    let labelIndex = -1;
+    this.metadataFields = spriteAndMetadata.stats.map((stats, i) => {
+      if (!stats.isNumeric && labelIndex === -1) {
+        labelIndex = i;
+      }
+      return stats.name;
+    });
+    labelIndex = Math.max(0, labelIndex);
+    // Make the default label the first non-numeric column.
+    this.selectedMetadataField = spriteAndMetadata.stats[labelIndex].name;
+  }
+
+  datasetChanged() {
+    this.enableResetFilterButton(false);
+  }
+
+  private updateSearchResults(indices: number[]) {
+    const container = this.querySelector('.matches-list') as HTMLDivElement;
+    container.style.display = indices.length ? null : 'none';
+    const list = container.querySelector('.list') as HTMLDivElement;
+    list.innerHTML = '';
+    if (indices.length === 0) {
+      return;
+    }
+
+    this.limitMessage.style.display =
+        indices.length <= LIMIT_RESULTS ? 'none' : null;
+    indices = indices.slice(0, LIMIT_RESULTS);
+
+    for (let i = 0; i < indices.length; i++) {
+      const index = indices[i];
+
+      const row = document.createElement('div');
+      row.className = 'row';
+
+      const label = this.getLabelFromIndex(index);
+      const rowLink = document.createElement('a');
+      rowLink.className = 'label';
+      rowLink.title = label;
+      rowLink.innerText = label;
+
+      rowLink.onmouseenter = () => {
+        this.projectorEventContext.notifyHoverOverPoint(index);
+      };
+      rowLink.onmouseleave = () => {
+        this.projectorEventContext.notifyHoverOverPoint(null);
+      };
+      rowLink.onclick = () => {
+        this.projectorEventContext.notifySelectionChanged([index]);
+      };
+
+      row.appendChild(rowLink);
+      list.appendChild(row);
+    }
+  }
+
+  private getLabelFromIndex(pointIndex: number): string {
+    const point = this.projector.dataSet.points[pointIndex];
+    return point.metadata[this.selectedMetadataField].toString();
+  }
+
+  private updateNeighborsList(neighbors: knn.NearestEntry[]) {
+    const nnlist = this.querySelector('.nn-list') as HTMLDivElement;
+    nnlist.innerHTML = '';
+
+    (this.querySelector('.nn') as HTMLDivElement).style.display =
+        neighbors.length ? null : 'none';
+
+    if (neighbors.length === 0) {
+      return;
+    }
+
+    this.searchBox.message = '';
+    const minDist = neighbors.length > 0 ? neighbors[0].dist : 0;
+
+    for (let i = 0; i < neighbors.length; i++) {
+      const neighbor = neighbors[i];
+
+      const neighborElement = document.createElement('div');
+      neighborElement.className = 'neighbor';
+
+      const neighborElementLink = document.createElement('a');
+      neighborElementLink.className = 'neighbor-link';
+      neighborElementLink.title = this.getLabelFromIndex(neighbor.index);
+
+      const labelValueElement = document.createElement('div');
+      labelValueElement.className = 'label-and-value';
+
+      const labelElement = document.createElement('div');
+      labelElement.className = 'label';
+      labelElement.style.color =
+          adapter.dist2color(this.distFunc, neighbor.dist, minDist);
+      labelElement.innerText = this.getLabelFromIndex(neighbor.index);
+
+      const valueElement = document.createElement('div');
+      valueElement.className = 'value';
+      valueElement.innerText = neighbor.dist.toFixed(3);
+
+      labelValueElement.appendChild(labelElement);
+      labelValueElement.appendChild(valueElement);
+
+      const barElement = document.createElement('div');
+      barElement.className = 'bar';
+
+      const barFillElement = document.createElement('div');
+      barFillElement.className = 'fill';
+      barFillElement.style.borderTopColor =
+          adapter.dist2color(this.distFunc, neighbor.dist, minDist);
+      barFillElement.style.width =
+          adapter.normalizeDist(this.distFunc, neighbor.dist, minDist) * 100 +
+          '%';
+      barElement.appendChild(barFillElement);
+
+      for (let j = 1; j < 4; j++) {
+        const tickElement = document.createElement('div');
+        tickElement.className = 'tick';
+        tickElement.style.left = j * 100 / 4 + '%';
+        barElement.appendChild(tickElement);
+      }
+
+      neighborElementLink.appendChild(labelValueElement);
+      neighborElementLink.appendChild(barElement);
+      neighborElement.appendChild(neighborElementLink);
+      nnlist.appendChild(neighborElement);
+
+      neighborElementLink.onmouseenter = () => {
+        this.projectorEventContext.notifyHoverOverPoint(neighbor.index);
+      };
+      neighborElementLink.onmouseleave = () => {
+        this.projectorEventContext.notifyHoverOverPoint(null);
+      };
+      neighborElementLink.onclick = () => {
+        this.projectorEventContext.notifySelectionChanged([neighbor.index]);
+      };
+    }
+  }
+
+  private updateFilterButtons(numPoints: number) {
+    if (numPoints > 1) {
+      this.setFilterButton.innerText = `Isolate ${numPoints} points`;
+      this.setFilterButton.disabled = null;
+      this.clearSelectionButton.disabled = null;
+    } else {
+      this.setFilterButton.disabled = true;
+      this.clearSelectionButton.disabled = true;
+    }
+  }
+
+  private setupUI(projector: Projector) {
+    this.distFunc = vector.cosDist;
+    const eucDist =
+        this.querySelector('.distance a.euclidean') as HTMLLinkElement;
+    eucDist.onclick = () => {
+      const links = this.querySelectorAll('.distance a');
+      for (let i = 0; i < links.length; i++) {
+        util.classed(links[i] as HTMLElement, 'selected', false);
+      }
+      util.classed(eucDist as HTMLElement, 'selected', true);
+
+      this.distFunc = vector.dist;
+      this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc);
+      const neighbors = projector.dataSet.findNeighbors(
+          this.selectedPointIndices[0], this.distFunc, this.numNN);
+      this.updateNeighborsList(neighbors);
+    };
+
+    const cosDist = this.querySelector('.distance a.cosine') as HTMLLinkElement;
+    cosDist.onclick = () => {
+      const links = this.querySelectorAll('.distance a');
+      for (let i = 0; i < links.length; i++) {
+        util.classed(links[i] as HTMLElement, 'selected', false);
+      }
+      util.classed(cosDist, 'selected', true);
+
+      this.distFunc = vector.cosDist;
+      this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc);
+      const neighbors = projector.dataSet.findNeighbors(
+          this.selectedPointIndices[0], this.distFunc, this.numNN);
+      this.updateNeighborsList(neighbors);
+    };
+
+    // Called whenever the search text input changes.
+    const updateInput = (value: string, inRegexMode: boolean) => {
+      if (value == null || value.trim() === '') {
+        this.searchBox.message = '';
+        this.projectorEventContext.notifySelectionChanged([]);
+        return;
+      }
+      const indices = projector.dataSet.query(
+          value, inRegexMode, this.selectedMetadataField);
+      if (indices.length === 0) {
+        this.searchBox.message = '0 matches.';
+      } else {
+        this.searchBox.message = `${indices.length} matches.`;
+      }
+      this.projectorEventContext.notifySelectionChanged(indices);
+    };
+    this.searchBox.registerInputChangedListener((value, inRegexMode) => {
+      updateInput(value, inRegexMode);
+    });
+
+    // Nearest neighbors controls.
+    const numNNInput = this.$$('#nn-slider') as HTMLInputElement;
+    const updateNumNN = () => {
+      this.numNN = +numNNInput.value;
+      (this.querySelector('.num-nn .nn-count') as HTMLSpanElement).innerText =
+          '' + this.numNN;
+      if (this.selectedPointIndices != null) {
+        this.projectorEventContext.notifySelectionChanged(
+            [this.selectedPointIndices[0]]);
+      }
+    };
+    numNNInput.addEventListener('change', updateNumNN);
+    updateNumNN();
+
+    // Filtering dataset.
+    this.setFilterButton.onclick = () => {
+      const indices = this.selectedPointIndices.concat(
+          this.neighborsOfFirstPoint.map(n => n.index));
+      projector.filterDataset(indices);
+      this.enableResetFilterButton(true);
+      this.updateFilterButtons(0);
+    };
+
+    this.resetFilterButton.onclick = () => {
+      projector.resetFilterDataset();
+      this.enableResetFilterButton(false);
+    };
+
+    this.clearSelectionButton.onclick = () => {
+      projector.adjustSelectionAndHover([]);
+    };
+    this.enableResetFilterButton(false);
+  }
+}
+
+document.registerElement(InspectorPanel.prototype.is, InspectorPanel);
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.html
new file mode 100644
index 00000000000..3fc5f4db158
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.html
@@ -0,0 +1,76 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="styles.html">
+
+<dom-module id='vz-projector-legend'>
+<template>
+<style include="vz-projector-styles"></style>
+<style>
+.item {
+  display: flex;
+  align-items: flex-start;
+  margin-bottom: 10px;
+}
+
+.shape {
+  width: 10px;
+  height: 10px;
+  margin-right: 10px;
+  margin-top: 5px;
+  border-radius: 50%;
+}
+
+.label {
+  flex-grow: 1;
+}
+
+.gradient {
+  width: 100%;
+  height: 10px;
+}
+
+.gradient-boundaries {
+  display: flex;
+  justify-content: space-between;
+}
+</style>
+
+<template is="dom-repeat" items="[[renderInfo.items]]">
+  <div class="item">
+    <div class="shape" style="background-color: [[item.color]];"></div>
+    <div class="label">[[item.label]]</div>
+    <div class="info" style="color: [[item.color]];">[[item.count]]</div>
+  </div>
+</template>
+
+<template is="dom-if" if="[[renderInfo.thresholds]]">
+  <svg class="gradient">
+    <defs>
+      <linearGradient id="gradient" x1="0%" y1="100%" x2="100%" y2="100%"></linearGradient>
+    </defs>
+    <rect height="10" style="fill: url('#gradient');"></rect>
+  </svg>
+  <div class="gradient-boundaries">
+    <div>[[renderInfo.thresholds.0.value]]</div>
+    <div>[[_getLastThreshold(renderInfo.thresholds)]]</div>
+  </div>
+</template>
+<!-- Closing global template -->
+</template>
+</dom-module>
\ No newline at end of file
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.ts
new file mode 100644
index 00000000000..1c4ddf940dc
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-legend.ts
@@ -0,0 +1,98 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// tslint:disable-next-line:no-unused-variable
+import {PolymerElement, PolymerHTMLElement} from './vz-projector-util';
+
+// tslint:disable-next-line
+export let LegendPolymer = PolymerElement({
+  is: 'vz-projector-legend',
+  properties: {renderInfo: {type: Object, observer: '_renderInfoChanged'}}
+});
+
+export interface ColorLegendRenderInfo {
+  // To be used for categorical map.
+  items: ColorLegendItem[];
+  // To be used for gradient map.
+  thresholds: ColorLegendThreshold[];
+}
+
+/** An item in the categorical color legend. */
+export interface ColorLegendItem {
+  color: string;
+  label: string;
+  count: number;
+}
+
+/** An item in the gradient color legend. */
+export interface ColorLegendThreshold {
+  color: string;
+  value: number;
+}
+
+export class Legend extends LegendPolymer {
+  renderInfo: ColorLegendRenderInfo;
+
+  _renderInfoChanged() {
+    if (this.renderInfo == null) {
+      return;
+    }
+    if (this.renderInfo.thresholds) {
+      // <linearGradient> is under dom-if so we should wait for it to be
+      // inserted in the dom tree using async().
+      this.async(() => this.setupLinearGradient());
+    }
+  }
+
+  _getLastThreshold(): number {
+    if (this.renderInfo == null || this.renderInfo.thresholds == null) {
+      return;
+    }
+    return this.renderInfo.thresholds[this.renderInfo.thresholds.length - 1]
+        .value;
+  }
+
+  private getOffset(value: number): string {
+    const min = this.renderInfo.thresholds[0].value;
+    const max =
+        this.renderInfo.thresholds[this.renderInfo.thresholds.length - 1].value;
+    return (100 * (value - min) / (max - min)).toFixed(2) + '%';
+  }
+
+  private setupLinearGradient() {
+    const linearGradient =
+        this.querySelector('#gradient') as SVGLinearGradientElement;
+
+    const width =
+        (this.querySelector('svg.gradient') as SVGElement).clientWidth;
+
+    // Set the svg <rect> to be the width of its <svg> parent.
+    (this.querySelector('svg.gradient rect') as SVGRectElement).style.width =
+        width + 'px';
+
+    // Remove all <stop> children from before.
+    linearGradient.innerHTML = '';
+
+    // Add a <stop> child in <linearGradient> for each gradient threshold.
+    this.renderInfo.thresholds.forEach(t => {
+      const stopElement =
+          document.createElementNS('http://www.w3.org/2000/svg', 'stop');
+      stopElement.setAttribute('offset', this.getOffset(t.value));
+      stopElement.setAttribute('stop-color', t.color);
+    });
+  }
+}
+
+document.registerElement(Legend.prototype.is, Legend);
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.html
new file mode 100644
index 00000000000..ebdcd72c77d
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.html
@@ -0,0 +1,97 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="../iron-collapse/iron-collapse.html">
+<link rel="import" href="../paper-icon-button/paper-icon-button.html">
+
+<dom-module id="vz-projector-metadata-card">
+<template>
+<style>
+#metadata-card {
+  background-color: rgba(255,255,255,0.9);
+  box-shadow: 0 2px 2px 0 rgba(0, 0, 0, 0.14),
+      0 1px 5px 0 rgba(0, 0, 0, 0.12), 0 3px 1px -2px rgba(0, 0, 0, 0.2);
+  width: 280px;
+}
+
+#header {
+  background: #e9e9e9;
+}
+
+#icon-container {
+  position: absolute;
+  right: 0;
+  top: 4px;
+}
+
+#metadata-label {
+  font-weight: 400;
+  font-size: 14px;
+  line-height: 24px;
+  padding: 12px 12px 8px;
+  width: 230px;
+}
+
+#metadata-table {
+  display: table;
+  padding: 8px 12px 4px;
+}
+
+.metadata-row {
+  display: table-row;
+}
+
+.metadata-key {
+  font-weight: bold;
+}
+
+.metadata-key, .metadata-value {
+  display: table-cell;
+  font-size: 12px;
+  padding: 3px 3px;
+}
+</style>
+
+<template is="dom-if" if="[[hasMetadata]]">
+  <div id="metadata-card">
+    <div id="icon-container">
+      <paper-icon-button id="expand-more"
+          style="display: none"
+          icon="expand-more"
+          on-tap="_expandMore"></paper-icon-button>
+      <paper-icon-button id="expand-less"
+          on-tap="_expandLess"
+          icon="expand-less"></paper-icon-button>
+    </div>
+    <div id="header">
+      <div id="metadata-label">[[label]]</div>
+    </div>
+    <iron-collapse id="metadata-container" opened>
+      <div id="metadata-table">
+        <template is="dom-repeat" items="[[metadata]]">
+          <div class="metadata-row">
+            <div class="metadata-key">[[item.key]]</div>
+            <div class="metadata-value">[[item.value]]</div>
+          </div>
+        </template>
+      </div>
+    </iron-collapse>
+  </div>
+</template>
+</template>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.ts
new file mode 100644
index 00000000000..939300f3878
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-metadata-card.ts
@@ -0,0 +1,88 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {PointMetadata} from './data';
+// tslint:disable-next-line:no-unused-variable
+import {PolymerElement, PolymerHTMLElement} from './vz-projector-util';
+
+// tslint:disable-next-line
+export let MetadataCardPolymer = PolymerElement({
+  is: 'vz-projector-metadata-card',
+  properties: {
+    hasMetadata: {type: Boolean, value: false},
+    metadata: {type: Array},
+    label: String
+  }
+});
+
+export class MetadataCard extends MetadataCardPolymer {
+  hasMetadata: boolean;
+  metadata: Array<{key: string, value: string}>;
+  label: string;
+
+  private labelOption: string;
+  private pointMetadata: PointMetadata;
+
+  private expandLessButton: HTMLButtonElement;
+  private expandMoreButton: HTMLButtonElement;
+
+  ready() {
+    this.expandLessButton =
+        this.querySelector('#expand-less') as HTMLButtonElement;
+    this.expandMoreButton =
+        this.querySelector('#expand-more') as HTMLButtonElement;
+  }
+  /** Handles a click on the expand more icon. */
+  _expandMore() {
+    (this.$$('#metadata-container') as any).toggle();
+
+    this.expandMoreButton.style.display = 'none';
+    this.expandLessButton.style.display = '';
+  }
+
+  /** Handles a click on the expand less icon. */
+  _expandLess() {
+    (this.$$('#metadata-container') as any).toggle();
+    this.expandMoreButton.style.display = '';
+    this.expandLessButton.style.display = 'none';
+  }
+
+  updateMetadata(pointMetadata?: PointMetadata) {
+    this.pointMetadata = pointMetadata;
+    this.hasMetadata = (pointMetadata != null);
+
+    if (pointMetadata) {
+      let metadata = [];
+      for (let metadataKey in pointMetadata) {
+        if (!pointMetadata.hasOwnProperty(metadataKey)) {
+          continue;
+        }
+        metadata.push({key: metadataKey, value: pointMetadata[metadataKey]});
+      }
+
+      this.metadata = metadata;
+      this.label = '' + this.pointMetadata[this.labelOption];
+    }
+  }
+
+  setLabelOption(labelOption: string) {
+    this.labelOption = labelOption;
+    if (this.pointMetadata) {
+      this.label = '' + this.pointMetadata[this.labelOption];
+    }
+  }
+}
+
+document.registerElement(MetadataCard.prototype.is, MetadataCard);
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.html
new file mode 100644
index 00000000000..cddcb2b7d08
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.html
@@ -0,0 +1,314 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="../iron-collapse/iron-collapse.html">
+<link rel="import" href="../paper-dropdown-menu/paper-dropdown-menu.html">
+<link rel="import" href="../paper-toggle-button/paper-toggle-button.html">
+<link rel="import" href="../paper-listbox/paper-listbox.html">
+<link rel="import" href="../paper-item/paper-item.html">
+<link rel="import" href="../paper-checkbox/paper-checkbox.html">
+<link rel="import" href="../iron-icons/iron-icons.html">
+<link rel="import" href="../iron-icons/image-icons.html">
+<link rel="import" href="../paper-icon-button/paper-icon-button.html">
+<link rel="import" href="../paper-tooltip/paper-tooltip.html">
+<link rel="import" href="../paper-input/paper-input.html">
+<link rel="import" href="../paper-button/paper-button.html">
+<link rel="import" href="../paper-slider/paper-slider.html">
+<link rel="import" href="styles.html">
+
+<dom-module id="vz-projector-projections-panel">
+<template>
+<style include="vz-projector-styles"></style>
+<style>
+:host {
+  transition: height 0.2s;
+}
+
+.ink-button, ::shadow .ink-button {
+  border: none;
+  border-radius: 2px;
+  font-size: 13px;
+  padding: 10px;
+  min-width: 100px;
+  flex-shrink: 0;
+  background: #e3e3e3;
+}
+
+.ink-panel-buttons {
+  margin-bottom: 10px;
+}
+
+.two-way-toggle {
+  display: flex;
+  flex-direction: row;
+}
+
+.two-way-toggle span {
+  padding-right: 7px;
+}
+
+.has-border {
+  border: 1px solid rgba(0, 0, 0, 0.1);
+}
+
+.toggle {
+  min-width: 0px;
+  font-size: 12px;
+  width: 17px;
+  min-height: 0px;
+  height: 21px;
+  padding: 0;
+  margin: 0px;
+}
+
+.toggle[active] {
+  background-color: #880E4F;
+  color: white;
+}
+
+.two-columns {
+  display:flex;
+  justify-content: space-between;
+}
+
+.two-columns > :first-child {
+  margin-right: 15px;
+}
+
+.two-columns > div {
+  width: 50%;
+}
+
+.dropdown-item {
+  justify-content: space-between;
+  min-height: 35px;
+}
+
+#z-container {
+  display: flex;
+  align-items: center;
+  width: 50%;
+}
+
+#z-checkbox {
+  margin: 27px 0 0 5px;
+  width: 18px;
+}
+
+#z-dropdown {
+  flex-grow: 1;
+}
+
+.notice {
+  color: #880E4F;
+}
+
+.container {
+  padding: 20px;
+}
+
+.book-icon {
+  height: 20px;
+  color: rgba(0, 0, 0, 0.7);
+}
+
+.item-details {
+  color: gray;
+  font-size: 12px;
+  margin-left: 5px;
+}
+
+.pca-dropdown {
+  width: 100%;
+}
+
+.pca-dropdown paper-listbox {
+  width: 135px;
+}
+
+.dropdown-item.header {
+  border-bottom: 1px solid #aaa;
+  color: #333;
+  font-weight: bold;
+}
+
+#total-variance {
+  color: rgba(0, 0, 0, 0.7);
+}
+</style>
+<div id="main">
+  <div class="ink-panel-header">
+    <div class="ink-tab-group">
+
+      <div data-tab="tsne" id="tsne-tab" class="ink-tab projection-tab">t-SNE</div>
+      <paper-tooltip for="tsne-tab" position="bottom" animation-delay="0" fit-to-visible-bounds>
+        t-distributed stochastic neighbor embedding
+      </paper-tooltip>
+
+      <div data-tab="pca" id="pca-tab" class="ink-tab projection-tab">PCA</div>
+      <paper-tooltip for="pca-tab" position="bottom" animation-delay="0" fit-to-visible-bounds>
+        Principal component analysis
+      </paper-tooltip>
+
+      <div data-tab="custom" id="custom-tab" class="ink-tab projection-tab" title="Linear projection of two custom vectors">Custom</div>
+      <paper-tooltip for="custom-tab" position="bottom" animation-delay="0" fit-to-visible-bounds>
+        Search for two vectors upon which to project all points.
+      </paper-tooltip>
+
+    </div>
+  </div>
+  <div class="container">
+    <!-- TSNE Controls -->
+    <div data-panel="tsne" class="ink-panel-content">
+      <div class="slider">
+        <label>Dimension</label>
+        <div class="two-way-toggle">
+          <span>2D</span>
+          <paper-toggle-button id="tsne-toggle" checked="{{tSNEis3d}}">3D</paper-toggle-button>
+        </div>
+      </div>
+      <div class="slider tsne-perplexity">
+        <label>
+          Perplexity
+          <paper-icon-button icon="help" class="help-icon"></paper-icon-button>
+          <paper-tooltip position="right" animation-delay="0" fit-to-visible-bounds>
+            The most appropriate perplexity value depends on the density of the
+            data. Loosely speaking, a larger / denser dataset
+            requires a larger perplexity. Typical values for perplexity range
+            between 5 and 50.
+          </paper-tooltip>
+        </label>
+        <paper-slider id="perplexity-slider" pin min="2" max="100" value="30"></paper-slider>
+        <span></span>
+      </div>
+      <div class="slider tsne-learning-rate">
+        <label>
+          Learning rate
+          <paper-icon-button icon="help" class="help-icon"></paper-icon-button>
+          <paper-tooltip position="right" animation-delay="0" fit-to-visible-bounds>
+            The ideal learning rate often depends on the size of the data,
+            with smaller datasets requiring smaller learning rates.
+          </paper-tooltip>
+        </label>
+        <paper-slider id="learning-rate-slider" snaps min="-3" max="2" step="1"
+            value="1" max-markers="6">
+        </paper-slider>
+        <span></span>
+      </div>
+      <p>
+        <button class="run-tsne ink-button" title="Re-run t-SNE">Re-run</button>
+        <button class="stop-tsne ink-button" title="Stop t-SNE">Stop</button>
+      </p>
+      <p>Iteration: <span class="run-tsne-iter">0</span></p>
+      <p id="tsne-sampling" class="notice">
+        For fast results, the data will be sampled down to [[getTsneSampleSizeText()]] points.
+      </p>
+      <p>
+        <iron-icon icon="book" class="book-icon"></iron-icon>
+        <a target="_blank" href="http://distill.pub/2016/misread-tsne/">
+          How to use t-SNE effectively.
+        </a>
+      </p>
+    </div>
+    <!-- PCA Controls -->
+    <div data-panel="pca" class="ink-panel-content">
+      <div class="two-columns">
+        <div> <!-- Left column -->
+          <paper-dropdown-menu class="pca-dropdown" vertical-align="bottom" no-animations label="X">
+            <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{pcaX}}">
+              <paper-item disabled class="dropdown-item header">
+                  <div>#</div>
+                  <div>Variance (%)</div>
+              </paper-item>
+              <template is="dom-repeat" items="[[pcaComponents]]">
+                <paper-item class="dropdown-item" value="[[item.id]]"
+                            label="Component #[[item.componentNumber]]">
+                  <div>[[item.componentNumber]]</div>
+                  <div class="item-details">[[item.percVariance]]</div>
+                </paper-item>
+              </template>
+            </paper-listbox>
+          </paper-dropdown-menu>
+          <paper-dropdown-menu class="pca-dropdown" no-animations vertical-align="bottom" label="Z" disabled="[[!hasPcaZ]]" id="z-dropdown">
+            <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{pcaZ}}">
+              <paper-item disabled class="dropdown-item header">
+                  <div>#</div>
+                  <div>Variance (%)</div>
+              </paper-item>
+              <template is="dom-repeat" items="[[pcaComponents]]">
+                <paper-item class="dropdown-item" value="[[item.id]]"
+                            label="Component #[[item.componentNumber]]">
+                  <div>[[item.componentNumber]]</div>
+                  <div class="item-details">[[item.percVariance]]</div>
+                </paper-item>
+              </template>
+            </paper-listbox>
+          </paper-dropdown-menu>
+        </div>
+        <div> <!-- Right column -->
+          <paper-dropdown-menu class="pca-dropdown" vertical-align="bottom" no-animations label="Y">
+            <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{pcaY}}">
+              <paper-item disabled class="dropdown-item header">
+                  <div>#</div>
+                  <div>Variance (%)</div>
+              </paper-item>
+              <template is="dom-repeat" items="[[pcaComponents]]">
+                <paper-item class="dropdown-item" value="[[item.id]]"
+                            label="Component #[[item.componentNumber]]">
+                  <div>[[item.componentNumber]]</div>
+                  <div class="item-details">[[item.percVariance]]</div>
+                </paper-item>
+              </template>
+            </paper-listbox>
+          </paper-dropdown-menu>
+          <paper-checkbox id="z-checkbox" checked="{{pcaIs3d}}"></paper-checkbox>
+        </div>
+      </div>
+      <p id="pca-sampling" class="notice">
+        PCA is approximate.
+        <paper-icon-button icon="help" class="help-icon"></paper-icon-button>
+      </p>
+      <div id="total-variance">Total variance</div>
+      <paper-tooltip for="pca-sampling" position="top" animation-delay="0" fit-to-visible-bounds>
+        For fast results, the data was sampled to [[getPcaSampleSizeText()]] points and randomly projected down to [[getPcaSampledDimText()]] dimensions.
+      </paper-tooltip>
+    </div>
+    <!-- Custom Controls -->
+    <div data-panel="custom" class="ink-panel-content">
+      <paper-dropdown-menu style="width: 100%" no-animations label="Search by">
+        <paper-listbox attr-for-selected="value" class="dropdown-content" selected="{{customSelectedSearchByMetadataOption}}">
+          <template is="dom-repeat" items="[[searchByMetadataOptions]]">
+            <paper-item class="dropdown-item" value="[[item]]" label="[[item]]">
+              [[item]]
+            </paper-item>
+          </template>
+        </paper-listbox>
+      </paper-dropdown-menu>
+      <div class="two-columns">
+        <vz-projector-input id="xLeft" label="Left"></vz-projector-input>
+        <vz-projector-input id="xRight" label="Right"></vz-projector-input>
+      </div>
+      <div class="two-columns">
+        <vz-projector-input id="yUp" label="Up"></vz-projector-input>
+        <vz-projector-input id="yDown" label="Down"></vz-projector-input>
+      </div>
+    </div>
+  </div>
+</div>
+</template>
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.ts
new file mode 100644
index 00000000000..377c6c11ad5
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel.ts
@@ -0,0 +1,589 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import * as data from './data';
+import {DataSet, Projection, ProjectionType, SpriteAndMetadataInfo, State} from './data';
+import * as util from './util';
+import * as vector from './vector';
+import {Vector} from './vector';
+import {Projector} from './vz-projector';
+import {ProjectorInput} from './vz-projector-input';
+// tslint:disable-next-line:no-unused-variable
+import {PolymerElement, PolymerHTMLElement} from './vz-projector-util';
+
+const NUM_PCA_COMPONENTS = 10;
+
+// tslint:disable-next-line
+export let ProjectionsPanelPolymer = PolymerElement({
+  is: 'vz-projector-projections-panel',
+  properties: {
+    pcaIs3d:
+        {type: Boolean, value: true, observer: '_pcaDimensionToggleObserver'},
+    tSNEis3d:
+        {type: Boolean, value: true, observer: '_tsneDimensionToggleObserver'},
+    // PCA projection.
+    pcaComponents: Array,
+    pcaX: {type: Number, value: 0, observer: 'showPCAIfEnabled'},
+    pcaY: {type: Number, value: 1, observer: 'showPCAIfEnabled'},
+    pcaZ: {type: Number, value: 2, observer: 'showPCAIfEnabled'},
+    // Custom projection.
+    customSelectedSearchByMetadataOption: {
+      type: String,
+      observer: '_customSelectedSearchByMetadataOptionChanged'
+    },
+  }
+});
+
+type InputControlName = 'xLeft'|'xRight'|'yUp'|'yDown';
+
+type CentroidResult = {
+  centroid?: Vector; numMatches?: number;
+};
+
+type Centroids = {
+  [key: string]: Vector; xLeft: Vector; xRight: Vector; yUp: Vector;
+  yDown: Vector;
+};
+
+/**
+ * A polymer component which handles the projection tabs in the projector.
+ */
+export class ProjectionsPanel extends ProjectionsPanelPolymer {
+  private projector: Projector;
+  private pcaComponents:
+      Array<{id: number, componentNumber: number, percVariance: string}>;
+  private currentProjection: ProjectionType;
+  private polymerChangesTriggerReprojection: boolean;
+  private dataSet: DataSet;
+  private originalDataSet: DataSet;
+  private dim: number;
+
+  /** T-SNE perplexity. Roughly how many neighbors each point influences. */
+  private perplexity: number;
+  /** T-SNE learning rate. */
+  private learningRate: number;
+
+  private searchByMetadataOptions: string[];
+
+  /** Centroids for custom projections. */
+  private centroidValues: any;
+  private centroids: Centroids;
+  /** The centroid across all points. */
+  private allCentroid: number[];
+
+  /** Polymer properties. */
+  // TODO(nsthorat): Move these to a separate view controller.
+  public tSNEis3d: boolean;
+  public pcaIs3d: boolean;
+  public pcaX: number;
+  public pcaY: number;
+  public pcaZ: number;
+  public customSelectedSearchByMetadataOption: string;
+
+  /** Polymer elements. */
+  private runTsneButton: HTMLButtonElement;
+  private stopTsneButton: HTMLButtonElement;
+  private perplexitySlider: HTMLInputElement;
+  private learningRateInput: HTMLInputElement;
+  private zDropdown: HTMLElement;
+  private iterationLabel: HTMLElement;
+
+  private customProjectionXLeftInput: ProjectorInput;
+  private customProjectionXRightInput: ProjectorInput;
+  private customProjectionYUpInput: ProjectorInput;
+  private customProjectionYDownInput: ProjectorInput;
+
+  initialize(projector: Projector) {
+    this.polymerChangesTriggerReprojection = true;
+    this.projector = projector;
+
+    // Set up TSNE projections.
+    this.perplexity = 30;
+    this.learningRate = 10;
+
+    // Setup Custom projections.
+    this.centroidValues = {xLeft: null, xRight: null, yUp: null, yDown: null};
+    this.clearCentroids();
+
+    this.setupUIControls();
+  }
+
+  ready() {
+    this.zDropdown = this.querySelector('#z-dropdown') as HTMLElement;
+    this.runTsneButton = this.querySelector('.run-tsne') as HTMLButtonElement;
+    this.stopTsneButton = this.querySelector('.stop-tsne') as HTMLButtonElement;
+    this.perplexitySlider =
+        this.querySelector('#perplexity-slider') as HTMLInputElement;
+    this.learningRateInput =
+        this.querySelector('#learning-rate-slider') as HTMLInputElement;
+    this.iterationLabel = this.querySelector('.run-tsne-iter') as HTMLElement;
+  }
+
+  disablePolymerChangesTriggerReprojection() {
+    this.polymerChangesTriggerReprojection = false;
+  }
+
+  enablePolymerChangesTriggerReprojection() {
+    this.polymerChangesTriggerReprojection = true;
+  }
+
+  private updateTSNEPerplexityFromSliderChange() {
+    if (this.perplexitySlider) {
+      this.perplexity = +this.perplexitySlider.value;
+    }
+    (this.querySelector('.tsne-perplexity span') as HTMLSpanElement).innerText =
+        '' + this.perplexity;
+  }
+
+  private updateTSNELearningRateFromUIChange() {
+    if (this.learningRateInput) {
+      this.learningRate = Math.pow(10, +this.learningRateInput.value);
+    }
+    (this.querySelector('.tsne-learning-rate span') as HTMLSpanElement)
+        .innerText = '' + this.learningRate;
+  }
+
+  private setupUIControls() {
+    {
+      const self = this;
+      const inkTabs = this.querySelectorAll('.ink-tab');
+      for (let i = 0; i < inkTabs.length; i++) {
+        inkTabs[i].addEventListener('click', function() {
+          let id = this.getAttribute('data-tab');
+          self.showTab(id);
+        });
+      }
+    }
+
+    this.runTsneButton.addEventListener('click', () => this.runTSNE());
+    this.stopTsneButton.addEventListener(
+        'click', () => this.dataSet.stopTSNE());
+
+    this.perplexitySlider.value = this.perplexity.toString();
+    this.perplexitySlider.addEventListener(
+        'change', () => this.updateTSNEPerplexityFromSliderChange());
+    this.updateTSNEPerplexityFromSliderChange();
+
+    this.learningRateInput.addEventListener(
+        'change', () => this.updateTSNELearningRateFromUIChange());
+    this.updateTSNELearningRateFromUIChange();
+
+    this.setupCustomProjectionInputFields();
+    // TODO: figure out why `--paper-input-container-input` css mixin didn't
+    // work.
+    const inputs =
+        this.querySelectorAll('paper-dropdown-menu paper-input input');
+    for (let i = 0; i < inputs.length; i++) {
+      (inputs[i] as HTMLElement).style.fontSize = '14px';
+    }
+  }
+
+  restoreUIFromBookmark(bookmark: State) {
+    this.disablePolymerChangesTriggerReprojection();
+
+    // PCA
+    this.pcaX = bookmark.pcaComponentDimensions[0];
+    this.pcaY = bookmark.pcaComponentDimensions[1];
+    if (bookmark.pcaComponentDimensions.length === 3) {
+      this.pcaZ = bookmark.pcaComponentDimensions[2];
+    }
+    this.pcaIs3d = (bookmark.pcaComponentDimensions.length === 3);
+
+    // t-SNE
+    if (this.perplexitySlider) {
+      this.perplexitySlider.value = bookmark.tSNEPerplexity.toString();
+    }
+    if (this.learningRateInput) {
+      this.learningRateInput.value = bookmark.tSNELearningRate.toString();
+    }
+    this.tSNEis3d = bookmark.tSNEis3d;
+
+    // custom
+    this.customSelectedSearchByMetadataOption =
+        bookmark.customSelectedSearchByMetadataOption;
+    if (this.customProjectionXLeftInput) {
+      this.customProjectionXLeftInput.set(
+          bookmark.customXLeftText, bookmark.customXLeftRegex);
+    }
+    if (this.customProjectionXRightInput) {
+      this.customProjectionXRightInput.set(
+          bookmark.customXRightText, bookmark.customXRightRegex);
+    }
+    if (this.customProjectionYUpInput) {
+      this.customProjectionYUpInput.set(
+          bookmark.customYUpText, bookmark.customYUpRegex);
+    }
+    if (this.customProjectionYDownInput) {
+      this.customProjectionYDownInput.set(
+          bookmark.customYDownText, bookmark.customYDownRegex);
+    }
+    this.computeAllCentroids();
+
+    this.setZDropdownEnabled(this.pcaIs3d);
+    this.updateTSNEPerplexityFromSliderChange();
+    this.updateTSNELearningRateFromUIChange();
+    if (this.iterationLabel) {
+      this.iterationLabel.innerText = bookmark.tSNEIteration.toString();
+    }
+    if (bookmark.selectedProjection != null) {
+      this.showTab(bookmark.selectedProjection);
+    }
+    this.enablePolymerChangesTriggerReprojection();
+  }
+
+  populateBookmarkFromUI(bookmark: State) {
+    this.disablePolymerChangesTriggerReprojection();
+
+    // PCA
+    bookmark.pcaComponentDimensions = [this.pcaX, this.pcaY];
+    if (this.pcaIs3d) {
+      bookmark.pcaComponentDimensions.push(this.pcaZ);
+    }
+
+    // t-SNE
+    if (this.perplexitySlider != null) {
+      bookmark.tSNEPerplexity = +this.perplexitySlider.value;
+    }
+    if (this.learningRateInput != null) {
+      bookmark.tSNELearningRate = +this.learningRateInput.value;
+    }
+    bookmark.tSNEis3d = this.tSNEis3d;
+
+    // custom
+    bookmark.customSelectedSearchByMetadataOption =
+        this.customSelectedSearchByMetadataOption;
+    if (this.customProjectionXLeftInput != null) {
+      bookmark.customXLeftText = this.customProjectionXLeftInput.getValue();
+      bookmark.customXLeftRegex =
+          this.customProjectionXLeftInput.getInRegexMode();
+    }
+    if (this.customProjectionXRightInput != null) {
+      bookmark.customXRightText = this.customProjectionXRightInput.getValue();
+      bookmark.customXRightRegex =
+          this.customProjectionXRightInput.getInRegexMode();
+    }
+    if (this.customProjectionYUpInput != null) {
+      bookmark.customYUpText = this.customProjectionYUpInput.getValue();
+      bookmark.customYUpRegex = this.customProjectionYUpInput.getInRegexMode();
+    }
+    if (this.customProjectionYDownInput != null) {
+      bookmark.customYDownText = this.customProjectionYDownInput.getValue();
+      bookmark.customYDownRegex =
+          this.customProjectionYDownInput.getInRegexMode();
+    }
+
+    this.enablePolymerChangesTriggerReprojection();
+  }
+
+  // This method is marked as public as it is used as the view method that
+  // abstracts DOM manipulation so we can stub it in a test.
+  // TODO(nsthorat): Move this to its own class as the glue between this class
+  // and the DOM.
+  setZDropdownEnabled(enabled: boolean) {
+    if (this.zDropdown) {
+      if (enabled) {
+        this.zDropdown.removeAttribute('disabled');
+      } else {
+        this.zDropdown.setAttribute('disabled', 'true');
+      }
+    }
+  }
+
+  dataSetUpdated(dataSet: DataSet, originalDataSet: DataSet, dim: number) {
+    this.dataSet = dataSet;
+    this.originalDataSet = originalDataSet;
+    this.dim = dim;
+    const pointCount = (dataSet == null) ? 0 : dataSet.points.length;
+    const perplexity = Math.max(5, Math.ceil(Math.sqrt(pointCount) / 4));
+    this.perplexitySlider.value = perplexity.toString();
+    this.updateTSNEPerplexityFromSliderChange();
+    this.clearCentroids();
+
+    (this.querySelector('#tsne-sampling') as HTMLElement).style.display =
+        pointCount > data.TSNE_SAMPLE_SIZE ? null : 'none';
+    const wasSampled =
+        (dataSet == null) ? false : (dataSet.dim[0] > data.PCA_SAMPLE_DIM ||
+                                     dataSet.dim[1] > data.PCA_SAMPLE_DIM);
+    (this.querySelector('#pca-sampling') as HTMLElement).style.display =
+        wasSampled ? null : 'none';
+    this.showTab('pca');
+  }
+
+  _pcaDimensionToggleObserver() {
+    this.setZDropdownEnabled(this.pcaIs3d);
+    this.beginProjection(this.currentProjection);
+  }
+
+  _tsneDimensionToggleObserver() {
+    this.beginProjection(this.currentProjection);
+  }
+
+  metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) {
+    // Project by options for custom projections.
+    let searchByMetadataIndex = -1;
+    this.searchByMetadataOptions = spriteAndMetadata.stats.map((stats, i) => {
+      // Make the default label by the first non-numeric column.
+      if (!stats.isNumeric && searchByMetadataIndex === -1) {
+        searchByMetadataIndex = i;
+      }
+      return stats.name;
+    });
+    this.customSelectedSearchByMetadataOption =
+        this.searchByMetadataOptions[Math.max(0, searchByMetadataIndex)];
+  }
+
+  public showTab(id: ProjectionType) {
+    this.currentProjection = id;
+
+    const tab =
+        this.querySelector('.ink-tab[data-tab="' + id + '"]') as HTMLElement;
+    const allTabs = this.querySelectorAll('.ink-tab');
+    for (let i = 0; i < allTabs.length; i++) {
+      util.classed(allTabs[i] as HTMLElement, 'active', false);
+    }
+
+    util.classed(tab, 'active', true);
+
+    const allTabContent = this.querySelectorAll('.ink-panel-content');
+    for (let i = 0; i < allTabContent.length; i++) {
+      util.classed(allTabContent[i] as HTMLElement, 'active', false);
+    }
+
+    util.classed(
+        this.querySelector('.ink-panel-content[data-panel="' + id + '"]') as
+            HTMLElement,
+        'active', true);
+
+    // guard for unit tests, where polymer isn't attached and $ doesn't exist.
+    if (this.$ != null) {
+      const main = this.$['main'];
+      // In order for the projections panel to animate its height, we need to
+      // set it explicitly.
+      requestAnimationFrame(() => {
+        this.style.height = main.clientHeight + 'px';
+      });
+    }
+
+    this.beginProjection(id);
+  }
+
+  private beginProjection(projection: ProjectionType) {
+    if (this.polymerChangesTriggerReprojection === false) {
+      return;
+    }
+    if (projection === 'pca') {
+      if (this.dataSet != null) {
+        this.dataSet.stopTSNE();
+      }
+      this.showPCA();
+    } else if (projection === 'tsne') {
+      this.showTSNE();
+    } else if (projection === 'custom') {
+      if (this.dataSet != null) {
+        this.dataSet.stopTSNE();
+      }
+      this.computeAllCentroids();
+      this.reprojectCustom();
+    }
+  }
+
+  private showTSNE() {
+    const dataSet = this.dataSet;
+    if (dataSet == null) {
+      return;
+    }
+    const accessors =
+        data.getProjectionComponents('tsne', [0, 1, this.tSNEis3d ? 2 : null]);
+    const dimensionality = this.tSNEis3d ? 3 : 2;
+    const projection =
+        new Projection('tsne', accessors, dimensionality, dataSet);
+    this.projector.setProjection(projection);
+
+    if (!this.dataSet.hasTSNERun) {
+      this.runTSNE();
+    } else {
+      this.projector.notifyProjectionPositionsUpdated();
+    }
+  }
+
+  private runTSNE() {
+    this.runTsneButton.disabled = true;
+    this.stopTsneButton.disabled = null;
+    this.dataSet.projectTSNE(
+        this.perplexity, this.learningRate, this.tSNEis3d ? 3 : 2,
+        (iteration: number) => {
+          if (iteration != null) {
+            this.iterationLabel.innerText = '' + iteration;
+            this.projector.notifyProjectionPositionsUpdated();
+          } else {
+            this.runTsneButton.disabled = null;
+            this.stopTsneButton.disabled = true;
+          }
+        });
+  }
+
+  // tslint:disable-next-line:no-unused-variable
+  private showPCAIfEnabled() {
+    if (this.polymerChangesTriggerReprojection) {
+      this.showPCA();
+    }
+  }
+
+  private updateTotalVarianceMessage() {
+    let variances = this.dataSet.fracVariancesExplained;
+    let totalVariance = variances[this.pcaX] + variances[this.pcaY];
+    let msg = 'Total variance described: ';
+    if (this.pcaIs3d) {
+      totalVariance += variances[this.pcaZ];
+    }
+    msg += (totalVariance * 100).toFixed(1) + '%.';
+    (this.querySelector('#total-variance') as HTMLElement).innerHTML = msg;
+  }
+
+  private showPCA() {
+    if (this.dataSet == null) {
+      return;
+    }
+    this.dataSet.projectPCA().then(() => {
+      // Polymer properties are 1-based.
+      const accessors = data.getProjectionComponents(
+          'pca', [this.pcaX, this.pcaY, this.pcaZ]);
+
+      const dimensionality = this.pcaIs3d ? 3 : 2;
+      const projection =
+          new Projection('pca', accessors, dimensionality, this.dataSet);
+      this.projector.setProjection(projection);
+      let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]);
+      this.updateTotalVarianceMessage();
+      this.pcaComponents = util.range(numComponents).map(i => {
+        let fracVariance = this.dataSet.fracVariancesExplained[i];
+        return {
+          id: i,
+          componentNumber: i + 1,
+          percVariance: (fracVariance * 100).toFixed(1)
+        };
+      });
+    });
+  }
+
+  private reprojectCustom() {
+    if (this.centroids == null || this.centroids.xLeft == null ||
+        this.centroids.xRight == null || this.centroids.yUp == null ||
+        this.centroids.yDown == null) {
+      return;
+    }
+    const xDir = vector.sub(this.centroids.xRight, this.centroids.xLeft);
+    this.dataSet.projectLinear(xDir, 'linear-x');
+
+    const yDir = vector.sub(this.centroids.yUp, this.centroids.yDown);
+    this.dataSet.projectLinear(yDir, 'linear-y');
+
+    const accessors = data.getProjectionComponents('custom', ['x', 'y']);
+    const projection = new Projection('custom', accessors, 2, this.dataSet);
+    this.projector.setProjection(projection);
+  }
+
+  clearCentroids(): void {
+    this.centroids = {xLeft: null, xRight: null, yUp: null, yDown: null};
+    this.allCentroid = null;
+  }
+
+  _customSelectedSearchByMetadataOptionChanged(newVal: string, oldVal: string) {
+    if (this.polymerChangesTriggerReprojection === false) {
+      return;
+    }
+    if (this.currentProjection === 'custom') {
+      this.computeAllCentroids();
+      this.reprojectCustom();
+    }
+  }
+
+  private setupCustomProjectionInputFields() {
+    this.customProjectionXLeftInput =
+        this.setupCustomProjectionInputField('xLeft');
+    this.customProjectionXRightInput =
+        this.setupCustomProjectionInputField('xRight');
+    this.customProjectionYUpInput = this.setupCustomProjectionInputField('yUp');
+    this.customProjectionYDownInput =
+        this.setupCustomProjectionInputField('yDown');
+  }
+
+  private computeAllCentroids() {
+    this.computeCentroid('xLeft');
+    this.computeCentroid('xRight');
+    this.computeCentroid('yUp');
+    this.computeCentroid('yDown');
+  }
+
+  private computeCentroid(name: InputControlName) {
+    const input = this.querySelector('#' + name) as ProjectorInput;
+    if (input == null) {
+      return;
+    }
+    const value = input.getValue();
+    if (value == null) {
+      return;
+    }
+    let inRegexMode = input.getInRegexMode();
+    let result = this.getCentroid(value, inRegexMode);
+    if (result.numMatches === 0) {
+      input.message = '0 matches. Using a random vector.';
+      result.centroid = vector.rn(this.dim);
+    } else {
+      input.message = `${result.numMatches} matches.`;
+    }
+    this.centroids[name] = result.centroid;
+    this.centroidValues[name] = value;
+  }
+
+  private setupCustomProjectionInputField(name: InputControlName):
+      ProjectorInput {
+    let input = this.querySelector('#' + name) as ProjectorInput;
+    input.registerInputChangedListener((input, inRegexMode) => {
+      if (this.polymerChangesTriggerReprojection) {
+        this.computeCentroid(name);
+        this.reprojectCustom();
+      }
+    });
+    return input;
+  }
+
+  private getCentroid(pattern: string, inRegexMode: boolean): CentroidResult {
+    if (pattern == null || pattern === '') {
+      return {numMatches: 0};
+    }
+    // Search by the original dataset since we often want to filter and project
+    // only the nearest neighbors of A onto B-C where B and C are not nearest
+    // neighbors of A.
+    let accessor = (i: number) => this.originalDataSet.points[i].vector;
+    let r = this.originalDataSet.query(
+        pattern, inRegexMode, this.customSelectedSearchByMetadataOption);
+    return {centroid: vector.centroid(r, accessor), numMatches: r.length};
+  }
+
+  getPcaSampledDimText() {
+    return data.PCA_SAMPLE_DIM.toLocaleString();
+  }
+
+  getPcaSampleSizeText() {
+    return data.PCA_SAMPLE_SIZE.toLocaleString();
+  }
+
+  getTsneSampleSizeText() {
+    return data.TSNE_SAMPLE_SIZE.toLocaleString();
+  }
+}
+
+document.registerElement(ProjectionsPanel.prototype.is, ProjectionsPanel);
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel_test.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel_test.ts
new file mode 100644
index 00000000000..fd1acf6f085
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-projections-panel_test.ts
@@ -0,0 +1,109 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+import {State} from './data';
+import {ProjectionsPanel} from './vz-projector-projections-panel';
+
+const assert = chai.assert;
+
+describe('restoreUIFromBookmark', () => {
+  let projectionsPanel: ProjectionsPanel;
+  beforeEach(() => {
+    projectionsPanel = document.createElement(ProjectionsPanel.prototype.is) as
+        ProjectionsPanel;
+
+    // Set up some of the UI so the elements are found in the production code.
+    const tsnePerplexityContainer = document.createElement('div');
+    tsnePerplexityContainer.className = 'tsne-perplexity';
+    const tsnePerplexity = document.createElement('span');
+    tsnePerplexityContainer.appendChild(tsnePerplexity);
+    projectionsPanel.appendChild(tsnePerplexityContainer);
+
+    const tsneLearningRateContainer = document.createElement('div');
+    tsneLearningRateContainer.className = 'tsne-learning-rate';
+    const tsneLearningRate = document.createElement('span');
+    tsneLearningRateContainer.appendChild(tsneLearningRate);
+    projectionsPanel.appendChild(tsneLearningRateContainer);
+  });
+
+  it('sets the pcaX/Y properties when setting 2D component values', () => {
+    spyOn(projectionsPanel, 'setZDropdownEnabled');
+
+    const s = new State();
+    s.pcaComponentDimensions = [0, 1];
+    projectionsPanel.restoreUIFromBookmark(s);
+
+    assert.equal(0, projectionsPanel.pcaX);
+    assert.equal(1, projectionsPanel.pcaY);
+
+    expect(projectionsPanel.setZDropdownEnabled).toHaveBeenCalledWith(false);
+  });
+
+  it('sets the pcaX/Y properties when setting 3D component values', () => {
+    spyOn(projectionsPanel, 'setZDropdownEnabled');
+
+    const s = new State();
+    s.pcaComponentDimensions = [0, 1, 2];
+    projectionsPanel.restoreUIFromBookmark(s);
+
+    assert.equal(0, projectionsPanel.pcaX);
+    assert.equal(1, projectionsPanel.pcaY);
+    assert.equal(2, projectionsPanel.pcaZ);
+
+    expect(projectionsPanel.setZDropdownEnabled).toHaveBeenCalledWith(true);
+  });
+});
+
+describe('populateBookmarkFromUI', () => {
+  let projectionsPanel: ProjectionsPanel;
+
+  beforeEach(() => {
+    projectionsPanel = document.createElement(ProjectionsPanel.prototype.is) as
+        ProjectionsPanel;
+
+    // Set up some of the UI so the elements are found in the production code.
+    const tsnePerplexityContainer = document.createElement('div');
+    tsnePerplexityContainer.className = 'tsne-perplexity';
+    const tsnePerplexity = document.createElement('span');
+    tsnePerplexityContainer.appendChild(tsnePerplexity);
+    projectionsPanel.appendChild(tsnePerplexityContainer);
+
+    const tsneLearningRateContainer = document.createElement('div');
+    tsneLearningRateContainer.className = 'tsne-learning-rate';
+    const tsneLearningRate = document.createElement('span');
+    tsneLearningRateContainer.appendChild(tsneLearningRate);
+    projectionsPanel.appendChild(tsneLearningRateContainer);
+  });
+
+  it('gets the PCA component UI values from a 2D PCA projection', () => {
+    projectionsPanel.pcaX = 0;
+    projectionsPanel.pcaY = 1;
+    projectionsPanel.pcaIs3d = false;
+
+    const s = new State();
+    projectionsPanel.populateBookmarkFromUI(s);
+    assert.deepEqual([0, 1], s.pcaComponentDimensions);
+  });
+
+  it('gets the PCA component UI values from a 3D PCA projection', () => {
+    projectionsPanel.pcaX = 0;
+    projectionsPanel.pcaY = 1;
+    projectionsPanel.pcaZ = 2;
+    projectionsPanel.pcaIs3d = true;
+
+    const s = new State();
+    projectionsPanel.populateBookmarkFromUI(s);
+    assert.deepEqual([0, 1, 2], s.pcaComponentDimensions);
+  });
+});
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-util.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-util.ts
new file mode 100644
index 00000000000..44062062a36
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector-util.ts
@@ -0,0 +1,34 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+export type Spec = {
+  is: string; properties?: {
+    [key: string]:
+        (Function |
+         {
+           type: Function, value?: any;
+           readonly?: boolean;
+           notify?: boolean;
+           observer?: string;
+         })
+  };
+  observers?: string[];
+};
+
+export function PolymerElement(spec: Spec) {
+  return Polymer.Class(spec as any) as{new (): PolymerHTMLElement};
+}
+
+export interface PolymerHTMLElement extends HTMLElement, polymer.Base {}
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.html b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.html
new file mode 100644
index 00000000000..d4be2f26a5d
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.html
@@ -0,0 +1,343 @@
+<!--
+@license
+Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+
+<link rel="import" href="../polymer/polymer.html">
+<link rel="import" href="../iron-collapse/iron-collapse.html">
+<link rel="import" href="../paper-toggle-button/paper-toggle-button.html">
+<link rel="import" href="../paper-listbox/paper-listbox.html">
+<link rel="import" href="../paper-item/paper-item.html">
+<link rel="import" href="../paper-checkbox/paper-checkbox.html">
+<link rel="import" href="../iron-icons/iron-icons.html">
+<link rel="import" href="../iron-icons/image-icons.html">
+<link rel="import" href="../paper-icon-button/paper-icon-button.html">
+<link rel="import" href="../paper-tooltip/paper-tooltip.html">
+<link rel="import" href="../paper-input/paper-input.html">
+<link rel="import" href="../paper-button/paper-button.html">
+<link rel="import" href="../paper-dialog/paper-dialog.html">
+<link rel="import" href="../paper-toast/paper-toast.html">
+<link rel="import" href="../paper-styles/typography.html">
+<link rel="import" href="../paper-spinner/paper-spinner-lite.html">
+<link rel="import" href="../paper-dialog-scrollable/paper-dialog-scrollable.html">
+
+<link rel="import" href="vz-projector-bookmark-panel.html">
+<link rel="import" href="vz-projector-data-panel.html">
+<link rel="import" href="vz-projector-inspector-panel.html">
+<link rel="import" href="vz-projector-input.html">
+<link rel="import" href="vz-projector-metadata-card.html">
+<link rel="import" href="vz-projector-projections-panel.html">
+<link rel="import" href="styles.html">
+
+<dom-module id="vz-projector">
+<template>
+<style include="vz-projector-styles"></style>
+<style>
+:host {
+  display: flex;
+  width: 100%;
+  height: 100%;
+}
+
+#container {
+  display: flex;
+  width: 100%;
+  height: 100%;
+  overflow: hidden;
+}
+
+.hidden {
+  display: none !important;
+}
+
+/* Main */
+
+#main {
+  position: relative;
+  flex-grow: 2;
+}
+
+#main .stage {
+  position: relative;
+  flex-grow: 2;
+}
+
+#scatter {
+  position: absolute;
+  top: 0;
+  left: 0;
+  right: 0;
+  bottom: 0;
+}
+
+#selector {
+  display: none;
+  height: 100%;
+  position: absolute;
+  width: 100%;
+}
+
+#left-pane {
+  display: flex;
+  flex-direction: column;
+  justify-content: space-between;
+  min-width: 312px;
+  width: 312px;
+  border-right: 1px solid rgba(0, 0, 0, 0.1);
+  background: #fafafa;
+}
+
+#right-pane {
+  border-left: 1px solid rgba(0, 0, 0, 0.1);
+  background: #fafafa;
+  display: flex;
+  height: 100%;
+  min-width: 300px;
+  width: 300px;
+}
+
+.file-name {
+  margin-right: 5px;
+}
+
+.control input[type=text]:focus {
+  outline: none;
+  border-bottom: 1px solid rgba(0, 0, 0, 1);
+}
+
+.control {
+  display: inline-block;
+  width: 45%;
+  vertical-align: top;
+  margin-right: 10px;
+  overflow-x: hidden;
+}
+
+.control.last {
+  margin-right: 0;
+}
+
+#notification-dialog {
+  width: 400px;
+  padding-bottom: 20px;
+}
+
+#notification-dialog paper-button {
+  background: none;
+  text-transform: uppercase;
+}
+
+#notification-dialog .progress {
+  --paper-spinner-color: #880E4F;
+  --paper-spinner-stroke-width: 2px;
+}
+
+#notify-msgs {
+  text-align: center;
+  display: block;
+}
+
+.notify-msg {
+  font-weight: 500;
+  margin: 0;
+  padding: 0;
+}
+
+.notify-msg.error {
+  text-align: left;
+}
+
+.brush .extent {
+  stroke: #fff;
+  fill-opacity: .125;
+  shape-rendering: crispEdges;
+}
+
+.origin text {
+  font-size: 12px;
+  font-weight: 500;
+}
+
+.origin line {
+  stroke: black;
+  stroke-opacity: 0.2;
+}
+
+/* Ink Framework */
+
+/* - Buttons */
+.ink-button, ::shadow .ink-button {
+  border: none;
+  border-radius: 2px;
+  font-size: 13px;
+  padding: 10px;
+  min-width: 100px;
+  flex-shrink: 0;
+  background: #e3e3e3;
+}
+
+.status-bar-panel {
+  display: flex;
+  align-items: center;
+}
+
+.status-bar-entry {
+  border-left: 1px solid rgba(0, 0, 0, 0.5);
+  margin-left: 5px;
+  padding-left: 5px;
+}
+
+/* - Menubar */
+
+.ink-panel-menubar {
+  align-items: center;
+  position: relative;
+  height: 60px;
+  border-bottom: solid 1px #eee;
+  padding: 0 24px;
+  display: flex;
+}
+
+.ink-panel-menubar .ink-fabs {
+  position: absolute;
+  right: 12px;
+  top: 40px;
+  z-index: 1;
+}
+
+#bookmark-panel {
+  bottom: 0;
+  position: absolute;
+  width: 300px;
+}
+#bookmark-panel-container {
+  bottom: 60px;
+  position: absolute;
+}
+
+.ink-fab {
+  margin-left: 8px;
+  border: 1px solid rgba(0, 0, 0, 0.02);
+  background: white;
+  box-shadow: 0 1px 3px rgba(0, 0, 0, 0.3);
+}
+
+#metadata-card {
+  position: absolute;
+  right: 5px;
+  top: 25px;
+}
+
+#help-3d-icon {
+  position: absolute;
+  top: 20px;
+  left: 20px;
+}
+
+#help3dDialog .main {
+  margin: 0;
+  padding: 20px;
+}
+
+#help3dDialog h3 {
+  margin-top: 20px;
+  margin-bottom: 5px;
+}
+
+#help3dDialog h3:first-child {
+  margin-top: 0;
+}
+
+#data-panel {
+  border-top: 1px solid rgba(0, 0, 0, 0.1);
+  overflow-y: auto;
+}
+
+#toast {
+  display: flex;
+  align-items: center;
+  --paper-toast-color: #eeff41;
+}
+</style>
+<paper-dialog id="notification-dialog" modal>
+  <h2 id="notification-title"></h2>
+  <paper-dialog-scrollable>
+    <div id="notify-msgs"></div>
+  </paper-dialog-scrollable>
+  <div style="text-align: center;"><paper-spinner-lite active class="progress"></paper-spinner-lite></div>
+  <div class="buttons">
+    <paper-button class="close-button" dialog-confirm autofocus>Close</paper-button>
+  </div>
+</paper-dialog>
+<div id="container">
+  <div id="left-pane" class="ink-panel">
+    <vz-projector-data-panel id="data-panel"></vz-projector-data-panel>
+    <vz-projector-projections-panel id="projections-panel"></vz-projector-projections-panel>
+  </div>
+  <div id="main" class="ink-panel">
+    <div class="ink-panel-menubar">
+      <paper-icon-button id="selectMode" alt="Bounding box selection" toggles icon="image:photo-size-select-small"></paper-icon-button>
+      <paper-tooltip for="selectMode" position="bottom" animation-delay="0" fit-to-visible-bounds>Bounding box selection</paper-tooltip>
+
+      <paper-icon-button id="nightDayMode" alt="Enable/disable night mode" toggles icon="image:brightness-2"></paper-icon-button>
+      <paper-tooltip for="nightDayMode" position="bottom" animation-delay="0" fit-to-visible-bounds>Enable/disable night mode</paper-tooltip>
+
+      <paper-icon-button id="labels3DMode" alt="Enable/disable 3D labels mode" toggles icon="font-download"></paper-icon-button>
+      <paper-tooltip for="labels3DMode" position="bottom" animation-delay="0" fit-to-visible-bounds>Enable/disable 3D labels mode</paper-tooltip>
+      <div class="status-bar-panel">
+        <div class="status-bar-entry">Points: <span class="numDataPoints">Loading...</span></div>
+        <div class="status-bar-entry">Dimension: <span class="dim">Loading...</span></div>
+        <div id="status-bar" class="status-bar-entry" style="display: none;"></div>
+      </div>
+      <div class="ink-fabs">
+        <paper-icon-button id="reset-zoom" class="ink-fab" alt="Reset zoom to fit all points" icon="home"></paper-icon-button>
+        <paper-tooltip for="reset-zoom" position="left" animation-delay="0">Reset zoom to fit all points</paper-tooltip>
+      </div>
+    </div>
+    <div class="stage">
+      <div id="scatter">
+        <svg id="selector"></svg>
+      </div>
+      <vz-projector-metadata-card id="metadata-card"></vz-projector-metadata-card>
+      <paper-icon-button raised onclick="help3dDialog.open()" icon="help-outline" id="help-3d-icon"></paper-icon-button>
+      <paper-tooltip animation-delay="0" for="help-3d-icon">Help with interaction controls.</paper-tooltip>
+      <paper-dialog id="help3dDialog" with-backdrop>
+        <div class="main" dialog-confirm autofocus>
+          <h3>3D controls</h3>
+            <b>Rotate</b> Mouse left click.<br/>
+            <b>Pan</b> Mouse right click.<br/>
+            <b>Zoom</b> Mouse wheel.<br/>
+            Holding <b>ctrl</b> reverses the mouse clicks.
+          <h3>2D controls</h3>
+            <b>Pan</b> Mouse left click.<br/>
+            <b>Zoom</b> Mouse wheel.
+          <div class="dismiss-dialog-note"> Click anywhere to dismiss.</div>
+        </div>
+      </paper-dialog>
+    </div>
+  </div>
+  <div id="right-pane" class="ink-panel">
+    <div class="ink-panel-content active">
+      <vz-projector-inspector-panel id="inspector-panel"></vz-projector-inspector-panel>
+    </div>
+    <div id="bookmark-panel-container">
+      <vz-projector-bookmark-panel id="bookmark-panel"></vz-projector-bookmark-panel>
+    </div>
+  </div>
+</div>
+<paper-toast id="toast" always-on-top></paper-toast>
+
+</template> <!-- global template -->
+</dom-module>
diff --git a/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.ts
new file mode 100644
index 00000000000..bf98a4d4785
--- /dev/null
+++ b/tensorflow/tensorboard/components/vz_projector_d3v4/vz-projector.ts
@@ -0,0 +1,570 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+import {AnalyticsLogger} from './analyticsLogger';
+import * as data from './data';
+import {ColorOption, ColumnStats, DataPoint, DataProto, DataSet, DistanceFunction, PointMetadata, Projection, SpriteAndMetadataInfo, State, stateGetAccessorDimensions} from './data';
+import {DataProvider, EmbeddingInfo, ServingMode} from './data-provider';
+import {DemoDataProvider} from './data-provider-demo';
+import {ProtoDataProvider} from './data-provider-proto';
+import {ServerDataProvider} from './data-provider-server';
+import * as knn from './knn';
+import * as logging from './logging';
+import {DistanceMetricChangedListener, HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext';
+import {ProjectorScatterPlotAdapter} from './projectorScatterPlotAdapter';
+import {MouseMode} from './scatterPlot';
+import * as util from './util';
+import {BookmarkPanel} from './vz-projector-bookmark-panel';
+import {DataPanel} from './vz-projector-data-panel';
+import {InspectorPanel} from './vz-projector-inspector-panel';
+import {MetadataCard} from './vz-projector-metadata-card';
+import {ProjectionsPanel} from './vz-projector-projections-panel';
+// tslint:disable-next-line:no-unused-variable
+import {PolymerElement, PolymerHTMLElement} from './vz-projector-util';
+
+/**
+ * The minimum number of dimensions the data should have to automatically
+ * decide to normalize the data.
+ */
+const THRESHOLD_DIM_NORMALIZE = 50;
+const POINT_COLOR_MISSING = 'black';
+
+export let ProjectorPolymer = PolymerElement({
+  is: 'vz-projector',
+  properties: {
+    routePrefix: String,
+    dataProto: {type: String, observer: '_dataProtoChanged'},
+    servingMode: String,
+    projectorConfigJsonPath: String,
+    pageViewLogging: Boolean,
+    eventLogging: Boolean
+  }
+});
+
+const INDEX_METADATA_FIELD = '__index__';
+
+export class Projector extends ProjectorPolymer implements
+    ProjectorEventContext {
+  // The working subset of the data source's original data set.
+  dataSet: DataSet;
+  servingMode: ServingMode;
+  // The path to the projector config JSON file for demo mode.
+  projectorConfigJsonPath: string;
+
+  private selectionChangedListeners: SelectionChangedListener[];
+  private hoverListeners: HoverListener[];
+  private projectionChangedListeners: ProjectionChangedListener[];
+  private distanceMetricChangedListeners: DistanceMetricChangedListener[];
+
+  private originalDataSet: DataSet;
+  private dataSetBeforeFilter: DataSet;
+  private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter;
+  private dim: number;
+
+  private dataSetFilterIndices: number[];
+  private selectedPointIndices: number[];
+  private neighborsOfFirstPoint: knn.NearestEntry[];
+  private hoverPointIndex: number;
+
+  private dataProvider: DataProvider;
+  private inspectorPanel: InspectorPanel;
+
+  private selectedColorOption: ColorOption;
+  private selectedLabelOption: string;
+  private routePrefix: string;
+  private normalizeData: boolean;
+  private projection: Projection;
+
+  /** Polymer component panels */
+  private dataPanel: DataPanel;
+  private bookmarkPanel: BookmarkPanel;
+  private projectionsPanel: ProjectionsPanel;
+  private metadataCard: MetadataCard;
+
+  private statusBar: HTMLDivElement;
+  private analyticsLogger: AnalyticsLogger;
+  private eventLogging: boolean;
+  private pageViewLogging: boolean;
+
+  ready() {
+    logging.setDomContainer(this);
+
+    this.analyticsLogger =
+        new AnalyticsLogger(this.pageViewLogging, this.eventLogging);
+    this.analyticsLogger.logPageView('embeddings');
+
+    if (!util.hasWebGLSupport()) {
+      this.analyticsLogger.logWebGLDisabled();
+      logging.setErrorMessage(
+          'Your browser or device does not have WebGL enabled. Please enable ' +
+          'hardware acceleration, or use a browser that supports WebGL.');
+      return;
+    }
+
+    this.selectionChangedListeners = [];
+    this.hoverListeners = [];
+    this.projectionChangedListeners = [];
+    this.distanceMetricChangedListeners = [];
+    this.selectedPointIndices = [];
+    this.neighborsOfFirstPoint = [];
+
+    this.dataPanel = this.$['data-panel'] as DataPanel;
+    this.inspectorPanel = this.$['inspector-panel'] as InspectorPanel;
+    this.inspectorPanel.initialize(this, this as ProjectorEventContext);
+    this.projectionsPanel = this.$['projections-panel'] as ProjectionsPanel;
+    this.projectionsPanel.initialize(this);
+    this.bookmarkPanel = this.$['bookmark-panel'] as BookmarkPanel;
+    this.bookmarkPanel.initialize(this, this as ProjectorEventContext);
+    this.metadataCard = this.$['metadata-card'] as MetadataCard;
+    this.statusBar = this.querySelector('#status-bar') as HTMLDivElement;
+    this.scopeSubtree(this.$$('#notification-dialog'), true);
+    this.setupUIControls();
+    this.initializeDataProvider();
+  }
+
+  setSelectedLabelOption(labelOption: string) {
+    this.selectedLabelOption = labelOption;
+    this.metadataCard.setLabelOption(this.selectedLabelOption);
+    this.projectorScatterPlotAdapter.setLabelPointAccessor(labelOption);
+    this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
+    this.projectorScatterPlotAdapter.render();
+  }
+
+  setSelectedColorOption(colorOption: ColorOption) {
+    this.selectedColorOption = colorOption;
+    this.projectorScatterPlotAdapter.setLegendPointColorer(
+        this.getLegendPointColorer(colorOption));
+    this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
+    this.projectorScatterPlotAdapter.render();
+  }
+
+  setNormalizeData(normalizeData: boolean) {
+    this.normalizeData = normalizeData;
+    this.setCurrentDataSet(this.originalDataSet.getSubset());
+  }
+
+  updateDataSet(
+      ds: DataSet, spriteAndMetadata?: SpriteAndMetadataInfo,
+      metadataFile?: string) {
+    this.dataSetFilterIndices = null;
+    this.originalDataSet = ds;
+    if (ds != null) {
+      this.normalizeData =
+          this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE;
+      spriteAndMetadata = spriteAndMetadata || {};
+      if (spriteAndMetadata.pointsInfo == null) {
+        let [pointsInfo, stats] = this.makeDefaultPointsInfoAndStats(ds.points);
+        spriteAndMetadata.pointsInfo = pointsInfo;
+        spriteAndMetadata.stats = stats;
+      }
+      let metadataMergeSucceeded = ds.mergeMetadata(spriteAndMetadata);
+      if (!metadataMergeSucceeded) {
+        return;
+      }
+    }
+    if (this.projectorScatterPlotAdapter != null) {
+      if (ds == null) {
+        this.projectorScatterPlotAdapter.setLabelPointAccessor(null);
+        this.setProjection(null);
+      } else {
+        this.projectorScatterPlotAdapter.updateScatterPlotPositions();
+        this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
+        this.projectorScatterPlotAdapter.resize();
+        this.projectorScatterPlotAdapter.render();
+      }
+    }
+    if (ds != null) {
+      this.dataPanel.setNormalizeData(this.normalizeData);
+      this.setCurrentDataSet(ds.getSubset());
+      this.projectorScatterPlotAdapter.setLabelPointAccessor(
+          this.selectedLabelOption);
+      this.inspectorPanel.datasetChanged();
+
+      this.inspectorPanel.metadataChanged(spriteAndMetadata);
+      this.projectionsPanel.metadataChanged(spriteAndMetadata);
+      this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile);
+      // Set the container to a fixed height, otherwise in Colab the
+      // height can grow indefinitely.
+      const container = this.querySelector('#container') as HTMLDivElement;
+      container.style.height = container.clientHeight + 'px';
+    } else {
+      this.setCurrentDataSet(null);
+    }
+  }
+
+  setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) {
+    this.bookmarkPanel.setSelectedTensor(run, tensorInfo, this.dataProvider);
+  }
+
+  /**
+   * Registers a listener to be called any time the selected point set changes.
+   */
+  registerSelectionChangedListener(listener: SelectionChangedListener) {
+    this.selectionChangedListeners.push(listener);
+  }
+
+  filterDataset(pointIndices: number[]) {
+    const selectionSize = this.selectedPointIndices.length;
+    if (this.dataSetBeforeFilter == null) {
+      this.dataSetBeforeFilter = this.dataSet;
+    }
+    this.setCurrentDataSet(this.dataSet.getSubset(pointIndices));
+    this.dataSetFilterIndices = pointIndices;
+    this.projectorScatterPlotAdapter.updateScatterPlotPositions();
+    this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
+    this.adjustSelectionAndHover(util.range(selectionSize));
+  }
+
+  resetFilterDataset() {
+    const originalPointIndices = this.selectedPointIndices.map(
+        filteredIndex => this.dataSet.points[filteredIndex].index);
+    this.setCurrentDataSet(this.dataSetBeforeFilter);
+    if (this.projection != null) {
+      this.projection.dataSet = this.dataSetBeforeFilter;
+    }
+    this.dataSetBeforeFilter = null;
+    this.projectorScatterPlotAdapter.updateScatterPlotPositions();
+    this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
+    this.dataSetFilterIndices = [];
+    this.adjustSelectionAndHover(originalPointIndices);
+  }
+
+  /**
+   * Used by clients to indicate that a selection has occurred.
+   */
+  notifySelectionChanged(newSelectedPointIndices: number[]) {
+    this.selectedPointIndices = newSelectedPointIndices;
+    let neighbors: knn.NearestEntry[] = [];
+
+    if (newSelectedPointIndices.length === 1) {
+      neighbors = this.dataSet.findNeighbors(
+          newSelectedPointIndices[0], this.inspectorPanel.distFunc,
+          this.inspectorPanel.numNN);
+      this.metadataCard.updateMetadata(
+          this.dataSet.points[newSelectedPointIndices[0]].metadata);
+    } else {
+      this.metadataCard.updateMetadata(null);
+    }
+
+    this.selectionChangedListeners.forEach(
+        l => l(this.selectedPointIndices, neighbors));
+  }
+
+  /**
+   * Registers a listener to be called any time the mouse hovers over a point.
+   */
+  registerHoverListener(listener: HoverListener) {
+    this.hoverListeners.push(listener);
+  }
+
+  /**
+   * Used by clients to indicate that a hover is occurring.
+   */
+  notifyHoverOverPoint(pointIndex: number) {
+    this.hoverListeners.forEach(l => l(pointIndex));
+  }
+
+  registerProjectionChangedListener(listener: ProjectionChangedListener) {
+    this.projectionChangedListeners.push(listener);
+  }
+
+  notifyProjectionChanged(projection: Projection) {
+    this.projectionChangedListeners.forEach(l => l(projection));
+  }
+
+  registerDistanceMetricChangedListener(l: DistanceMetricChangedListener) {
+    this.distanceMetricChangedListeners.push(l);
+  }
+
+  notifyDistanceMetricChanged(distMetric: DistanceFunction) {
+    this.distanceMetricChangedListeners.forEach(l => l(distMetric));
+  }
+
+  _dataProtoChanged(dataProtoString: string) {
+    let dataProto =
+        dataProtoString ? JSON.parse(dataProtoString) as DataProto : null;
+    this.initializeDataProvider(dataProto);
+  }
+
+  private makeDefaultPointsInfoAndStats(points: DataPoint[]):
+      [PointMetadata[], ColumnStats[]] {
+    let pointsInfo: PointMetadata[] = [];
+    points.forEach(p => {
+      let pointInfo: PointMetadata = {};
+      pointInfo[INDEX_METADATA_FIELD] = p.index;
+      pointsInfo.push(pointInfo);
+    });
+    let stats: ColumnStats[] = [{
+      name: INDEX_METADATA_FIELD,
+      isNumeric: false,
+      tooManyUniqueValues: true,
+      min: 0,
+      max: pointsInfo.length - 1
+    }];
+    return [pointsInfo, stats];
+  }
+
+  private initializeDataProvider(dataProto?: DataProto) {
+    if (this.servingMode === 'demo') {
+      let projectorConfigUrl: string;
+
+      // Only in demo mode do we allow the config being passed via URL.
+      let urlParams = util.getURLParams(window.location.search);
+      if ('config' in urlParams) {
+        projectorConfigUrl = urlParams['config'];
+      } else {
+        projectorConfigUrl = this.projectorConfigJsonPath;
+      }
+      this.dataProvider = new DemoDataProvider(projectorConfigUrl);
+    } else if (this.servingMode === 'server') {
+      if (!this.routePrefix) {
+        throw 'route-prefix is a required parameter';
+      }
+      this.dataProvider = new ServerDataProvider(this.routePrefix);
+    } else if (this.servingMode === 'proto' && dataProto != null) {
+      this.dataProvider = new ProtoDataProvider(dataProto);
+    }
+
+    this.dataPanel.initialize(this, this.dataProvider);
+  }
+
+  private getLegendPointColorer(colorOption: ColorOption):
+      (ds: DataSet, index: number) => string {
+    if ((colorOption == null) || (colorOption.map == null)) {
+      return null;
+    }
+    const colorer = (ds: DataSet, i: number) => {
+      let value = ds.points[i].metadata[this.selectedColorOption.name];
+      if (value == null) {
+        return POINT_COLOR_MISSING;
+      }
+      return colorOption.map(value);
+    };
+    return colorer;
+  }
+
+  private get3DLabelModeButton(): any {
+    return this.querySelector('#labels3DMode');
+  }
+
+  private get3DLabelMode(): boolean {
+    const label3DModeButton = this.get3DLabelModeButton();
+    return (label3DModeButton as any).active;
+  }
+
+  adjustSelectionAndHover(selectedPointIndices: number[], hoverIndex?: number) {
+    this.notifySelectionChanged(selectedPointIndices);
+    this.notifyHoverOverPoint(hoverIndex);
+    this.setMouseMode(MouseMode.CAMERA_AND_CLICK_SELECT);
+  }
+
+  private setMouseMode(mouseMode: MouseMode) {
+    let selectModeButton = this.querySelector('#selectMode');
+    (selectModeButton as any).active = (mouseMode === MouseMode.AREA_SELECT);
+    this.projectorScatterPlotAdapter.scatterPlot.setMouseMode(mouseMode);
+  }
+
+  private setCurrentDataSet(ds: DataSet) {
+    this.adjustSelectionAndHover([]);
+    if (this.dataSet != null) {
+      this.dataSet.stopTSNE();
+    }
+    if ((ds != null) && this.normalizeData) {
+      ds.normalize();
+    }
+    this.dim = (ds == null) ? 0 : ds.dim[1];
+    (this.querySelector('span.numDataPoints') as HTMLSpanElement).innerText =
+        (ds == null) ? '0' : '' + ds.dim[0];
+    (this.querySelector('span.dim') as HTMLSpanElement).innerText =
+        (ds == null) ? '0' : '' + ds.dim[1];
+
+    this.dataSet = ds;
+
+    this.projectionsPanel.dataSetUpdated(
+        this.dataSet, this.originalDataSet, this.dim);
+
+    this.projectorScatterPlotAdapter.setDataSet(this.dataSet);
+    this.projectorScatterPlotAdapter.scatterPlot
+        .setCameraParametersForNextCameraCreation(null, true);
+  }
+
+  private setupUIControls() {
+    // View controls
+    this.querySelector('#reset-zoom').addEventListener('click', () => {
+      this.projectorScatterPlotAdapter.scatterPlot.resetZoom();
+      this.projectorScatterPlotAdapter.scatterPlot.startOrbitAnimation();
+    });
+
+    let selectModeButton = this.querySelector('#selectMode');
+    selectModeButton.addEventListener('click', (event) => {
+      this.setMouseMode(
+          (selectModeButton as any).active ? MouseMode.AREA_SELECT :
+                                             MouseMode.CAMERA_AND_CLICK_SELECT);
+    });
+    let nightModeButton = this.querySelector('#nightDayMode');
+    nightModeButton.addEventListener('click', () => {
+      this.projectorScatterPlotAdapter.scatterPlot.setDayNightMode(
+          (nightModeButton as any).active);
+    });
+
+    const labels3DModeButton = this.get3DLabelModeButton();
+    labels3DModeButton.addEventListener('click', () => {
+      this.projectorScatterPlotAdapter.set3DLabelMode(this.get3DLabelMode());
+    });
+
+    window.addEventListener('resize', () => {
+      const container = this.querySelector('#container') as HTMLDivElement;
+      const parentHeight = (container.parentNode as HTMLElement).clientHeight;
+      container.style.height = parentHeight + 'px';
+      this.projectorScatterPlotAdapter.resize();
+    });
+
+    {
+      this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter(
+          this.getScatterContainer(), this as ProjectorEventContext);
+      this.projectorScatterPlotAdapter.setLabelPointAccessor(
+          this.selectedLabelOption);
+    }
+
+    this.projectorScatterPlotAdapter.scatterPlot.onCameraMove(
+        (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) =>
+            this.bookmarkPanel.clearStateSelection());
+
+    this.registerHoverListener(
+        (hoverIndex: number) => this.onHover(hoverIndex));
+
+    this.registerSelectionChangedListener(
+        (selectedPointIndices: number[],
+         neighborsOfFirstPoint: knn.NearestEntry[]) =>
+            this.onSelectionChanged(
+                selectedPointIndices, neighborsOfFirstPoint));
+  }
+
+  private onHover(hoverIndex: number) {
+    this.hoverPointIndex = hoverIndex;
+    let hoverText = null;
+    if (hoverIndex != null) {
+      const point = this.dataSet.points[hoverIndex];
+      if (point.metadata[this.selectedLabelOption]) {
+        hoverText = point.metadata[this.selectedLabelOption].toString();
+      }
+    }
+    if (this.selectedPointIndices.length === 0) {
+      this.statusBar.style.display = hoverText ? null : 'none';
+      this.statusBar.innerText = hoverText;
+    }
+  }
+
+  private getScatterContainer(): HTMLDivElement {
+    return this.querySelector('#scatter') as HTMLDivElement;
+  }
+
+  private onSelectionChanged(
+      selectedPointIndices: number[],
+      neighborsOfFirstPoint: knn.NearestEntry[]) {
+    this.selectedPointIndices = selectedPointIndices;
+    this.neighborsOfFirstPoint = neighborsOfFirstPoint;
+    let totalNumPoints =
+        this.selectedPointIndices.length + neighborsOfFirstPoint.length;
+    this.statusBar.innerText = `Selected ${totalNumPoints} points`;
+    this.statusBar.style.display = totalNumPoints > 0 ? null : 'none';
+  }
+
+  setProjection(projection: Projection) {
+    this.projection = projection;
+    if (projection != null) {
+      this.analyticsLogger.logProjectionChanged(projection.projectionType);
+    }
+    this.notifyProjectionChanged(projection);
+  }
+
+  notifyProjectionPositionsUpdated() {
+    this.projectorScatterPlotAdapter.notifyProjectionPositionsUpdated();
+  }
+
+  /**
+   * Gets the current view of the embedding and saves it as a State object.
+   */
+  getCurrentState(): State {
+    const state = new State();
+
+    // Save the individual datapoint projections.
+    state.projections = [];
+    for (let i = 0; i < this.dataSet.points.length; i++) {
+      const point = this.dataSet.points[i];
+      const projections: {[key: string]: number} = {};
+      const keys = Object.keys(point.projections);
+      for (let j = 0; j < keys.length; ++j) {
+        projections[keys[j]] = point.projections[keys[j]];
+      }
+      state.projections.push(projections);
+    }
+    state.selectedProjection = this.projection.projectionType;
+    state.dataSetDimensions = this.dataSet.dim;
+    state.tSNEIteration = this.dataSet.tSNEIteration;
+    state.selectedPoints = this.selectedPointIndices;
+    state.filteredPoints = this.dataSetFilterIndices;
+    this.projectorScatterPlotAdapter.populateBookmarkFromUI(state);
+    state.selectedColorOptionName = this.dataPanel.selectedColorOptionName;
+    state.forceCategoricalColoring = this.dataPanel.forceCategoricalColoring;
+    state.selectedLabelOption = this.selectedLabelOption;
+    this.projectionsPanel.populateBookmarkFromUI(state);
+    return state;
+  }
+
+  /** Loads a State object into the world. */
+  loadState(state: State) {
+    this.setProjection(null);
+    {
+      this.projectionsPanel.disablePolymerChangesTriggerReprojection();
+      if (this.dataSetBeforeFilter != null) {
+        this.resetFilterDataset();
+      }
+      if (state.filteredPoints != null) {
+        this.filterDataset(state.filteredPoints);
+      }
+      this.projectionsPanel.enablePolymerChangesTriggerReprojection();
+    }
+    for (let i = 0; i < state.projections.length; i++) {
+      const point = this.dataSet.points[i];
+      const projection = state.projections[i];
+      const keys = Object.keys(projection);
+      for (let j = 0; j < keys.length; ++j) {
+        point.projections[keys[j]] = projection[keys[j]];
+      }
+    }
+    this.dataSet.hasTSNERun = (state.selectedProjection === 'tsne');
+    this.dataSet.tSNEIteration = state.tSNEIteration;
+    this.projectionsPanel.restoreUIFromBookmark(state);
+    this.inspectorPanel.restoreUIFromBookmark(state);
+    this.dataPanel.selectedColorOptionName = state.selectedColorOptionName;
+    this.dataPanel.setForceCategoricalColoring(
+        !!state.forceCategoricalColoring);
+    this.selectedLabelOption = state.selectedLabelOption;
+    this.projectorScatterPlotAdapter.restoreUIFromBookmark(state);
+    {
+      const dimensions = stateGetAccessorDimensions(state);
+      const components =
+          data.getProjectionComponents(state.selectedProjection, dimensions);
+      const projection = new Projection(
+          state.selectedProjection, components, dimensions.length,
+          this.dataSet);
+      this.setProjection(projection);
+    }
+    this.notifySelectionChanged(state.selectedPoints);
+  }
+}
+
+document.registerElement(Projector.prototype.is, Projector);

From 7c8fffaf5dca165759862322f868c21eda693da1 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 21 Apr 2017 09:41:57 -0800
Subject: [PATCH 12/27] Add GrpcDebugWrapperSession to the public debug API.
 Change: 153843609

---
 tensorflow/python/debug/__init__.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py
index d4a84f62cde..01e36f754c2 100644
--- a/tensorflow/python/debug/__init__.py
+++ b/tensorflow/python/debug/__init__.py
@@ -25,6 +25,7 @@ See the @{$python/tfdbg} guide.
 @@has_inf_or_nan
 @@DumpingDebugHook
 @@DumpingDebugWrapperSession
+@@GrpcDebugWrapperSession
 @@LocalCLIDebugHook
 @@LocalCLIDebugWrapperSession
 @@WatchOptions
@@ -46,6 +47,7 @@ from tensorflow.python.debug.lib.debug_utils import watch_graph_with_blacklists
 
 from tensorflow.python.debug.wrappers.dumping_wrapper import DumpingDebugWrapperSession
 from tensorflow.python.debug.wrappers.framework import WatchOptions
+from tensorflow.python.debug.wrappers.grpc_wrapper import GrpcDebugWrapperSession
 from tensorflow.python.debug.wrappers.hooks import DumpingDebugHook
 from tensorflow.python.debug.wrappers.hooks import LocalCLIDebugHook
 from tensorflow.python.debug.wrappers.local_cli_wrapper import LocalCLIDebugWrapperSession

From fa8f9da8f2607bde3701815b27ba1a67d71525ad Mon Sep 17 00:00:00 2001
From: Pete Warden <petewarden@google.com>
Date: Fri, 21 Apr 2017 10:08:59 -0800
Subject: [PATCH 13/27] Add Mfcc op to TensorFlow for speech feature generation
 Change: 153847440

---
 tensorflow/core/kernels/BUILD                 | 119 ++++++++++
 tensorflow/core/kernels/mfcc.cc               |  67 ++++++
 tensorflow/core/kernels/mfcc.h                |  76 +++++++
 tensorflow/core/kernels/mfcc_dct.cc           |  82 +++++++
 tensorflow/core/kernels/mfcc_dct.h            |  44 ++++
 tensorflow/core/kernels/mfcc_dct_test.cc      |  55 +++++
 .../core/kernels/mfcc_mel_filterbank.cc       | 204 ++++++++++++++++++
 tensorflow/core/kernels/mfcc_mel_filterbank.h |  65 ++++++
 .../core/kernels/mfcc_mel_filterbank_test.cc  |  92 ++++++++
 tensorflow/core/kernels/mfcc_op.cc            | 111 ++++++++++
 tensorflow/core/kernels/mfcc_op_test.cc       |  77 +++++++
 tensorflow/core/kernels/mfcc_test.cc          |  92 ++++++++
 tensorflow/core/ops/audio_ops.cc              |  50 +++++
 13 files changed, 1134 insertions(+)
 create mode 100644 tensorflow/core/kernels/mfcc.cc
 create mode 100644 tensorflow/core/kernels/mfcc.h
 create mode 100644 tensorflow/core/kernels/mfcc_dct.cc
 create mode 100644 tensorflow/core/kernels/mfcc_dct.h
 create mode 100644 tensorflow/core/kernels/mfcc_dct_test.cc
 create mode 100644 tensorflow/core/kernels/mfcc_mel_filterbank.cc
 create mode 100644 tensorflow/core/kernels/mfcc_mel_filterbank.h
 create mode 100644 tensorflow/core/kernels/mfcc_mel_filterbank_test.cc
 create mode 100644 tensorflow/core/kernels/mfcc_op.cc
 create mode 100644 tensorflow/core/kernels/mfcc_op_test.cc
 create mode 100644 tensorflow/core/kernels/mfcc_test.cc

diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index c9ec61328fd..e32f51a3a2a 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3692,11 +3692,130 @@ tf_cuda_cc_test(
     ],
 )
 
+cc_library(
+    name = "mfcc_dct",
+    srcs = ["mfcc_dct.cc"],
+    hdrs = ["mfcc_dct.h"],
+    copts = tf_copts(),
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_cc_test(
+    name = "mfcc_dct_test",
+    size = "small",
+    srcs = ["mfcc_dct_test.cc"],
+    deps = [
+        ":mfcc_dct",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:lib_test_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//third_party/eigen3",
+    ],
+)
+
+cc_library(
+    name = "mfcc_mel_filterbank",
+    srcs = ["mfcc_mel_filterbank.cc"],
+    hdrs = ["mfcc_mel_filterbank.h"],
+    copts = tf_copts(),
+    deps = [
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_cc_test(
+    name = "mfcc_mel_filterbank_test",
+    size = "small",
+    srcs = ["mfcc_mel_filterbank_test.cc"],
+    deps = [
+        ":mfcc_mel_filterbank",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:lib_test_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//third_party/eigen3",
+    ],
+)
+
+cc_library(
+    name = "mfcc",
+    srcs = ["mfcc.cc"],
+    hdrs = ["mfcc.h"],
+    copts = tf_copts(),
+    deps = [
+        ":mfcc_dct",
+        ":mfcc_mel_filterbank",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+    ],
+)
+
+tf_cc_test(
+    name = "mfcc_test",
+    size = "small",
+    srcs = ["mfcc_test.cc"],
+    deps = [
+        ":mfcc",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core:lib_test_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//third_party/eigen3",
+    ],
+)
+
+tf_kernel_library(
+    name = "mfcc_op",
+    prefix = "mfcc_op",
+    deps = [
+        ":mfcc",
+        "//tensorflow/core:audio_ops_op_lib",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+    alwayslink = 1,
+)
+
+tf_cuda_cc_test(
+    name = "mfcc_op_test",
+    size = "small",
+    srcs = ["mfcc_op_test.cc"],
+    deps = [
+        ":mfcc_op",
+        ":ops_util",
+        "//tensorflow/cc:cc_ops",
+        "//tensorflow/cc:client_session",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:tensorflow",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 cc_library(
     name = "audio",
     deps = [
         ":decode_wav_op",
         ":encode_wav_op",
+        ":mfcc_op",
         ":spectrogram_op",
     ],
 )
diff --git a/tensorflow/core/kernels/mfcc.cc b/tensorflow/core/kernels/mfcc.cc
new file mode 100644
index 00000000000..2793005aa26
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc.cc
@@ -0,0 +1,67 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <math.h>
+
+#include "tensorflow/core/kernels/mfcc.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+const double kDefaultUpperFrequencyLimit = 4000;
+const double kDefaultLowerFrequencyLimit = 20;
+const double kFilterbankFloor = 1e-12;
+const int kDefaultFilterbankChannelCount = 40;
+const int kDefaultDCTCoefficientCount = 13;
+
+Mfcc::Mfcc() : initialized_(false),
+               lower_frequency_limit_(kDefaultLowerFrequencyLimit),
+               upper_frequency_limit_(kDefaultUpperFrequencyLimit),
+               filterbank_channel_count_(kDefaultFilterbankChannelCount),
+               dct_coefficient_count_(kDefaultDCTCoefficientCount) { }
+
+bool Mfcc::Initialize(int input_length,
+                      double input_sample_rate) {
+  bool initialized = mel_filterbank_.Initialize(input_length,
+                                                input_sample_rate,
+                                                filterbank_channel_count_,
+                                                lower_frequency_limit_,
+                                                upper_frequency_limit_);
+  initialized &= dct_.Initialize(filterbank_channel_count_,
+                                 dct_coefficient_count_);
+  initialized_ = initialized;
+  return initialized;
+}
+
+void Mfcc::Compute(const std::vector<double>& spectrogram_frame,
+                   std::vector<double>* output) const {
+  if (!initialized_) {
+    LOG(ERROR) << "Mfcc not initialized.";
+    return;
+  }
+  std::vector<double> working;
+  mel_filterbank_.Compute(spectrogram_frame, &working);
+  for (int i = 0; i < working.size(); ++i) {
+    double val = working[i];
+    if (val < kFilterbankFloor) {
+      val = kFilterbankFloor;
+    }
+    working[i] = log(val);
+  }
+  dct_.Compute(working, output);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mfcc.h b/tensorflow/core/kernels/mfcc.h
new file mode 100644
index 00000000000..c39f1049909
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc.h
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Basic class for computing MFCCs from spectrogram slices.
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_
+
+#include <vector>
+
+#include "tensorflow/core/kernels/mfcc_dct.h"
+#include "tensorflow/core/kernels/mfcc_mel_filterbank.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class Mfcc {
+ public:
+  Mfcc();
+  bool Initialize(int input_length,
+                  double input_sample_rate);
+
+  // Input is a single magnitude spectrogram frame. The input spectrum
+  // is filtered into bands using a triangular mel filterbank and a
+  // discrete cosine transform (DCT) of the values is taken. Output is
+  // populated with the lowest dct_coefficient_count of these values.
+  void Compute(const std::vector<double>& spectrogram_frame,
+               std::vector<double>* output) const;
+
+  void set_upper_frequency_limit(double upper_frequency_limit) {
+    CHECK(!initialized_) << "Set frequency limits before calling Initialize.";
+    upper_frequency_limit_ = upper_frequency_limit;
+  }
+
+  void set_lower_frequency_limit(double lower_frequency_limit) {
+    CHECK(!initialized_) << "Set frequency limits before calling Initialize.";
+    lower_frequency_limit_ = lower_frequency_limit;
+  }
+
+  void set_filterbank_channel_count(int filterbank_channel_count) {
+    CHECK(!initialized_) << "Set channel count before calling Initialize.";
+    filterbank_channel_count_ = filterbank_channel_count;
+  }
+
+  void set_dct_coefficient_count(int dct_coefficient_count) {
+    CHECK(!initialized_) << "Set coefficient count before calling Initialize.";
+    dct_coefficient_count_ = dct_coefficient_count;
+  }
+
+ private:
+  MfccMelFilterbank mel_filterbank_;
+  MfccDct dct_;
+  bool initialized_;
+  double lower_frequency_limit_;
+  double upper_frequency_limit_;
+  int filterbank_channel_count_;
+  int dct_coefficient_count_;
+  TF_DISALLOW_COPY_AND_ASSIGN(Mfcc);
+};
+
+}  // namespace tensorflow
+
+#endif  // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_H_
diff --git a/tensorflow/core/kernels/mfcc_dct.cc b/tensorflow/core/kernels/mfcc_dct.cc
new file mode 100644
index 00000000000..aa67a8d6499
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc_dct.cc
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/mfcc_dct.h"
+
+#include <math.h>
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+MfccDct::MfccDct() : initialized_(false) {}
+
+bool MfccDct::Initialize(int input_length, int coefficient_count) {
+  coefficient_count_ = coefficient_count;
+  input_length_ = input_length;
+
+  if (coefficient_count_ < 1) {
+    LOG(ERROR) << "Coefficient count must be positive.";
+    return false;
+  }
+
+  if (input_length < 1) {
+    LOG(ERROR) << "Input length must be positive.";
+    return false;
+  }
+
+  if (coefficient_count_ > input_length_) {
+    LOG(ERROR) << "Coefficient count must be less than or equal to "
+               << "input length.";
+    return false;
+  }
+
+  cosines_.resize(coefficient_count_);
+  double fnorm = sqrt(2.0 / input_length_);
+  // Some platforms don't have M_PI, so define a local constant here.
+  const double pi = std::atan(1) * 4;
+  double arg = pi / input_length_;
+  for (int i = 0; i < coefficient_count_; ++i) {
+    cosines_[i].resize(input_length_);
+    for (int j = 0; j < input_length_; ++j) {
+      cosines_[i][j] = fnorm * cos(i * arg * (j + 0.5));
+    }
+  }
+  initialized_ = true;
+  return true;
+}
+
+void MfccDct::Compute(const std::vector<double> &input,
+                      std::vector<double> *output) const {
+  if (!initialized_) {
+    LOG(ERROR) << "DCT not initialized.";
+    return;
+  }
+
+  output->resize(coefficient_count_);
+  int length = input.size();
+  if (length > input_length_) {
+    length = input_length_;
+  }
+
+  for (int i = 0; i < coefficient_count_; ++i) {
+    double sum = 0.0;
+    for (int j = 0; j < length; ++j) {
+      sum += cosines_[i][j] * input[j];
+    }
+    (*output)[i] = sum;
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mfcc_dct.h b/tensorflow/core/kernels/mfcc_dct.h
new file mode 100644
index 00000000000..4fa3c01628d
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc_dct.h
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Basic minimal DCT class for MFCC speech processing.
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class MfccDct {
+ public:
+  MfccDct();
+  bool Initialize(int input_length, int coefficient_count);
+  void Compute(const std::vector<double>& input,
+               std::vector<double>* output) const;
+
+ private:
+  bool initialized_;
+  int coefficient_count_;
+  int input_length_;
+  std::vector<std::vector<double> > cosines_;
+  TF_DISALLOW_COPY_AND_ASSIGN(MfccDct);
+};
+
+}  // namespace tensorflow
+
+#endif  // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_DCT_H_
diff --git a/tensorflow/core/kernels/mfcc_dct_test.cc b/tensorflow/core/kernels/mfcc_dct_test.cc
new file mode 100644
index 00000000000..7526278fe9e
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc_dct_test.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/mfcc_dct.h"
+
+#include <vector>
+
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+TEST(MfccDctTest, AgreesWithMatlab) {
+  // This test verifies the DCT against MATLAB's dct function.
+  MfccDct dct;
+  std::vector<double> input = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+  const int kCoefficientCount = 6;
+  ASSERT_TRUE(dct.Initialize(input.size(), kCoefficientCount));
+  std::vector<double> output;
+  dct.Compute(input, &output);
+  // Note, the matlab dct function divides the first coefficient by
+  // sqrt(2), whereas we don't, so we multiply the first element of
+  // the matlab result by sqrt(2) to get the expected values below.
+  std::vector<double> expected = {12.1243556530, -4.1625617959, 0.0,
+                                  -0.4082482905, 0.0,           -0.0800788912};
+  ASSERT_EQ(output.size(), kCoefficientCount);
+  for (int i = 0; i < kCoefficientCount; ++i) {
+    EXPECT_NEAR(output[i], expected[i], 1e-10);
+  }
+}
+
+TEST(MfccDctTest, InitializeFailsOnInvalidInput) {
+  MfccDct dct1;
+  EXPECT_FALSE(dct1.Initialize(-50, 1));
+  MfccDct dct2;
+  EXPECT_FALSE(dct1.Initialize(10, -4));
+  MfccDct dct3;
+  EXPECT_FALSE(dct1.Initialize(-1, -1));
+  MfccDct dct4;
+  EXPECT_FALSE(dct1.Initialize(20, 21));
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.cc b/tensorflow/core/kernels/mfcc_mel_filterbank.cc
new file mode 100644
index 00000000000..d68c60280d9
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc_mel_filterbank.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This code resamples the FFT bins, and smooths then with triangle-shaped
+// weights to create a mel-frequency filter bank. For filter i centered at f_i,
+// there is a triangular weighting of the FFT bins that extends from
+// filter f_i-1 (with a value of zero at the left edge of the triangle) to f_i
+// (where the filter value is 1) to f_i+1 (where the filter values returns to
+// zero).
+
+// Note: this code fails if you ask for too many channels.  The algorithm used
+// here assumes that each FFT bin contributes to at most two channels: the
+// right side of a triangle for channel i, and the left side of the triangle
+// for channel i+1.  If you ask for so many channels that some of the
+// resulting mel triangle filters are smaller than a single FFT bin, these
+// channels may end up with no contributing FFT bins.  The resulting mel
+// spectrum output will have some channels that are always zero.
+
+#include "tensorflow/core/kernels/mfcc_mel_filterbank.h"
+
+#include <math.h>
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+MfccMelFilterbank::MfccMelFilterbank() : initialized_(false) {}
+
+bool MfccMelFilterbank::Initialize(int input_length,
+                               double input_sample_rate,
+                               int output_channel_count,
+                               double lower_frequency_limit,
+                               double upper_frequency_limit) {
+  num_channels_ = output_channel_count;
+  sample_rate_  = input_sample_rate;
+  input_length_ = input_length;
+
+  if (num_channels_ < 1) {
+    LOG(ERROR) << "Number of filterbank channels must be positive.";
+    return false;
+  }
+
+  if (sample_rate_ <= 0) {
+    LOG(ERROR) << "Sample rate must be positive.";
+    return false;
+  }
+
+  if (input_length < 2) {
+    LOG(ERROR) << "Input length must greater than 1.";
+    return false;
+  }
+
+  if (lower_frequency_limit <= 0) {
+    LOG(ERROR) << "Lower frequency limit must be positive.";
+    return false;
+  }
+
+  if (upper_frequency_limit <= lower_frequency_limit) {
+    LOG(ERROR) << "Upper frequency limit must be greater than "
+               << "lower frequency limit.";
+    return false;
+  }
+
+  // An extra center frequency is computed at the top to get the upper
+  // limit on the high side of the final triangular filter.
+  center_frequencies_.resize(num_channels_ + 1);
+  const double mel_low = FreqToMel(lower_frequency_limit);
+  const double mel_hi = FreqToMel(upper_frequency_limit);
+  const double mel_span = mel_hi - mel_low;
+  const double mel_spacing = mel_span / static_cast<double>(num_channels_ + 1);
+  for (int i = 0; i < num_channels_ + 1; ++i) {
+    center_frequencies_[i] = mel_low + (mel_spacing * (i + 1));
+  }
+
+  // Always exclude DC; emulate HTK.
+  const double hz_per_sbin = 0.5 * sample_rate_ /
+      static_cast<double>(input_length_ - 1);
+  start_index_ = static_cast<int>(1.5 + (lower_frequency_limit /
+                                           hz_per_sbin));
+  end_index_ = static_cast<int>(upper_frequency_limit / hz_per_sbin);
+
+  // Maps the input spectrum bin indices to filter bank channels/indices. For
+  // each FFT bin, band_mapper tells us which channel this bin contributes to
+  // on the right side of the triangle.  Thus this bin also contributes to the
+  // left side of the next channel's triangle response.
+  band_mapper_.resize(input_length_);
+  int channel = 0;
+  for (int i = 0; i < input_length_; ++i) {
+    double melf = FreqToMel(i * hz_per_sbin);
+    if ((i < start_index_) || (i > end_index_)) {
+      band_mapper_[i] = -2;  // Indicate an unused Fourier coefficient.
+    } else {
+      while ((center_frequencies_[channel] < melf) &&
+             (channel < num_channels_)) {
+        ++channel;
+      }
+      band_mapper_[i] = channel - 1;  // Can be == -1
+    }
+  }
+
+  // Create the weighting functions to taper the band edges.  The contribution
+  // of any one FFT bin is based on its distance along the continuum between two
+  // mel-channel center frequencies.  This bin contributes weights_[i] to the
+  // current channel and 1-weights_[i] to the next channel.
+  weights_.resize(input_length_);
+  for (int i = 0; i < input_length_; ++i) {
+    channel = band_mapper_[i];
+    if ((i < start_index_) || (i > end_index_)) {
+      weights_[i] = 0.0;
+    } else {
+      if (channel >= 0) {
+        weights_[i] = (center_frequencies_[channel + 1] -
+                       FreqToMel(i * hz_per_sbin)) /
+            (center_frequencies_[channel + 1] - center_frequencies_[channel]);
+      } else {
+        weights_[i] = (center_frequencies_[0] - FreqToMel(i * hz_per_sbin)) /
+            (center_frequencies_[0] - mel_low);
+      }
+    }
+  }
+  // Check the sum of FFT bin weights for every mel band to identify
+  // situations where the mel bands are so narrow that they don't get
+  // significant weight on enough (or any) FFT bins -- i.e., too many
+  // mel bands have been requested for the given FFT size.
+  std::vector<int> bad_channels;
+  for (int c = 0; c < num_channels_; ++c) {
+    float band_weights_sum = 0.0;
+    for (int i = 0; i < input_length_; ++i) {
+      if (band_mapper_[i] == c - 1) {
+        band_weights_sum += (1.0 - weights_[i]);
+      } else if (band_mapper_[i] == c) {
+        band_weights_sum += weights_[i];
+      }
+    }
+    // The lowest mel channels have the fewest FFT bins and the lowest
+    // weights sum.  But given that the target gain at the center frequency
+    // is 1.0, if the total sum of weights is 0.5, we're in bad shape.
+    if (band_weights_sum < 0.5) {
+      bad_channels.push_back(c);
+    }
+  }
+  if (!bad_channels.empty()) {
+    LOG(ERROR) << "Missing " << bad_channels.size() << " bands " <<
+        " starting at " << bad_channels[0] <<
+        " in mel-frequency design. " <<
+        "Perhaps too many channels or " <<
+        "not enough frequency resolution in spectrum. (" <<
+        "input_length: " << input_length <<
+        " input_sample_rate: " << input_sample_rate <<
+        " output_channel_count: " << output_channel_count <<
+        " lower_frequency_limit: " << lower_frequency_limit <<
+        " upper_frequency_limit: " << upper_frequency_limit;
+  }
+  initialized_ = true;
+  return true;
+}
+
+// Compute the mel spectrum from the squared-magnitude FFT input by taking the
+// square root, then summing FFT magnitudes under triangular integration windows
+// whose widths increase with frequency.
+void MfccMelFilterbank::Compute(const std::vector<double> &input,
+                            std::vector<double> *output) const {
+  if (!initialized_) {
+    LOG(ERROR) << "Mel Filterbank not initialized.";
+    return;
+  }
+
+  if (input.size() <= end_index_) {
+    LOG(ERROR) << "Input too short to compute filterbank";
+    return;
+  }
+
+  // Ensure output is right length and reset all values.
+  output->assign(num_channels_, 0.0);
+
+  for (int i = start_index_; i <= end_index_; i++) {  // For each FFT bin
+    double spec_val = sqrt(input[i]);
+    double weighted = spec_val * weights_[i];
+    int channel = band_mapper_[i];
+    if (channel >= 0)
+      (*output)[channel] += weighted;  // Right side of triangle, downward slope
+    channel++;
+    if (channel < num_channels_)
+      (*output)[channel] += spec_val - weighted;  // Left side of triangle
+  }
+}
+
+double MfccMelFilterbank::FreqToMel(double freq) const {
+  return 1127.0 * log(1.0 + (freq / 700.0));
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank.h b/tensorflow/core/kernels/mfcc_mel_filterbank.h
new file mode 100644
index 00000000000..33ea1bdb5bc
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc_mel_filterbank.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Basic class for applying a mel-scale filterbank to an input.
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
+
+#include <vector>
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class MfccMelFilterbank {
+ public:
+  MfccMelFilterbank();
+  bool Initialize(int input_length,  // Number of unique FFT bins fftsize/2+1.
+                  double input_sample_rate,
+                  int output_channel_count,
+                  double lower_frequency_limit,
+                  double upper_frequency_limit);
+
+  // Takes a magnitude spectrogram slice as input, computes a
+  // traingular mel filterbank and places the result in output.
+  void Compute(const std::vector<double>& input,
+               std::vector<double>* output) const;
+
+ private:
+  double FreqToMel(double freq) const;
+  bool initialized_;
+  int num_channels_;
+  double sample_rate_;
+  int input_length_;
+  std::vector<double> center_frequencies_;  // In mel, for each mel channel.
+
+  // Each FFT bin b contributes to two triangular mel channels, with
+  // proportion weights_[b] going into mel channel band_mapper_[b], and
+  // proportion (1 - weights_[b]) going into channel band_mapper_[b] + 1.
+  // Thus, weights_ contains the weighting applied to each FFT bin for the
+  // upper-half of the triangular band.
+  std::vector<double> weights_;  // Right-side weight for this fft  bin.
+
+  // FFT bin i contributes to the upper side of mel channel band_mapper_[i]
+  std::vector<int> band_mapper_;
+  int start_index_;  // Lowest FFT bin used to calculate mel spectrum.
+  int end_index_;  // Highest FFT bin used to calculate mel spectrum.
+
+  TF_DISALLOW_COPY_AND_ASSIGN(MfccMelFilterbank);
+};
+
+}  // namespace tensorflow
+
+#endif  // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_MFCC_MEL_FILTERBANK_H_
diff --git a/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc
new file mode 100644
index 00000000000..c3a7e779403
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc_mel_filterbank_test.cc
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/mfcc_mel_filterbank.h"
+
+#include <vector>
+
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+TEST(MfccMelFilterbankTest, AgreesWithPythonGoldenValues) {
+  // This test verifies the Mel filterbank against "golden values".
+  // Golden values are from an independent Python Mel implementation.
+  MfccMelFilterbank filterbank;
+
+  std::vector<double> input;
+  const int kSampleCount = 513;
+  for (int i = 0; i < kSampleCount; ++i) {
+    input.push_back(i + 1);
+  }
+  const int kChannelCount = 20;
+  filterbank.Initialize(input.size(),
+                        22050 /* sample rate */,
+                        kChannelCount /* channels */,
+                        20.0 /*  lower frequency limit */,
+                        4000.0 /* upper frequency limit */);
+
+  std::vector<double> output;
+  filterbank.Compute(input, &output);
+
+  std::vector<double> expected = {
+      7.38894574,   10.30330648, 13.72703292,  17.24158686,  21.35253118,
+      25.77781089,  31.30624108, 37.05877236,  43.9436536,   51.80306637,
+      60.79867148,  71.14363376, 82.90910141,  96.50069158,  112.08428368,
+      129.96721968, 150.4277597, 173.74997634, 200.86037462, 231.59802942};
+
+  ASSERT_EQ(output.size(), kChannelCount);
+
+  for (int i = 0; i < kChannelCount; ++i) {
+    EXPECT_NEAR(output[i], expected[i], 1e-04);
+  }
+}
+
+TEST(MfccMelFilterbankTest, IgnoresExistingContentOfOutputVector) {
+  // Test for bug where the output vector was not cleared before
+  // accumulating next frame's weighted spectral values.
+  MfccMelFilterbank filterbank;
+
+  const int kSampleCount = 513;
+  std::vector<double> input;
+  std::vector<double> output;
+
+  filterbank.Initialize(kSampleCount,
+                        22050 /* sample rate */,
+                        20 /* channels */,
+                        20.0 /*  lower frequency limit */,
+                        4000.0 /* upper frequency limit */);
+
+
+  // First call with nonzero input value, and an empty output vector,
+  // will resize the output and fill it with the correct, nonzero outputs.
+  input.assign(kSampleCount, 1.0);
+  filterbank.Compute(input, &output);
+  for (const double value : output) {
+    EXPECT_LE(0.0, value);
+  }
+
+  // Second call with zero input should also generate zero output.  However,
+  // the output vector now is already the correct size, but full of nonzero
+  // values.  Make sure these don't affect the output.
+  input.assign(kSampleCount, 0.0);
+  filterbank.Compute(input, &output);
+  for (const double value : output) {
+    EXPECT_EQ(0.0, value);
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mfcc_op.cc b/tensorflow/core/kernels/mfcc_op.cc
new file mode 100644
index 00000000000..02643857c1f
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc_op.cc
@@ -0,0 +1,111 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/audio_ops.cc
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/mfcc.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Create a speech fingerpring from spectrogram data.
+class MfccOp : public OpKernel {
+ public:
+  explicit MfccOp(OpKernelConstruction* context) : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("upper_frequency_limit",
+                                             &upper_frequency_limit_));
+    OP_REQUIRES_OK(context, context->GetAttr("lower_frequency_limit",
+                                             &lower_frequency_limit_));
+    OP_REQUIRES_OK(context, context->GetAttr("filterbank_channel_count",
+                                             &filterbank_channel_count_));
+    OP_REQUIRES_OK(context, context->GetAttr("dct_coefficient_count",
+                                             &dct_coefficient_count_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& spectrogram = context->input(0);
+    OP_REQUIRES(context, spectrogram.dims() == 3,
+                errors::InvalidArgument("spectrogram must be 3-dimensional",
+                                        spectrogram.shape().DebugString()));
+    const Tensor& sample_rate_tensor = context->input(1);
+    OP_REQUIRES(context, TensorShapeUtils::IsScalar(sample_rate_tensor.shape()),
+                errors::InvalidArgument(
+                    "Input sample_rate should be a scalar tensor, got ",
+                    sample_rate_tensor.shape().DebugString(), " instead."));
+    const int32 sample_rate = sample_rate_tensor.scalar<int32>()();
+
+    const int spectrogram_channels = spectrogram.dim_size(2);
+    const int spectrogram_samples = spectrogram.dim_size(1);
+    const int audio_channels = spectrogram.dim_size(0);
+
+    Mfcc mfcc;
+    mfcc.set_upper_frequency_limit(upper_frequency_limit_);
+    mfcc.set_lower_frequency_limit(lower_frequency_limit_);
+    mfcc.set_filterbank_channel_count(filterbank_channel_count_);
+    mfcc.set_dct_coefficient_count(dct_coefficient_count_);
+    OP_REQUIRES(context, mfcc.Initialize(spectrogram_channels, sample_rate),
+                errors::InvalidArgument(
+                    "Mfcc initialization failed for channel count ",
+                    spectrogram_channels, " and sample rate ", sample_rate));
+
+    Tensor* output_tensor = nullptr;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       0,
+                       TensorShape({audio_channels, spectrogram_samples,
+                                    dct_coefficient_count_}),
+                       &output_tensor));
+
+    const float* spectrogram_flat = spectrogram.flat<float>().data();
+    float* output_flat = output_tensor->flat<float>().data();
+
+    for (int audio_channel = 0; audio_channel < audio_channels;
+         ++audio_channel) {
+      for (int spectrogram_sample = 0; spectrogram_sample < spectrogram_samples;
+           ++spectrogram_sample) {
+        const float* sample_data =
+            spectrogram_flat +
+            (audio_channel * spectrogram_samples * spectrogram_channels) +
+            (spectrogram_sample * spectrogram_channels);
+        std::vector<double> mfcc_input(sample_data,
+                                       sample_data + spectrogram_channels);
+        std::vector<double> mfcc_output;
+        mfcc.Compute(mfcc_input, &mfcc_output);
+        DCHECK_EQ(dct_coefficient_count_, mfcc_output.size());
+        float* output_data =
+            output_flat +
+            (audio_channel * spectrogram_samples * dct_coefficient_count_) +
+            (spectrogram_sample * dct_coefficient_count_);
+        for (int i = 0; i < dct_coefficient_count_; ++i) {
+          output_data[i] = mfcc_output[i];
+        }
+      }
+    }
+  }
+
+ private:
+  float upper_frequency_limit_;
+  float lower_frequency_limit_;
+  int32 filterbank_channel_count_;
+  int32 dct_coefficient_count_;
+};
+REGISTER_KERNEL_BUILDER(Name("Mfcc").Device(DEVICE_CPU), MfccOp);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mfcc_op_test.cc b/tensorflow/core/kernels/mfcc_op_test.cc
new file mode 100644
index 00000000000..d16171d5265
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc_op_test.cc
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/ops/audio_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/math_ops.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+using namespace ops;  // NOLINT(build/namespaces)
+
+TEST(MfccOpTest, SimpleTest) {
+  Scope root = Scope::NewRootScope();
+
+  Tensor spectrogram_tensor(DT_FLOAT, TensorShape({1, 1, 513}));
+  test::FillIota<float>(&spectrogram_tensor, 1.0f);
+
+  Output spectrogram_const_op = Const(root.WithOpName("spectrogram_const_op"),
+                                      Input::Initializer(spectrogram_tensor));
+
+  Output sample_rate_const_op =
+      Const(root.WithOpName("sample_rate_const_op"), 22050);
+
+  Mfcc mfcc_op = Mfcc(root.WithOpName("mfcc_op"), spectrogram_const_op,
+                      sample_rate_const_op);
+
+  TF_ASSERT_OK(root.status());
+
+  ClientSession session(root);
+  std::vector<Tensor> outputs;
+
+  TF_EXPECT_OK(
+      session.Run(ClientSession::FeedType(), {mfcc_op.output}, &outputs));
+
+  const Tensor& mfcc_tensor = outputs[0];
+
+  EXPECT_EQ(3, mfcc_tensor.dims());
+  EXPECT_EQ(13, mfcc_tensor.dim_size(2));
+  EXPECT_EQ(1, mfcc_tensor.dim_size(1));
+  EXPECT_EQ(1, mfcc_tensor.dim_size(0));
+
+  test::ExpectTensorNear<float>(
+      mfcc_tensor,
+      test::AsTensor<float>(
+          {29.13970072, -6.41568601, -0.61903012, -0.96778652, -0.26819878,
+           -0.40907028, -0.15614748, -0.23203119, -0.10481487, -0.1543029,
+           -0.0769791, -0.10806114, -0.06047613},
+          TensorShape({1, 1, 13})),
+      1e-3);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mfcc_test.cc b/tensorflow/core/kernels/mfcc_test.cc
new file mode 100644
index 00000000000..9ab726e5b9c
--- /dev/null
+++ b/tensorflow/core/kernels/mfcc_test.cc
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/mfcc.h"
+
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+TEST(MfccTest, AgreesWithPythonGoldenValues) {
+  Mfcc mfcc;
+  std::vector<double> input;
+  const int kSampleCount = 513;
+  for (int i = 0; i < kSampleCount; ++i) {
+    input.push_back(i + 1);
+  }
+
+  ASSERT_TRUE(mfcc.Initialize(input.size(), 22050 /*sample rate*/));
+
+  std::vector<double> output;
+  mfcc.Compute(input, &output);
+
+  std::vector<double> expected = {29.13970072, -6.41568601, -0.61903012,
+                             -0.96778652, -0.26819878, -0.40907028,
+                             -0.15614748, -0.23203119, -0.10481487,
+                             -0.1543029,  -0.0769791,  -0.10806114,
+                             -0.06047613};
+
+  ASSERT_EQ(expected.size(), output.size());
+  for (int i = 0; i < output.size(); ++i) {
+    EXPECT_NEAR(output[i], expected[i], 1e-04);
+  }
+}
+
+TEST(MfccTest, AvoidsNansWithZeroInput) {
+  Mfcc mfcc;
+  std::vector<double> input;
+  const int kSampleCount = 513;
+  for (int i = 0; i < kSampleCount; ++i) {
+    input.push_back(0.0);
+  }
+
+  ASSERT_TRUE(mfcc.Initialize(input.size(), 22050 /*sample rate*/));
+
+  std::vector<double> output;
+  mfcc.Compute(input, &output);
+
+  int expected_size = 13;
+  ASSERT_EQ(expected_size, output.size());
+  for (const double value : output) {
+    EXPECT_FALSE(isnan(value));
+  }
+}
+
+TEST(MfccTest, SimpleInputSaneResult) {
+  Mfcc mfcc;
+  mfcc.set_lower_frequency_limit(125.0);
+  mfcc.set_upper_frequency_limit(3800.0);
+  mfcc.set_filterbank_channel_count(40);
+  mfcc.set_dct_coefficient_count(40);
+  const int kSpectrogramSize = 129;
+  std::vector<double> input(kSpectrogramSize, 0.0);
+
+  // Simulate a low-frequency sinusoid from the spectrogram.
+  const int kHotBin = 10;
+  input[kHotBin] = 1.0;
+  ASSERT_TRUE(mfcc.Initialize(input.size(), 8000));
+
+  std::vector<double> output;
+  mfcc.Compute(input, &output);
+
+  // For a single low-frequency input, output beyond c_0 should look like
+  // a slow cosine, with a slight delay.  Largest value will be c_1.
+  EXPECT_EQ(output.begin() + 1, std::max_element(output.begin(), output.end()));
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/ops/audio_ops.cc b/tensorflow/core/ops/audio_ops.cc
index 2f55e45e377..02b13a455ce 100644
--- a/tensorflow/core/ops/audio_ops.cc
+++ b/tensorflow/core/ops/audio_ops.cc
@@ -100,6 +100,26 @@ Status SpectrogramShapeFn(InferenceContext* c) {
   return Status::OK();
 }
 
+Status MfccShapeFn(InferenceContext* c) {
+  ShapeHandle spectrogram;
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram));
+  ShapeHandle unused;
+  TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+
+  int32 dct_coefficient_count;
+  TF_RETURN_IF_ERROR(
+      c->GetAttr("dct_coefficient_count", &dct_coefficient_count));
+
+  DimensionHandle spectrogram_channels = c->Dim(spectrogram, 0);
+  DimensionHandle spectrogram_length = c->Dim(spectrogram, 1);
+
+  DimensionHandle output_channels = c->MakeDim(dct_coefficient_count);
+
+  c->set_output(0, c->MakeShape({spectrogram_channels, spectrogram_length,
+                                 output_channels}));
+  return Status::OK();
+}
+
 }  // namespace
 
 REGISTER_OP("DecodeWav")
@@ -200,4 +220,34 @@ magnitude_squared: Whether to return the squared magnitude or just the
 spectrogram: 3D representation of the audio frequencies as an image.
 )doc");
 
+REGISTER_OP("Mfcc")
+    .Input("spectrogram: float")
+    .Input("sample_rate: int32")
+    .Attr("upper_frequency_limit: float = 4000")
+    .Attr("lower_frequency_limit: float = 20")
+    .Attr("filterbank_channel_count: int = 40")
+    .Attr("dct_coefficient_count: int = 13")
+    .Output("output: float")
+    .SetShapeFn(MfccShapeFn)
+    .Doc(R"doc(
+Transforms a spectrogram into a form that's useful for speech recognition.
+
+Mel Frequency Cepstral Coefficients are a way of representing audio data that's
+been effective as an input feature for machine learning. They are created by
+taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
+higher frequencies that are less significant to the human ear. They have a long
+history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
+is a good resource to learn more.
+
+spectrogram: Typically produced by the Spectrogram op, with magnitude_squared
+  set to true.
+sample_rate: How many samples per second the source audio used.
+upper_frequency_limit: The highest frequency to use when calculating the
+  ceptstrum.
+lower_frequency_limit: The lowest frequency to use when calculating the
+  ceptstrum.
+filterbank_channel_count: Resolution of the Mel bank used internally.
+dct_coefficient_count: How many output channels to produce per time slice.
+)doc");
+
 }  // namespace tensorflow

From 19922ea85779961f93ca1d2e1bfdb71a70044a54 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 21 Apr 2017 10:12:29 -0800
Subject: [PATCH 14/27] [TF:XLA] This flag is unnecessary, as we can get the
 path to ptxas from tensorflow::CudaRoot(). Change: 153847918

---
 tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc | 1 -
 tensorflow/compiler/xla/service/gpu/gpu_compiler.cc        | 4 ++--
 2 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc
index e79d3635095..7d3ad60aea4 100644
--- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc
@@ -38,7 +38,6 @@ static void AllocateFlags() {
   flags = new GpuCompilerFlags;
   flags->xla_gpu_embed_ir = false;
   flags->xla_cuda_data_dir = "./cuda_sdk_lib";
-  flags->xla_ptxas_path = "/usr/local/cuda/bin/ptxas";
   flag_list = new std::vector<tensorflow::Flag>({
       tensorflow::Flag(
           "xla_gpu_embed_ir", &flags->xla_gpu_embed_ir,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 161d0033a3a..43960cd3a8f 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -187,8 +187,8 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
 
 // Invokes the ptxas tool on the given PTX string, and dumps its output.
 void DumpPtxasInfo(const string& ptx) {
-  legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags();
-  const string ptxas_path = flags->xla_ptxas_path;
+  const string ptxas_path =
+      tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas");
   // Do not log PTX stats if ptxas is not found at the given path.
   if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) {
     LOG(WARNING)

From d524194bad2cbe0dbd88abec13ef43015c146b23 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 21 Apr 2017 10:32:58 -0800
Subject: [PATCH 15/27] Update ops-related pbtxt files. Change: 153850549

---
 .../core/ops/compat/ops_history.v1.pbtxt      | 43 ++++++++++++++++
 tensorflow/core/ops/ops.pbtxt                 | 51 +++++++++++++++++++
 2 files changed, 94 insertions(+)

diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 50e49713140..1781f778b49 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -11316,6 +11316,49 @@ op {
     }
   }
 }
+op {
+  name: "Mfcc"
+  input_arg {
+    name: "spectrogram"
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sample_rate"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "upper_frequency_limit"
+    type: "float"
+    default_value {
+      f: 4000
+    }
+  }
+  attr {
+    name: "lower_frequency_limit"
+    type: "float"
+    default_value {
+      f: 20
+    }
+  }
+  attr {
+    name: "filterbank_channel_count"
+    type: "int"
+    default_value {
+      i: 40
+    }
+  }
+  attr {
+    name: "dct_coefficient_count"
+    type: "int"
+    default_value {
+      i: 13
+    }
+  }
+}
 op {
   name: "Min"
   input_arg {
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 9725b6538de..6d28cb7e840 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -11539,6 +11539,57 @@ op {
   summary: "V2 format specific: merges the metadata files of sharded checkpoints.  The"
   description: "result is one logical checkpoint, with one physical metadata file and renamed\ndata files.\n\nIntended for \"grouping\" multiple checkpoints in a sharded checkpoint setup.\n\nIf delete_old_dirs is true, attempts to delete recursively the dirname of each\npath in the input checkpoint_prefixes.  This is useful when those paths are non\nuser-facing temporary locations."
 }
+op {
+  name: "Mfcc"
+  input_arg {
+    name: "spectrogram"
+    description: "Typically produced by the Spectrogram op, with magnitude_squared\nset to true."
+    type: DT_FLOAT
+  }
+  input_arg {
+    name: "sample_rate"
+    description: "How many samples per second the source audio used."
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    type: DT_FLOAT
+  }
+  attr {
+    name: "upper_frequency_limit"
+    type: "float"
+    default_value {
+      f: 4000
+    }
+    description: "The highest frequency to use when calculating the\nceptstrum."
+  }
+  attr {
+    name: "lower_frequency_limit"
+    type: "float"
+    default_value {
+      f: 20
+    }
+    description: "The lowest frequency to use when calculating the\nceptstrum."
+  }
+  attr {
+    name: "filterbank_channel_count"
+    type: "int"
+    default_value {
+      i: 40
+    }
+    description: "Resolution of the Mel bank used internally."
+  }
+  attr {
+    name: "dct_coefficient_count"
+    type: "int"
+    default_value {
+      i: 13
+    }
+    description: "How many output channels to produce per time slice."
+  }
+  summary: "Transforms a spectrogram into a form that\'s useful for speech recognition."
+  description: "Mel Frequency Cepstral Coefficients are a way of representing audio data that\'s\nbeen effective as an input feature for machine learning. They are created by\ntaking the spectrum of a spectrogram (a \'cepstrum\'), and discarding some of the\nhigher frequencies that are less significant to the human ear. They have a long\nhistory in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum\nis a good resource to learn more."
+}
 op {
   name: "Min"
   input_arg {

From c3bf39b7a6c3cc41f209ac863c764498b503d4f5 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 21 Apr 2017 10:38:34 -0800
Subject: [PATCH 16/27] Go: Update generated wrapper functions for TensorFlow
 ops. Change: 153851292

---
 tensorflow/go/op/wrappers.go | 77 ++++++++++++++++++++++++++++++++++++
 1 file changed, 77 insertions(+)

diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index e832690e183..b21e8fd4481 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -2456,6 +2456,83 @@ func ParallelConcat(scope *Scope, values []tf.Output, shape tf.Shape) (output tf
 	return op.Output(0)
 }
 
+// MfccAttr is an optional argument to Mfcc.
+type MfccAttr func(optionalAttr)
+
+// MfccUpperFrequencyLimit sets the optional upper_frequency_limit attribute to value.
+//
+// value: The highest frequency to use when calculating the
+// ceptstrum.
+// If not specified, defaults to 4000
+func MfccUpperFrequencyLimit(value float32) MfccAttr {
+	return func(m optionalAttr) {
+		m["upper_frequency_limit"] = value
+	}
+}
+
+// MfccLowerFrequencyLimit sets the optional lower_frequency_limit attribute to value.
+//
+// value: The lowest frequency to use when calculating the
+// ceptstrum.
+// If not specified, defaults to 20
+func MfccLowerFrequencyLimit(value float32) MfccAttr {
+	return func(m optionalAttr) {
+		m["lower_frequency_limit"] = value
+	}
+}
+
+// MfccFilterbankChannelCount sets the optional filterbank_channel_count attribute to value.
+//
+// value: Resolution of the Mel bank used internally.
+// If not specified, defaults to 40
+func MfccFilterbankChannelCount(value int64) MfccAttr {
+	return func(m optionalAttr) {
+		m["filterbank_channel_count"] = value
+	}
+}
+
+// MfccDctCoefficientCount sets the optional dct_coefficient_count attribute to value.
+//
+// value: How many output channels to produce per time slice.
+// If not specified, defaults to 13
+func MfccDctCoefficientCount(value int64) MfccAttr {
+	return func(m optionalAttr) {
+		m["dct_coefficient_count"] = value
+	}
+}
+
+// Transforms a spectrogram into a form that's useful for speech recognition.
+//
+// Mel Frequency Cepstral Coefficients are a way of representing audio data that's
+// been effective as an input feature for machine learning. They are created by
+// taking the spectrum of a spectrogram (a 'cepstrum'), and discarding some of the
+// higher frequencies that are less significant to the human ear. They have a long
+// history in the speech recognition world, and https://en.wikipedia.org/wiki/Mel-frequency_cepstrum
+// is a good resource to learn more.
+//
+// Arguments:
+//	spectrogram: Typically produced by the Spectrogram op, with magnitude_squared
+// set to true.
+//	sample_rate: How many samples per second the source audio used.
+func Mfcc(scope *Scope, spectrogram tf.Output, sample_rate tf.Output, optional ...MfccAttr) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "Mfcc",
+		Input: []tf.Input{
+			spectrogram, sample_rate,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // UniqueAttr is an optional argument to Unique.
 type UniqueAttr func(optionalAttr)
 

From 8e5041918f2e709ded94e63fb1779d6bb363becb Mon Sep 17 00:00:00 2001
From: Charles Nicholson <nicholsonc@google.com>
Date: Fri, 21 Apr 2017 10:59:14 -0800
Subject: [PATCH 17/27] Introduce TFDecorator, a base class for Python
 TensorFlow decorators. Provides basic introspection and "unwrap" services,
 allowing tooling code to fully 'understand' the wrapped object. Change:
 153854044

---
 .../distributions/python/ops/distribution.py  |   6 +-
 .../python/ops/kullback_leibler.py            |   7 +-
 .../contrib/framework/python/ops/arg_scope.py |  11 +-
 .../keras/python/keras/backend_test.py        |   5 +-
 .../keras/python/keras/engine/topology.py     |   8 +-
 .../contrib/keras/python/keras/layers/core.py |   4 +-
 .../keras/python/keras/layers/wrappers.py     |   4 +-
 .../keras/python/keras/testing_utils.py       |   5 +-
 .../keras/python/keras/utils/generic_utils.py |   6 +-
 .../python/keras/wrappers/scikit_learn.py     |   6 +-
 .../labeled_tensor/python/ops/_typecheck.py   |   4 +-
 .../learn/python/learn/dataframe/transform.py |   5 +-
 .../python/learn/estimators/estimator.py      |   8 +-
 .../learn/estimators/estimator_test_utils.py  |   4 +-
 .../learn/python/learn/estimators/head.py     |   6 +-
 .../learn/python/learn/export_strategy.py     |  10 +-
 .../contrib/learn/python/learn/metric_spec.py |   6 +-
 .../contrib/learn/python/learn/monitors.py    |   4 +-
 tensorflow/contrib/specs/python/specs.py      |  16 +-
 .../docs_src/api_guides/python/index.md       |   1 +
 tensorflow/python/BUILD                       |  33 ++
 tensorflow/python/__init__.py                 |   4 +-
 tensorflow/python/debug/BUILD                 |   1 +
 .../python/debug/cli/analyzer_cli_test.py     |   6 +-
 .../python/debug/lib/source_utils_test.py     |  86 ++---
 tensorflow/python/estimator/estimator.py      |   9 +-
 tensorflow/python/framework/contrib_test.py   |   9 +-
 tensorflow/python/framework/function.py       |  12 +-
 tensorflow/python/framework/op_def_library.py |   5 +-
 tensorflow/python/framework/ops.py            |  24 +-
 tensorflow/python/layers/base.py              |   1 -
 tensorflow/python/ops/variable_scope.py       |   8 +-
 tensorflow/python/platform/benchmark.py       |   5 +-
 tensorflow/python/platform/googletest.py      |  21 +-
 tensorflow/python/platform/resource_loader.py |   8 +-
 tensorflow/python/util/all_util.py            |   6 +-
 tensorflow/python/util/deprecation.py         |  34 +-
 tensorflow/python/util/tf_contextlib.py       |  36 ++
 tensorflow/python/util/tf_contextlib_test.py  |  92 +++++
 tensorflow/python/util/tf_decorator.py        | 167 +++++++++
 tensorflow/python/util/tf_decorator_test.py   | 243 +++++++++++++
 tensorflow/python/util/tf_inspect.py          | 141 ++++++++
 tensorflow/python/util/tf_inspect_test.py     | 327 ++++++++++++++++++
 .../tools/api/golden/tensorflow.-graph.pbtxt  |  10 +-
 .../tensorflow.gfile.-fast-g-file.pbtxt       |   2 +-
 .../api/golden/tensorflow.gfile.-g-file.pbtxt |   2 +-
 .../api/golden/tensorflow.gfile.-open.pbtxt   |   2 +-
 tensorflow/tools/api/golden/tensorflow.pbtxt  |  18 +-
 .../api/lib/python_object_to_proto_visitor.py |  18 +-
 tensorflow/tools/common/public_api.py         |   5 +-
 tensorflow/tools/common/traverse.py           |  13 +-
 .../tools/docs/doc_generator_visitor.py       |  19 +-
 tensorflow/tools/docs/generate.py             |   5 +-
 tensorflow/tools/docs/generate_1_0.py         |  20 +-
 tensorflow/tools/docs/generate_lib.py         |  63 ++--
 tensorflow/tools/docs/parser.py               |  72 ++--
 tensorflow/tools/docs/parser_test.py          |  32 +-
 57 files changed, 1352 insertions(+), 333 deletions(-)
 create mode 100644 tensorflow/python/util/tf_contextlib.py
 create mode 100644 tensorflow/python/util/tf_contextlib_test.py
 create mode 100644 tensorflow/python/util/tf_decorator.py
 create mode 100644 tensorflow/python/util/tf_decorator_test.py
 create mode 100644 tensorflow/python/util/tf_inspect.py
 create mode 100644 tensorflow/python/util/tf_inspect_test.py

diff --git a/tensorflow/contrib/distributions/python/ops/distribution.py b/tensorflow/contrib/distributions/python/ops/distribution.py
index 95d6c233886..0b7ffbd792e 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution.py
@@ -20,7 +20,6 @@ from __future__ import print_function
 
 import abc
 import contextlib
-import inspect
 import types
 
 import numpy as np
@@ -33,6 +32,7 @@ from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.util import tf_inspect
 
 
 _DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
@@ -154,12 +154,12 @@ class _DistributionMeta(abc.ABCMeta):
       if class_special_attr_value is None:
         # No _special method available, no need to update the docstring.
         continue
-      class_special_attr_docstring = inspect.getdoc(class_special_attr_value)
+      class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
       if not class_special_attr_docstring:
         # No docstring to append.
         continue
       class_attr_value = _copy_fn(base_attr_value)
-      class_attr_docstring = inspect.getdoc(base_attr_value)
+      class_attr_docstring = tf_inspect.getdoc(base_attr_value)
       if class_attr_docstring is None:
         raise ValueError(
             "Expected base class fn to contain a docstring: %s.%s"
diff --git a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py
index bb94a876809..335fe7a5e2a 100644
--- a/tensorflow/contrib/distributions/python/ops/kullback_leibler.py
+++ b/tensorflow/contrib/distributions/python/ops/kullback_leibler.py
@@ -18,12 +18,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
-
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.util import tf_inspect
 
 
 _DIVERGENCES = {}
@@ -31,8 +30,8 @@ _DIVERGENCES = {}
 
 def _registered_kl(type_a, type_b):
   """Get the KL function registered for classes a and b."""
-  hierarchy_a = inspect.getmro(type_a)
-  hierarchy_b = inspect.getmro(type_b)
+  hierarchy_a = tf_inspect.getmro(type_a)
+  hierarchy_b = tf_inspect.getmro(type_b)
   dist_to_children = None
   kl_fn = None
   for mro_to_a, parent_a in enumerate(hierarchy_a):
diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py
index ad84cd681aa..9c194ec202a 100644
--- a/tensorflow/contrib/framework/python/ops/arg_scope.py
+++ b/tensorflow/contrib/framework/python/ops/arg_scope.py
@@ -61,8 +61,9 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
-import contextlib
-import functools
+
+from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util import tf_decorator
 
 __all__ = ['arg_scope',
            'add_arg_scope',
@@ -106,7 +107,7 @@ def _add_op(op):
     _DECORATED_OPS[key_op] = _kwarg_names(op)
 
 
-@contextlib.contextmanager
+@tf_contextlib.contextmanager
 def arg_scope(list_ops_or_scope, **kwargs):
   """Stores the default arguments for the given set of list_ops.
 
@@ -170,7 +171,6 @@ def add_arg_scope(func):
   Returns:
     A tuple with the decorated function func_with_args().
   """
-  @functools.wraps(func)
   def func_with_args(*args, **kwargs):
     current_scope = _current_arg_scope()
     current_args = kwargs
@@ -181,8 +181,7 @@ def add_arg_scope(func):
     return func(*args, **current_args)
   _add_op(func)
   setattr(func_with_args, '_key_op', _key_op(func))
-  setattr(func_with_args, '__doc__', func.__doc__)
-  return func_with_args
+  return tf_decorator.make_decorator(func, func_with_args)
 
 
 def has_arg_scope(func):
diff --git a/tensorflow/contrib/keras/python/keras/backend_test.py b/tensorflow/contrib/keras/python/keras/backend_test.py
index fd9db1f3273..2da5aee58e5 100644
--- a/tensorflow/contrib/keras/python/keras/backend_test.py
+++ b/tensorflow/contrib/keras/python/keras/backend_test.py
@@ -18,12 +18,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
-
 import numpy as np
 
 from tensorflow.contrib.keras.python import keras
 from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
 
 
 def compare_single_input_op_to_numpy(keras_op,
@@ -207,7 +206,7 @@ class BackendLinearAlgebraTest(test.TestCase):
         compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
                                          keras_kwargs={'axis': -1},
                                          np_kwargs={'axis': -1})
-        if 'keepdims' in inspect.getargspec(keras_op).args:
+        if 'keepdims' in tf_inspect.getargspec(keras_op).args:
           compare_single_input_op_to_numpy(keras_op, np_op,
                                            input_shape=(4, 7, 5),
                                            keras_kwargs={'axis': 1,
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py
index 0d1812aaa2f..7848e5982dd 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology.py
@@ -20,7 +20,6 @@ from __future__ import division
 from __future__ import print_function
 
 import copy
-import inspect
 import json
 import os
 import re
@@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils import conv_utils
 from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
 from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.util import tf_inspect
 
 
 # pylint: disable=g-import-not-at-top
@@ -584,7 +584,7 @@ class Layer(object):
       user_kwargs = copy.copy(kwargs)
       if not _is_all_none(previous_mask):
         # The previous layer generated a mask.
-        if 'mask' in inspect.getargspec(self.call).args:
+        if 'mask' in tf_inspect.getargspec(self.call).args:
           if 'mask' not in kwargs:
             # If mask is explicitly passed to __call__,
             # we should override the default mask.
@@ -2166,7 +2166,7 @@ class Container(Layer):
               kwargs = {}
             if len(computed_data) == 1:
               computed_tensor, computed_mask = computed_data[0]
-              if 'mask' in inspect.getargspec(layer.call).args:
+              if 'mask' in tf_inspect.getargspec(layer.call).args:
                 if 'mask' not in kwargs:
                   kwargs['mask'] = computed_mask
               output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
@@ -2177,7 +2177,7 @@ class Container(Layer):
             else:
               computed_tensors = [x[0] for x in computed_data]
               computed_masks = [x[1] for x in computed_data]
-              if 'mask' in inspect.getargspec(layer.call).args:
+              if 'mask' in tf_inspect.getargspec(layer.call).args:
                 if 'mask' not in kwargs:
                   kwargs['mask'] = computed_masks
               output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py
index 8dd55aaa2e6..32ada176a4f 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core.py
+++ b/tensorflow/contrib/keras/python/keras/layers/core.py
@@ -19,7 +19,6 @@ from __future__ import division
 from __future__ import print_function
 
 import copy
-import inspect
 import types as python_types
 
 import numpy as np
@@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserializ
 from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump
 from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.util import tf_inspect
 
 
 class Masking(Layer):
@@ -595,7 +595,7 @@ class Lambda(Layer):
 
   def call(self, inputs, mask=None):
     arguments = self.arguments
-    arg_spec = inspect.getargspec(self.function)
+    arg_spec = tf_inspect.getargspec(self.function)
     if 'mask' in arg_spec.args:
       arguments['mask'] = mask
     return self.function(inputs, **arguments)
diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
index a322696514c..ce6458fd0c8 100644
--- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py
+++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
@@ -20,12 +20,12 @@ from __future__ import division
 from __future__ import print_function
 
 import copy
-import inspect
 
 from tensorflow.contrib.keras.python.keras import backend as K
 from tensorflow.contrib.keras.python.keras.engine import InputSpec
 from tensorflow.contrib.keras.python.keras.engine import Layer
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.util import tf_inspect
 
 
 class Wrapper(Layer):
@@ -284,7 +284,7 @@ class Bidirectional(Wrapper):
 
   def call(self, inputs, training=None, mask=None):
     kwargs = {}
-    func_args = inspect.getargspec(self.layer.call).args
+    func_args = tf_inspect.getargspec(self.layer.call).args
     if 'training' in func_args:
       kwargs['training'] = training
     if 'mask' in func_args:
diff --git a/tensorflow/contrib/keras/python/keras/testing_utils.py b/tensorflow/contrib/keras/python/keras/testing_utils.py
index baba5447d99..bf6f661adff 100644
--- a/tensorflow/contrib/keras/python/keras/testing_utils.py
+++ b/tensorflow/contrib/keras/python/keras/testing_utils.py
@@ -18,11 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
-
 import numpy as np
 
 from tensorflow.contrib.keras.python import keras
+from tensorflow.python.util import tf_inspect
 
 
 def get_test_data(train_samples,
@@ -98,7 +97,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
   layer.set_weights(weights)
 
   # test and instantiation from weights
-  if 'weights' in inspect.getargspec(layer_cls.__init__):
+  if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
     kwargs['weights'] = weights
     layer = layer_cls(**kwargs)
 
diff --git a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
index 4c95c314b16..27cc23f232d 100644
--- a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import marshal
 import sys
 import time
@@ -26,6 +25,8 @@ import types as python_types
 import numpy as np
 import six
 
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
 
 _GLOBAL_CUSTOM_OBJECTS = {}
 
@@ -116,6 +117,7 @@ def get_custom_objects():
 
 
 def serialize_keras_object(instance):
+  _, instance = tf_decorator.unwrap(instance)
   if instance is None:
     return None
   if hasattr(instance, 'get_config'):
@@ -149,7 +151,7 @@ def deserialize_keras_object(identifier,
       if cls is None:
         raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
     if hasattr(cls, 'from_config'):
-      arg_spec = inspect.getargspec(cls.from_config)
+      arg_spec = tf_inspect.getargspec(cls.from_config)
       if 'custom_objects' in arg_spec.args:
         custom_objects = custom_objects or {}
         return cls.from_config(
diff --git a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
index 323c31aee83..9f8cea375b7 100644
--- a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
+++ b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
@@ -19,13 +19,13 @@ from __future__ import division
 from __future__ import print_function
 
 import copy
-import inspect
 import types
 
 import numpy as np
 
 from tensorflow.contrib.keras.python.keras.models import Sequential
 from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical
+from tensorflow.python.util import tf_inspect
 
 
 class BaseWrapper(object):
@@ -97,7 +97,7 @@ class BaseWrapper(object):
 
     legal_params = []
     for fn in legal_params_fns:
-      legal_params += inspect.getargspec(fn)[0]
+      legal_params += tf_inspect.getargspec(fn)[0]
     legal_params = set(legal_params)
 
     for params_name in params:
@@ -182,7 +182,7 @@ class BaseWrapper(object):
     """
     override = override or {}
     res = {}
-    fn_args = inspect.getargspec(fn)[0]
+    fn_args = tf_inspect.getargspec(fn)[0]
     for name, value in self.sk_params.items():
       if name in fn_args:
         res.update({name: value})
diff --git a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
index 4a939cb22c5..80fa17ec1f7 100644
--- a/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
+++ b/tensorflow/contrib/labeled_tensor/python/ops/_typecheck.py
@@ -24,9 +24,9 @@ from __future__ import print_function
 
 import collections
 import functools
-import inspect
 import re
 
+from tensorflow.python.util import tf_inspect
 
 # used for register_type_abbreviation and _type_repr below.
 _TYPE_ABBREVIATIONS = {}
@@ -230,7 +230,7 @@ def accepts(*types):
 
   def check_accepts(f):
     """Check the types."""
-    spec = inspect.getargspec(f)
+    spec = tf_inspect.getargspec(f)
 
     num_function_arguments = len(spec.args)
     if len(types) != num_function_arguments:
diff --git a/tensorflow/contrib/learn/python/learn/dataframe/transform.py b/tensorflow/contrib/learn/python/learn/dataframe/transform.py
index c28da59ac76..33be68e46a5 100644
--- a/tensorflow/contrib/learn/python/learn/dataframe/transform.py
+++ b/tensorflow/contrib/learn/python/learn/dataframe/transform.py
@@ -24,11 +24,12 @@ from abc import abstractmethod
 from abc import abstractproperty
 
 import collections
-import inspect
 
 from .series import Series
 from .series import TransformedSeries
 
+from tensorflow.python.util import tf_inspect
+
 
 def _make_list_of_series(x):
   """Converts `x` into a list of `Series` if possible.
@@ -120,7 +121,7 @@ class Transform(object):
   def parameters(self):
     """A dict of names to values of properties marked with `@parameter`."""
     property_param_names = [name
-                            for name, func in inspect.getmembers(type(self))
+                            for name, func in tf_inspect.getmembers(type(self))
                             if (hasattr(func, "fget") and hasattr(
                                 getattr(func, "fget"), "is_parameter"))]
     return {name: getattr(self, name) for name in property_param_names}
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 89fbe768402..8a92809a0ce 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -21,7 +21,6 @@ from __future__ import print_function
 
 import abc
 import copy
-import inspect
 import os
 import tempfile
 
@@ -70,6 +69,8 @@ from tensorflow.python.training import monitored_session
 from tensorflow.python.training import saver
 from tensorflow.python.training import summary_io
 from tensorflow.python.util import compat
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
 
 
 AS_ITERABLE_DATE = '2016-09-15'
@@ -185,14 +186,15 @@ def _model_fn_args(fn):
   Raises:
     ValueError: if partial function has positionally bound arguments
   """
+  _, fn = tf_decorator.unwrap(fn)
   if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
     # Handle functools.partial and similar objects.
     return tuple([
-        arg for arg in inspect.getargspec(fn.func).args[len(fn.args):]
+        arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
         if arg not in set(fn.keywords.keys())
     ])
   # Handle function.
-  return tuple(inspect.getargspec(fn).args)
+  return tuple(tf_inspect.getargspec(fn).args)
 
 
 def _get_replica_device_setter(config):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py
index eb0cf51e098..fd47710e301 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
+from tensorflow.python.util import tf_inspect
 
 
 def assert_estimator_contract(tester, estimator_class):
@@ -31,7 +31,7 @@ def assert_estimator_contract(tester, estimator_class):
     tester: A tf.test.TestCase.
     estimator_class: 'type' object of pre-canned estimator.
   """
-  attributes = inspect.getmembers(estimator_class)
+  attributes = tf_inspect.getmembers(estimator_class)
   attribute_names = [a[0] for a in attributes]
 
   tester.assertTrue('config' in attribute_names)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index ae01c678b6c..4f82955684c 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -19,7 +19,6 @@ from __future__ import division
 from __future__ import print_function
 
 import abc
-import inspect
 
 import six
 
@@ -46,6 +45,8 @@ from tensorflow.python.ops import string_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.summary import summary
 from tensorflow.python.training import training
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
 
 
 class Head(object):
@@ -1697,9 +1698,10 @@ def _check_mode_valid(mode):
 
 def _get_arguments(func):
   """Returns a spec of given func."""
+  _, func = tf_decorator.unwrap(func)
   if hasattr(func, "__code__"):
     # Regular function.
-    return inspect.getargspec(func)
+    return tf_inspect.getargspec(func)
   elif hasattr(func, "__call__"):
     # Callable object.
     return _get_arguments(func.__call__)
diff --git a/tensorflow/contrib/learn/python/learn/export_strategy.py b/tensorflow/contrib/learn/python/learn/export_strategy.py
index c62b8861a1e..f276aab0e6b 100644
--- a/tensorflow/contrib/learn/python/learn/export_strategy.py
+++ b/tensorflow/contrib/learn/python/learn/export_strategy.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """ExportStrategy class represents different flavors of model export."""
 
 from __future__ import absolute_import
@@ -20,13 +19,14 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
-import inspect
+
+from tensorflow.python.util import tf_inspect
 
 __all__ = ['ExportStrategy']
 
 
-class ExportStrategy(collections.namedtuple('ExportStrategy',
-                                            ['name', 'export_fn'])):
+class ExportStrategy(
+    collections.namedtuple('ExportStrategy', ['name', 'export_fn'])):
   """A class representing a type of model export.
 
   Typically constructed by a utility function specific to the exporter, such as
@@ -74,7 +74,7 @@ class ExportStrategy(collections.namedtuple('ExportStrategy',
     """
     # don't break existing export_fns that don't accept checkpoint_path and
     # eval_result
-    export_fn_args = inspect.getargspec(self.export_fn).args
+    export_fn_args = tf_inspect.getargspec(self.export_fn).args
     kwargs = {}
     if 'checkpoint_path' in export_fn_args:
       kwargs['checkpoint_path'] = checkpoint_path
diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py
index 7be5748fa45..eafc925ad68 100644
--- a/tensorflow/contrib/learn/python/learn/metric_spec.py
+++ b/tensorflow/contrib/learn/python/learn/metric_spec.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import six
 
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_inspect
 
 
 def _assert_named_args(sentinel):
@@ -43,11 +43,11 @@ def _args(fn):
   if hasattr(fn, 'func') and hasattr(fn, 'keywords'):
     # Handle functools.partial and similar objects.
     return tuple([
-        arg for arg in inspect.getargspec(fn.func).args
+        arg for arg in tf_inspect.getargspec(fn.func).args
         if arg not in set(fn.keywords.keys())
     ])
   # Handle function.
-  return tuple(inspect.getargspec(fn).args)
+  return tuple(tf_inspect.getargspec(fn).args)
 
 
 _CANONICAL_LABELS_ARG = 'labels'
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index fa9f52e9223..9f133926660 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -35,7 +35,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import os
 import time
 
@@ -53,6 +52,7 @@ from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import saver as saver_lib
 from tensorflow.python.training import summary_io
 from tensorflow.python.util import deprecation
+from tensorflow.python.util import tf_inspect
 
 
 # TODO(ptucker): Split each monitor class into a separate file.
@@ -1164,7 +1164,7 @@ class RunHookAdapterForMonitors(session_run_hook.SessionRunHook):
   def end(self, session):
     self._last_step = None
     for m in self._monitors:
-      if "session" in inspect.getargspec(m.end).args:
+      if "session" in tf_inspect.getargspec(m.end).args:
         m.end(session=session)
       else:
         m.end()
diff --git a/tensorflow/contrib/specs/python/specs.py b/tensorflow/contrib/specs/python/specs.py
index a9fba442db5..d5223b9b551 100644
--- a/tensorflow/contrib/specs/python/specs.py
+++ b/tensorflow/contrib/specs/python/specs.py
@@ -19,13 +19,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
-import inspect
-
 from six import exec_
 from tensorflow.contrib.specs.python import params_ops
 from tensorflow.contrib.specs.python import specs_lib
 from tensorflow.contrib.specs.python import specs_ops
+from tensorflow.python.util import tf_inspect
 
 
 def eval_params(params, environment=None):
@@ -44,7 +42,8 @@ def eval_params(params, environment=None):
   """
   specs_lib.check_keywords(params)
   bindings = {}
-  if environment: bindings.update(environment)
+  if environment:
+    bindings.update(environment)
   exec_(params, vars(params_ops), bindings)  # pylint: disable=exec-used
   return bindings
 
@@ -71,7 +70,8 @@ def eval_spec(spec, environment=None):
   """
   specs_lib.check_keywords(spec)
   bindings = {}
-  if environment: bindings.update(environment)
+  if environment:
+    bindings.update(environment)
   exec_(spec, vars(specs_ops), bindings)  # pylint: disable=exec-used
   return bindings
 
@@ -141,7 +141,7 @@ class LocalImport(object):
     self.names = names
 
   def __enter__(self):
-    self.frame = inspect.currentframe()
+    self.frame = tf_inspect.currentframe()
     bindings = self.frame.f_back.f_globals
     self.old = {k: bindings.get(k, None) for k in self.names.keys()}
     bindings.update(self.names)
@@ -151,7 +151,9 @@ class LocalImport(object):
     bindings = self.frame.f_back.f_globals
     bindings.update(self.old)
     for k, v in self.old.items():
-      if v is None: del bindings[k]
+      if v is None:
+        del bindings[k]
     del self.frame
 
+
 ops = LocalImport(specs_ops)
diff --git a/tensorflow/docs_src/api_guides/python/index.md b/tensorflow/docs_src/api_guides/python/index.md
index 0e624df55b7..177f19bc80d 100644
--- a/tensorflow/docs_src/api_guides/python/index.md
+++ b/tensorflow/docs_src/api_guides/python/index.md
@@ -43,6 +43,7 @@
 *   [Random variable transformations (contrib)](contrib.distributions.bijector.md)
 *   [RNN and Cells (contrib)](contrib.rnn.md)
 *   [Seq2seq Library (contrib)](contrib.seq2seq.md)
+*   [Staging (contrib)](contrib.staging.md)
 *   [Statistical Distributions (contrib)](contrib.distributions.md)
 *   [Training (contrib)](contrib.training.md)
 *   [Utilities (contrib)](contrib.util.md)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index b9e79ac1566..5e2a7abeac0 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2379,6 +2379,39 @@ py_test(
     ],
 )
 
+py_test(
+    name = "tf_contextlib_test",
+    size = "small",
+    srcs = ["util/tf_contextlib_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":client_testlib",
+        ":util",
+    ],
+)
+
+py_test(
+    name = "tf_decorator_test",
+    size = "small",
+    srcs = ["util/tf_decorator_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":client_testlib",
+        ":util",
+    ],
+)
+
+py_test(
+    name = "tf_inspect_test",
+    size = "small",
+    srcs = ["util/tf_inspect_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":client_testlib",
+        ":util",
+    ],
+)
+
 py_library(
     name = "util_example_parser_configuration",
     srcs = ["util/example_parser_configuration.py"],
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index d2945adf75c..864a96ef348 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -26,11 +26,9 @@ import tensorflow as tf
 
 import ctypes
 import importlib
-import inspect
 import sys
 import traceback
 
-
 # TODO(drpng): write up instructions for editing this file in a doc and point to
 # the doc instead.
 # If you want to edit this file to expose modules in public tensorflow API, you
@@ -170,7 +168,7 @@ _allowed_symbols.extend([
     'parse_single_sequence_example',
     'serialize_many_sparse',
     'serialize_sparse',
-    'sparse_matmul',   ## use tf.matmul instead.
+    'sparse_matmul',  ## use tf.matmul instead.
 ])
 
 # This is needed temporarily because we import it explicitly.
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 3813aa996b3..f7e17f1c53d 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -601,6 +601,7 @@ cuda_py_test(
         "//tensorflow/python:framework_test_lib",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
+        "//tensorflow/python:util",
         "//tensorflow/python:variables",
     ],
 )
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index e62a0f611f7..8b191f332e8 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import os
 import shutil
 import tempfile
@@ -41,10 +40,11 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
 from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
 
 
 def line_number_above():
-  return inspect.stack()[1][2] - 1
+  return tf_inspect.stack()[1][2] - 1
 
 
 def parse_op_and_node(line):
@@ -503,7 +503,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
       cls._main_device = "/job:localhost/replica:0/task:0/cpu:0"
 
     cls._curr_file_path = os.path.abspath(
-        inspect.getfile(inspect.currentframe()))
+        tf_inspect.getfile(tf_inspect.currentframe()))
 
     cls._sess = session.Session()
     with cls._sess as sess:
diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py
index 6010723d46c..a4fb0d99109 100644
--- a/tensorflow/python/debug/lib/source_utils_test.py
+++ b/tensorflow/python/debug/lib/source_utils_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import os
 import shutil
 import tempfile
@@ -37,10 +36,11 @@ from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import googletest
+from tensorflow.python.util import tf_inspect
 
 
 def line_number_above():
-  return inspect.stack()[1][2] - 1
+  return tf_inspect.stack()[1][2] - 1
 
 
 class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):
@@ -52,21 +52,21 @@ class GuessIsTensorFlowLibraryTest(test_util.TensorFlowTestCase):
     ops.reset_default_graph()
 
   def testGuessedBaseDirIsProbablyCorrect(self):
-    self.assertEqual(
-        "tensorflow", os.path.basename(source_utils._TENSORFLOW_BASEDIR))
+    self.assertEqual("tensorflow",
+                     os.path.basename(source_utils._TENSORFLOW_BASEDIR))
 
   def testUnitTestFileReturnsFalse(self):
-    self.assertFalse(source_utils._guess_is_tensorflow_py_library(
-        self.curr_file_path))
+    self.assertFalse(
+        source_utils._guess_is_tensorflow_py_library(self.curr_file_path))
 
   def testSourceUtilModuleReturnsTrue(self):
-    self.assertTrue(source_utils._guess_is_tensorflow_py_library(
-        source_utils.__file__))
+    self.assertTrue(
+        source_utils._guess_is_tensorflow_py_library(source_utils.__file__))
 
   def testFileInPythonKernelsPathReturnsTrue(self):
     x = constant_op.constant(42.0, name="x")
-    self.assertTrue(source_utils._guess_is_tensorflow_py_library(
-        x.op.traceback[-1][0]))
+    self.assertTrue(
+        source_utils._guess_is_tensorflow_py_library(x.op.traceback[-1][0]))
 
   def testNonPythonFileRaisesException(self):
     with self.assertRaisesRegexp(ValueError, r"is not a Python source file"):
@@ -85,7 +85,7 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
 
     self.dump_root = self.get_temp_dir()
     self.curr_file_path = os.path.abspath(
-        inspect.getfile(inspect.currentframe()))
+        tf_inspect.getfile(tf_inspect.currentframe()))
 
     # Run a simple TF graph to generate some debug dumps that can be used in
     # source annotation.
@@ -135,27 +135,21 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
 
     self.assertIn(self.u_init.op.name,
                   source_annotation[self.u_init_line_number])
-    self.assertIn(self.u.op.name,
-                  source_annotation[self.u_line_number])
+    self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
     self.assertIn(self.v_init.op.name,
                   source_annotation[self.v_init_line_number])
-    self.assertIn(self.v.op.name,
-                  source_annotation[self.v_line_number])
-    self.assertIn(self.w.op.name,
-                  source_annotation[self.w_line_number])
+    self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
+    self.assertIn(self.w.op.name, source_annotation[self.w_line_number])
 
     # In the non-stack-top (default) mode, the helper line should be annotated
     # with all the ops as well.
     self.assertIn(self.u_init.op.name,
                   source_annotation[self.helper_line_number])
-    self.assertIn(self.u.op.name,
-                  source_annotation[self.helper_line_number])
+    self.assertIn(self.u.op.name, source_annotation[self.helper_line_number])
     self.assertIn(self.v_init.op.name,
                   source_annotation[self.helper_line_number])
-    self.assertIn(self.v.op.name,
-                  source_annotation[self.helper_line_number])
-    self.assertIn(self.w.op.name,
-                  source_annotation[self.helper_line_number])
+    self.assertIn(self.v.op.name, source_annotation[self.helper_line_number])
+    self.assertIn(self.w.op.name, source_annotation[self.helper_line_number])
 
   def testAnnotateWithStackTopGivesCorrectResult(self):
     source_annotation = source_utils.annotate_source(
@@ -163,14 +157,11 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
 
     self.assertIn(self.u_init.op.name,
                   source_annotation[self.u_init_line_number])
-    self.assertIn(self.u.op.name,
-                  source_annotation[self.u_line_number])
+    self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
     self.assertIn(self.v_init.op.name,
                   source_annotation[self.v_init_line_number])
-    self.assertIn(self.v.op.name,
-                  source_annotation[self.v_line_number])
-    self.assertIn(self.w.op.name,
-                  source_annotation[self.w_line_number])
+    self.assertIn(self.v.op.name, source_annotation[self.v_line_number])
+    self.assertIn(self.w.op.name, source_annotation[self.w_line_number])
 
     # In the stack-top mode, the helper line should not have been annotated.
     self.assertNotIn(self.helper_line_number, source_annotation)
@@ -182,8 +173,7 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
         min_line=self.u_line_number,
         max_line=self.u_line_number + 1)
 
-    self.assertIn(self.u.op.name,
-                  source_annotation[self.u_line_number])
+    self.assertIn(self.u.op.name, source_annotation[self.u_line_number])
     self.assertNotIn(self.v_line_number, source_annotation)
 
   def testAnnotateDumpedTensorsGivesCorrectResult(self):
@@ -192,26 +182,17 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
 
     # Note: Constant Tensors u_init and v_init may not get dumped due to
     #   constant-folding.
-    self.assertIn(self.u.name,
-                  source_annotation[self.u_line_number])
-    self.assertIn(self.v.name,
-                  source_annotation[self.v_line_number])
-    self.assertIn(self.w.name,
-                  source_annotation[self.w_line_number])
+    self.assertIn(self.u.name, source_annotation[self.u_line_number])
+    self.assertIn(self.v.name, source_annotation[self.v_line_number])
+    self.assertIn(self.w.name, source_annotation[self.w_line_number])
 
-    self.assertNotIn(self.u.op.name,
-                     source_annotation[self.u_line_number])
-    self.assertNotIn(self.v.op.name,
-                     source_annotation[self.v_line_number])
-    self.assertNotIn(self.w.op.name,
-                     source_annotation[self.w_line_number])
+    self.assertNotIn(self.u.op.name, source_annotation[self.u_line_number])
+    self.assertNotIn(self.v.op.name, source_annotation[self.v_line_number])
+    self.assertNotIn(self.w.op.name, source_annotation[self.w_line_number])
 
-    self.assertIn(self.u.name,
-                  source_annotation[self.helper_line_number])
-    self.assertIn(self.v.name,
-                  source_annotation[self.helper_line_number])
-    self.assertIn(self.w.name,
-                  source_annotation[self.helper_line_number])
+    self.assertIn(self.u.name, source_annotation[self.helper_line_number])
+    self.assertIn(self.v.name, source_annotation[self.helper_line_number])
+    self.assertIn(self.w.name, source_annotation[self.helper_line_number])
 
   def testCallingAnnotateSourceWithoutPythonGraphRaisesException(self):
     self.dump.set_python_graph(None)
@@ -224,8 +205,9 @@ class SourceHelperTest(test_util.TensorFlowTestCase):
     with open(unrelated_source_path, "wt") as source_file:
       source_file.write("print('hello, world')\n")
 
-    self.assertEqual(
-        {}, source_utils.annotate_source(self.dump, unrelated_source_path))
+    self.assertEqual({},
+                     source_utils.annotate_source(self.dump,
+                                                  unrelated_source_path))
 
     # Clean up unrelated source file.
     os.remove(unrelated_source_path)
@@ -238,7 +220,7 @@ class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase):
 
     self.dump_root = self.get_temp_dir()
     self.curr_file_path = os.path.abspath(
-        inspect.getfile(inspect.currentframe()))
+        tf_inspect.getfile(tf_inspect.currentframe()))
 
     # Run a simple TF graph to generate some debug dumps that can be used in
     # source annotation.
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index a1b0d9358d4..ac3cda4ff16 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -20,7 +20,6 @@ from __future__ import division
 from __future__ import print_function
 
 import copy
-import inspect
 import os
 import tempfile
 
@@ -48,6 +47,9 @@ from tensorflow.python.training import monitored_session
 from tensorflow.python.training import saver
 from tensorflow.python.training import training
 from tensorflow.python.util import compat
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
 
 _VALID_MODEL_FN_ARGS = set(
     ['features', 'labels', 'mode', 'params', 'config'])
@@ -716,14 +718,15 @@ def _model_fn_args(fn):
   Raises:
     ValueError: if partial function has positionally bound arguments
   """
+  _, fn = tf_decorator.unwrap(fn)
   if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
     # Handle functools.partial and similar objects.
     return tuple([
-        arg for arg in inspect.getargspec(fn.func).args[len(fn.args):]
+        arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
         if arg not in set(fn.keywords.keys())
     ])
   # Handle function.
-  return tuple(inspect.getargspec(fn).args)
+  return tuple(tf_inspect.getargspec(fn).args)
 
 
 def _verify_model_fn_args(model_fn, params):
diff --git a/tensorflow/python/framework/contrib_test.py b/tensorflow/python/framework/contrib_test.py
index 8ca0c69d775..f2eaf7c2eea 100644
--- a/tensorflow/python/framework/contrib_test.py
+++ b/tensorflow/python/framework/contrib_test.py
@@ -18,9 +18,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
-
 from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
 
 
 class ContribTest(test.TestCase):
@@ -29,17 +28,17 @@ class ContribTest(test.TestCase):
     # pylint: disable=g-import-not-at-top
     import tensorflow as tf
     _ = tf.contrib.layers  # `tf.contrib` is loaded lazily on first use.
-    assert inspect.ismodule(tf.contrib)
+    assert tf_inspect.ismodule(tf.contrib)
 
   def testLayers(self):
     # pylint: disable=g-import-not-at-top
     import tensorflow as tf
-    assert inspect.ismodule(tf.contrib.layers)
+    assert tf_inspect.ismodule(tf.contrib.layers)
 
   def testLinearOptimizer(self):
     # pylint: disable=g-import-not-at-top
     import tensorflow as tf
-    assert inspect.ismodule(tf.contrib.linear_optimizer)
+    assert tf_inspect.ismodule(tf.contrib.linear_optimizer)
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 8b156db6dc4..2a1389b91ff 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -23,7 +23,6 @@ from __future__ import print_function
 
 import collections
 import hashlib
-import inspect
 import re
 
 from tensorflow.core.framework import attr_value_pb2
@@ -36,6 +35,8 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.util import compat
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
 
 
 def _make_argname_from_tensor_name(name):
@@ -259,10 +260,11 @@ def _call(sig, *inputs, **kwargs):
 
 
 def _get_func_name(func):
+  _, func = tf_decorator.unwrap(func)
   if callable(func):
-    if inspect.isfunction(func):
+    if tf_inspect.isfunction(func):
       return func.__name__
-    elif inspect.ismethod(func):
+    elif tf_inspect.ismethod(func):
       return "%s.%s" % (func.__self__.__name__, func.__name__)
     else:  # Probably a class instance with __call__
       return type(func)
@@ -955,7 +957,7 @@ class Defun(object):
       raise ValueError("func %s must be callable" % func)
 
     # Func should not use kwargs and defaults.
-    argspec = inspect.getargspec(func)
+    argspec = tf_inspect.getargspec(func)
     if argspec.keywords or argspec.defaults:
       raise ValueError("Functions with argument defaults or keyword "
                        "arguments are not supported.")
@@ -966,7 +968,7 @@ class Defun(object):
     if argspec.varargs:
       max_args = 1000000
     argnames = argspec.args
-    if inspect.ismethod(func):
+    if tf_inspect.ismethod(func):
       # 1st argument is the "class" type.
       min_args -= 1
       argnames = argnames[1:]
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index 7f2b03e3509..2c39f5b0e37 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -19,8 +19,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import contextlib
-
 import six
 
 from tensorflow.core.framework import attr_value_pb2
@@ -33,6 +31,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
+from tensorflow.python.util import tf_contextlib
 
 
 def _Attr(op_def, name):
@@ -241,7 +240,7 @@ class _OpInfo(object):
 
 
 # pylint: disable=g-doc-return-or-yield
-@contextlib.contextmanager
+@tf_contextlib.contextmanager
 def _MaybeColocateWith(inputs):
   """A context manager for (maybe) colocating with a list of input tensors.
 
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index ebab40c0aab..6d2a38b3a6c 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -20,7 +20,6 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
-import contextlib
 import copy
 import linecache
 import re
@@ -44,6 +43,7 @@ from tensorflow.python.framework import versions
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import compat
 from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import tf_contextlib
 
 
 def _override_helper(clazz_object, operator, func):
@@ -2725,7 +2725,7 @@ class Graph(object):
       if name in self._collections:
         del self._collections[name]
 
-  @contextlib.contextmanager
+  @tf_contextlib.contextmanager
   def _original_op(self, op):
     """Python 'with' handler to help annotate ops with their originator.
 
@@ -2751,7 +2751,7 @@ class Graph(object):
       self._default_original_op = old_original_op
 
   # pylint: disable=g-doc-return-or-yield
-  @contextlib.contextmanager
+  @tf_contextlib.contextmanager
   def name_scope(self, name):
     r"""Returns a context manager that creates hierarchical names for operations.
 
@@ -2924,7 +2924,7 @@ class Graph(object):
     """
     return self._name_stack
 
-  @contextlib.contextmanager
+  @tf_contextlib.contextmanager
   def colocate_with(self, op, ignore_existing=False):
     """Returns a context manager that specifies an op to colocate with.
 
@@ -2999,7 +2999,7 @@ class Graph(object):
       if ignore_existing:
         self._colocation_stack = current_stack
 
-  @contextlib.contextmanager
+  @tf_contextlib.contextmanager
   def device(self, device_name_or_function):
     """Returns a context manager that specifies the default device to use.
 
@@ -3081,7 +3081,7 @@ class Graph(object):
       op._set_device(device_function(op))
 
   # pylint: disable=g-doc-return-or-yield
-  @contextlib.contextmanager
+  @tf_contextlib.contextmanager
   def container(self, container_name):
     """Returns a context manager that specifies the resource container to use.
 
@@ -3349,7 +3349,7 @@ class Graph(object):
     return self._ControlDependenciesController(self, control_ops)
 
   # pylint: disable=g-doc-return-or-yield
-  @contextlib.contextmanager
+  @tf_contextlib.contextmanager
   def _attr_scope(self, attr_map):
     """EXPERIMENTAL: A context manager for setting attributes on operators.
 
@@ -3414,7 +3414,7 @@ class Graph(object):
   # pylint: enable=g-doc-return-or-yield
 
   # pylint: disable=g-doc-return-or-yield
-  @contextlib.contextmanager
+  @tf_contextlib.contextmanager
   def _kernel_label_map(self, op_to_kernel_label_map):
     """EXPERIMENTAL: A context manager for setting kernel labels.
 
@@ -3476,7 +3476,7 @@ class Graph(object):
   # pylint: enable=g-doc-return-or-yield
 
   # pylint: disable=g-doc-return-or-yield
-  @contextlib.contextmanager
+  @tf_contextlib.contextmanager
   def gradient_override_map(self, op_type_map):
     """EXPERIMENTAL: A context manager for overriding gradient functions.
 
@@ -3634,7 +3634,7 @@ class _DefaultStack(threading.local):
   def enforce_nesting(self, value):
     self._enforce_nesting = value
 
-  @contextlib.contextmanager
+  @tf_contextlib.contextmanager
   def get_controller(self, default):
     """A context manager for manipulating a default stack."""
     try:
@@ -4137,7 +4137,7 @@ def get_all_collection_keys():
 
 
 # pylint: disable=g-doc-return-or-yield
-@contextlib.contextmanager
+@tf_contextlib.contextmanager
 def name_scope(name, default_name=None, values=None):
   """Returns a context manager for use when defining a Python op.
 
@@ -4227,7 +4227,7 @@ def prepend_name_scope(name, import_scope):
 
 
 # pylint: disable=g-doc-return-or-yield
-@contextlib.contextmanager
+@tf_contextlib.contextmanager
 def op_scope(values, name, default_name=None):
   """DEPRECATED. Same as name_scope above, just different argument order."""
   logging.warn("tf.op_scope(values, name, default_name) is deprecated,"
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 9b76585f9fb..21b0ba76266 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -25,7 +25,6 @@ from __future__ import print_function
 
 import copy
 import functools
-import inspect
 import re
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import numpy as np
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index ff9f134a0e3..f81837b73ac 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -20,7 +20,6 @@ from __future__ import division
 from __future__ import print_function
 
 import collections as collections_lib
-import contextlib
 import copy
 import functools
 import traceback
@@ -36,6 +35,7 @@ from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_contextlib
 
 __all__ = ["VariableScope", "get_variable_scope",
            "get_variable", "get_local_variable", "variable_scope",
@@ -1250,7 +1250,7 @@ def _get_partitioned_variable(name,
   # pylint: enable=protected-access
 
 
-@contextlib.contextmanager
+@tf_contextlib.contextmanager
 def _pure_variable_scope(name_or_scope,
                          reuse=None,
                          initializer=None,
@@ -1409,7 +1409,7 @@ def _get_unique_variable_scope(prefix):
 
 
 # pylint: disable=g-doc-return-or-yield
-@contextlib.contextmanager
+@tf_contextlib.contextmanager
 def variable_scope(name_or_scope,
                    default_name=None,
                    values=None,
@@ -1582,7 +1582,7 @@ def variable_scope(name_or_scope,
 
 
 # pylint: disable=g-doc-return-or-yield
-@contextlib.contextmanager
+@tf_contextlib.contextmanager
 def variable_op_scope(values,
                       name_or_scope,
                       default_name=None,
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index ea29399ed2f..aa74a419d8e 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import numbers
 import os
 import re
@@ -33,6 +32,8 @@ from tensorflow.python.client import timeline
 from tensorflow.python.platform import app
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_inspect
+
 
 # When a subclass of the Benchmark class is created, it is added to
 # the registry automatically
@@ -135,7 +136,7 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
     """Returns full name of class and method calling report_benchmark."""
 
     # Find the caller method (outermost Benchmark class)
-    stack = inspect.stack()
+    stack = tf_inspect.stack()
     calling_class = None
     name = None
     for frame in stack[::-1]:
diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py
index 1e74b1512b8..96219faab71 100644
--- a/tensorflow/python/platform/googletest.py
+++ b/tensorflow/python/platform/googletest.py
@@ -19,7 +19,6 @@ from __future__ import division
 from __future__ import print_function
 
 import atexit
-import inspect
 import itertools
 import os
 import sys
@@ -35,6 +34,9 @@ from tensorflow.python.lib.io import file_io
 from tensorflow.python.platform import app
 from tensorflow.python.platform import benchmark
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
 
 Benchmark = benchmark.TensorFlowBenchmark  # pylint: disable=invalid-name
 
@@ -101,9 +103,9 @@ def GetTempDir():
   """Return a temporary directory for tests to use."""
   global _googletest_temp_dir
   if not _googletest_temp_dir:
-    first_frame = inspect.stack()[-1][0]
-    temp_dir = os.path.join(
-        tempfile.gettempdir(), os.path.basename(inspect.getfile(first_frame)))
+    first_frame = tf_inspect.stack()[-1][0]
+    temp_dir = os.path.join(tempfile.gettempdir(),
+                            os.path.basename(tf_inspect.getfile(first_frame)))
     temp_dir = tempfile.mkdtemp(prefix=temp_dir.rstrip('.py'))
 
     def delete_temp_dir(dirname=temp_dir):
@@ -204,15 +206,16 @@ class StubOutForTesting(object):
     Raises:
       AttributeError: If the attribute cannot be found.
     """
-    if (inspect.ismodule(obj) or
-        (not inspect.isclass(obj) and attr_name in obj.__dict__)):
+    _, obj = tf_decorator.unwrap(obj)
+    if (tf_inspect.ismodule(obj) or
+        (not tf_inspect.isclass(obj) and attr_name in obj.__dict__)):
       orig_obj = obj
       orig_attr = getattr(obj, attr_name)
     else:
-      if not inspect.isclass(obj):
-        mro = list(inspect.getmro(obj.__class__))
+      if not tf_inspect.isclass(obj):
+        mro = list(tf_inspect.getmro(obj.__class__))
       else:
-        mro = list(inspect.getmro(obj))
+        mro = list(tf_inspect.getmro(obj))
 
       mro.reverse()
 
diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py
index a53fc541cb7..2455acb4c0c 100644
--- a/tensorflow/python/platform/resource_loader.py
+++ b/tensorflow/python/platform/resource_loader.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-
 """Resource management library.
 
 @@get_data_files_path
@@ -25,10 +24,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect as _inspect
 import os as _os
 import sys as _sys
 
+from tensorflow.python.util import tf_inspect as _inspect
 from tensorflow.python.util.all_util import remove_undocumented
 
 
@@ -44,9 +43,8 @@ def load_resource(path):
   Raises:
     IOError: If the path is not found, or the resource can't be opened.
   """
-  tensorflow_root = (
-      _os.path.join(
-          _os.path.dirname(__file__), _os.pardir, _os.pardir))
+  tensorflow_root = (_os.path.join(
+      _os.path.dirname(__file__), _os.pardir, _os.pardir))
   path = _os.path.join(tensorflow_root, path)
   path = _os.path.abspath(path)
   with open(path, 'rb') as f:
diff --git a/tensorflow/python/util/all_util.py b/tensorflow/python/util/all_util.py
index 08f33657510..50d480f8707 100644
--- a/tensorflow/python/util/all_util.py
+++ b/tensorflow/python/util/all_util.py
@@ -18,10 +18,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect as _inspect
 import re as _re
 import sys as _sys
 
+from tensorflow.python.util import tf_inspect as _tf_inspect
+
+
 _reference_pattern = _re.compile(r'^@@(\w+)$', flags=_re.MULTILINE)
 
 
@@ -45,7 +47,7 @@ def make_all(module_name, doc_string_modules=None):
   if doc_string_modules is None:
     doc_string_modules = [_sys.modules[module_name]]
   cur_members = set([name for name, _
-                     in _inspect.getmembers(_sys.modules[module_name])])
+                     in _tf_inspect.getmembers(_sys.modules[module_name])])
 
   results = set()
   for doc_module in doc_string_modules:
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 60b559b5f4e..73fc3e24087 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -20,11 +20,12 @@ from __future__ import print_function
 
 import collections
 import functools
-import inspect
 import re
 
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.util import decorator_utils
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
 
 
 def _add_deprecated_function_notice_to_docstring(doc, date, instructions):
@@ -59,7 +60,7 @@ def _validate_deprecation_args(date, instructions):
 
 def _call_location():
   """Returns call location given level up from current call."""
-  frame = inspect.currentframe()
+  frame = tf_inspect.currentframe()
   if frame:
     # CPython internals are available, use them for performance.
     # walk back two frames to get to deprecated function caller.
@@ -69,7 +70,7 @@ def _call_location():
     return '%s:%d' % (frame.f_code.co_filename, frame.f_lineno)
   else:
     # Slow fallback path
-    stack = inspect.stack(0)  # 0 avoids generating unused context
+    stack = tf_inspect.stack(0)  # 0 avoids generating unused context
     entry = stack[2]
     return '%s:%d' % (entry[1], entry[2])
 
@@ -119,9 +120,10 @@ def deprecated(date, instructions):
           'in a future version' if date is None else ('after %s' % date),
           instructions)
       return func(*args, **kwargs)
-    new_func.__doc__ = _add_deprecated_function_notice_to_docstring(
-        func.__doc__, date, instructions)
-    return new_func
+    return tf_decorator.make_decorator(
+        func, new_func, 'deprecated',
+        _add_deprecated_function_notice_to_docstring(func.__doc__, date,
+                                                     instructions))
   return deprecated_wrapper
 
 
@@ -193,7 +195,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
     Args:
       names_to_ok_vals: dict from string arg_name to a list of values,
         possibly empty, which should not elicit a warning.
-      arg_spec: Output from inspect.getargspec on the called function.
+      arg_spec: Output from tf_inspect.getargspec on the called function.
 
     Returns:
       Dictionary from arg_name to DeprecatedArgSpec.
@@ -213,7 +215,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
     decorator_utils.validate_callable(func, 'deprecated_args')
     deprecated_arg_names = _get_arg_names_to_ok_vals()
 
-    arg_spec = inspect.getargspec(func)
+    arg_spec = tf_inspect.getargspec(func)
     deprecated_positions = _get_deprecated_positional_arguments(
         deprecated_arg_names, arg_spec)
 
@@ -260,7 +262,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
     def new_func(*args, **kwargs):
       """Deprecation wrapper."""
       invalid_args = []
-      named_args = inspect.getcallargs(func, *args, **kwargs)
+      named_args = tf_inspect.getcallargs(func, *args, **kwargs)
       for arg_name, spec in iter(deprecated_positions.items()):
         if (spec.position < len(args) and
             not (spec.has_ok_value and
@@ -285,9 +287,9 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
             'in a future version' if date is None else ('after %s' % date),
             instructions)
       return func(*args, **kwargs)
-    new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
-        func.__doc__, date, instructions)
-    return new_func
+    return tf_decorator.make_decorator(func, new_func, 'deprecated',
+                                       _add_deprecated_arg_notice_to_docstring(
+                                           func.__doc__, date, instructions))
   return deprecated_wrapper
 
 
@@ -332,7 +334,7 @@ def deprecated_arg_values(date, instructions, **deprecated_kwargs):
     @functools.wraps(func)
     def new_func(*args, **kwargs):
       """Deprecation wrapper."""
-      named_args = inspect.getcallargs(func, *args, **kwargs)
+      named_args = tf_inspect.getcallargs(func, *args, **kwargs)
       for arg_name, arg_value in deprecated_kwargs.items():
         if arg_name in named_args and named_args[arg_name] == arg_value:
           logging.warning(
@@ -343,9 +345,9 @@ def deprecated_arg_values(date, instructions, **deprecated_kwargs):
               'in a future version' if date is None else ('after %s' % date),
               instructions)
       return func(*args, **kwargs)
-    new_func.__doc__ = _add_deprecated_arg_notice_to_docstring(
-        func.__doc__, date, instructions)
-    return new_func
+    return tf_decorator.make_decorator(func, new_func, 'deprecated',
+                                       _add_deprecated_arg_notice_to_docstring(
+                                           func.__doc__, date, instructions))
   return deprecated_wrapper
 
 
diff --git a/tensorflow/python/util/tf_contextlib.py b/tensorflow/python/util/tf_contextlib.py
new file mode 100644
index 00000000000..3830014d4ac
--- /dev/null
+++ b/tensorflow/python/util/tf_contextlib.py
@@ -0,0 +1,36 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TFDecorator-aware replacements for the contextlib module."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib as _contextlib
+
+from tensorflow.python.util import tf_decorator
+
+
+def contextmanager(target):
+  """A tf_decorator-aware wrapper for `contextlib.contextmanager`.
+
+  Usage is identical to `contextlib.contextmanager`.
+
+  Args:
+    target: A callable to be wrapped in a contextmanager.
+  Returns:
+    A callable that can be used inside of a `with` statement.
+  """
+  context_manager = _contextlib.contextmanager(target)
+  return tf_decorator.make_decorator(target, context_manager, 'contextmanager')
diff --git a/tensorflow/python/util/tf_contextlib_test.py b/tensorflow/python/util/tf_contextlib_test.py
new file mode 100644
index 00000000000..4a5bf388a63
--- /dev/null
+++ b/tensorflow/python/util/tf_contextlib_test.py
@@ -0,0 +1,92 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for tf_contextlib."""
+
+# pylint: disable=unused-import
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.platform import test
+from tensorflow.python.util import tf_contextlib
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
+
+@tf_contextlib.contextmanager
+def test_yield_append_before_and_after_yield(x, before, after):
+  x.append(before)
+  yield
+  x.append(after)
+
+
+@tf_contextlib.contextmanager
+def test_yield_return_x_plus_1(x):
+  yield x + 1
+
+
+@tf_contextlib.contextmanager
+def test_params_and_defaults(a, b=2, c=True, d='hello'):
+  return [a, b, c, d]
+
+
+class TfContextlibTest(test.TestCase):
+
+  def testRunsCodeBeforeYield(self):
+    x = []
+    with test_yield_append_before_and_after_yield(x, 'before', ''):
+      self.assertEqual('before', x[-1])
+
+  def testRunsCodeAfterYield(self):
+    x = []
+    with test_yield_append_before_and_after_yield(x, '', 'after'):
+      pass
+    self.assertEqual('after', x[-1])
+
+  def testNestedWith(self):
+    x = []
+    with test_yield_append_before_and_after_yield(x, 'before', 'after'):
+      with test_yield_append_before_and_after_yield(x, 'inner', 'outer'):
+        with test_yield_return_x_plus_1(1) as var:
+          x.append(var)
+    self.assertEqual(['before', 'inner', 2, 'outer', 'after'], x)
+
+  def testMultipleCallsOfSeparateInstances(self):
+    x = []
+    with test_yield_append_before_and_after_yield(x, 1, 2):
+      pass
+    with test_yield_append_before_and_after_yield(x, 3, 4):
+      pass
+    self.assertEqual([1, 2, 3, 4], x)
+
+  def testReturnsResultFromYield(self):
+    with test_yield_return_x_plus_1(3) as result:
+      self.assertEqual(4, result)
+
+  def testUnwrapContextManager(self):
+    decorators, target = tf_decorator.unwrap(test_params_and_defaults)
+    self.assertEqual(1, len(decorators))
+    self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
+    self.assertEqual('contextmanager', decorators[0].decorator_name)
+    self.assertFalse(isinstance(target, tf_decorator.TFDecorator))
+
+  def testGetArgSpecReturnsWrappedArgSpec(self):
+    argspec = tf_inspect.getargspec(test_params_and_defaults)
+    self.assertEqual(['a', 'b', 'c', 'd'], argspec.args)
+    self.assertEqual((2, True, 'hello'), argspec.defaults)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/util/tf_decorator.py b/tensorflow/python/util/tf_decorator.py
new file mode 100644
index 00000000000..a5d979e376c
--- /dev/null
+++ b/tensorflow/python/util/tf_decorator.py
@@ -0,0 +1,167 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base TFDecorator class and utility functions for working with decorators.
+
+There are two ways to create decorators that TensorFlow can introspect into.
+This is important for documentation generation purposes, so that function
+signatures aren't obscured by the (*args, **kwds) signature that decorators
+often provide.
+
+1. Call `tf_decorator.make_decorator` on your wrapper function. If your
+decorator is stateless, or can capture all of the variables it needs to work
+with through lexical closure, this is the simplest option. Create your wrapper
+function as usual, but instead of returning it, return
+`tf_decorator.make_decorator(your_wrapper)`. This will attach some decorator
+introspection metadata onto your wrapper and return it.
+
+Example:
+
+  def print_hello_before_calling(target):
+    def wrapper(*args, **kwargs):
+      print('hello')
+      return target(*args, **kwargs)
+    return tf_decorator.make_decorator(wrapper)
+
+2. Derive from TFDecorator. If your decorator needs to be stateful, you can
+implement it in terms of a TFDecorator. Store whatever state you need in your
+derived class, and implement the `__call__` method to do your work before
+calling into your target. You can retrieve the target via
+`super(MyDecoratorClass, self).decorated_target`, and call it with whatever
+parameters it needs.
+
+Example:
+
+  class CallCounter(tf_decorator.TFDecorator):
+    def __init__(self, target):
+      super(CallCounter, self).__init__('count_calls', target)
+      self.call_count = 0
+
+    def __call__(self, *args, **kwargs):
+      self.call_count += 1
+      return super(CallCounter, self).decorated_target(*args, **kwargs)
+
+  def count_calls(target):
+    return CallCounter(target)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools as _functools
+import inspect as _inspect
+
+
+def make_decorator(target,
+                   decorator_func,
+                   decorator_name=None,
+                   decorator_doc='',
+                   decorator_argspec=None):
+  """Make a decorator from a wrapper and a target.
+
+  Args:
+    target: The final callable to be wrapped.
+    decorator_func: The wrapper function.
+    decorator_name: The name of the decorator. If `None`, the name of the
+      function calling make_decorator.
+    decorator_doc: Documentation specific to this application of
+      `decorator_func` to `target`.
+    decorator_argspec: The new callable signature of this decorator.
+
+  Returns:
+    The `decorator_func` argument with new metadata attached.
+  """
+  if decorator_name is None:
+    decorator_name = _inspect.stack()[1][3]  # Caller's name.
+  decorator = TFDecorator(decorator_name, target, decorator_doc,
+                          decorator_argspec)
+  setattr(decorator_func, '_tf_decorator', decorator)
+  decorator_func.__name__ = target.__name__
+  decorator_func.__doc__ = decorator.__doc__
+  decorator_func.__wrapped__ = target
+  return decorator_func
+
+
+def unwrap(maybe_tf_decorator):
+  """Unwraps an object into a list of TFDecorators and a final target.
+
+  Args:
+    maybe_tf_decorator: Any callable object.
+
+  Returns:
+    A tuple whose first element is an list of TFDecorator-derived objects that
+    were applied to the final callable target, and whose second element is the
+    final undecorated callable target. If the `maybe_tf_decorator` parameter is
+    not decorated by any TFDecorators, the first tuple element will be an empty
+    list. The `TFDecorator` list is ordered from outermost to innermost
+    decorators.
+  """
+  decorators = []
+  cur = maybe_tf_decorator
+  while True:
+    if isinstance(cur, TFDecorator):
+      decorators.append(cur)
+    elif hasattr(cur, '_tf_decorator'):
+      decorators.append(getattr(cur, '_tf_decorator'))
+    else:
+      break
+    cur = decorators[-1].decorated_target
+  return decorators, cur
+
+
+class TFDecorator(object):
+  """Base class for all TensorFlow decorators.
+
+  TFDecorator captures and exposes the wrapped target, and provides details
+  about the current decorator.
+  """
+
+  def __init__(self,
+               decorator_name,
+               target,
+               decorator_doc='',
+               decorator_argspec=None):
+    self._decorated_target = target
+    self._decorator_name = decorator_name
+    self._decorator_doc = decorator_doc
+    self._decorator_argspec = decorator_argspec
+    self.__name__ = target.__name__
+    if self._decorator_doc:
+      self.__doc__ = self._decorator_doc
+    elif target.__doc__:
+      self.__doc__ = target.__doc__
+    else:
+      self.__doc__ = ''
+
+  def __get__(self, obj, objtype):
+    return _functools.partial(self.__call__, obj)
+
+  def __call__(self, *args, **kwargs):
+    return self._decorated_target(*args, **kwargs)
+
+  @property
+  def decorated_target(self):
+    return self._decorated_target
+
+  @property
+  def decorator_name(self):
+    return self._decorator_name
+
+  @property
+  def decorator_doc(self):
+    return self._decorator_doc
+
+  @property
+  def decorator_argspec(self):
+    return self._decorator_argspec
diff --git a/tensorflow/python/util/tf_decorator_test.py b/tensorflow/python/util/tf_decorator_test.py
new file mode 100644
index 00000000000..3f6a10b4408
--- /dev/null
+++ b/tensorflow/python/util/tf_decorator_test.py
@@ -0,0 +1,243 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for tf_decorator."""
+
+# pylint: disable=unused-import
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
+
+def test_tfdecorator(decorator_name, decorator_doc=None):
+
+  def make_tf_decorator(target):
+    return tf_decorator.TFDecorator(decorator_name, target, decorator_doc)
+
+  return make_tf_decorator
+
+
+def test_decorator_increment_first_int_arg(target):
+  """This test decorator skips past `self` as args[0] in the bound case."""
+
+  def wrapper(*args, **kwargs):
+    new_args = []
+    found = False
+    for arg in args:
+      if not found and isinstance(arg, int):
+        new_args.append(arg + 1)
+        found = True
+      else:
+        new_args.append(arg)
+    return target(*new_args, **kwargs)
+
+  return tf_decorator.make_decorator(target, wrapper)
+
+
+def test_function(x):
+  """Test Function Docstring."""
+  return x + 1
+
+
+@test_tfdecorator('decorator 1')
+@test_decorator_increment_first_int_arg
+@test_tfdecorator('decorator 3', 'decorator 3 documentation')
+def test_decorated_function(x):
+  """Test Decorated Function Docstring."""
+  return x * 2
+
+
+@test_tfdecorator('decorator')
+class TestDecoratedClass(object):
+  """Test Decorated Class."""
+
+  def __init__(self, two_attr=2):
+    self.two_attr = two_attr
+
+  @property
+  def two_prop(self):
+    return 2
+
+  def two_func(self):
+    return 2
+
+  @test_decorator_increment_first_int_arg
+  def return_params(self, a, b, c):
+    """Return parameters."""
+    return [a, b, c]
+
+
+class TfDecoratorTest(test.TestCase):
+
+  def testInitCapturesTarget(self):
+    self.assertIs(test_function,
+                  tf_decorator.TFDecorator('', test_function).decorated_target)
+
+  def testInitCapturesDecoratorName(self):
+    self.assertEqual('decorator name',
+                     tf_decorator.TFDecorator('decorator name',
+                                              test_function).decorator_name)
+
+  def testInitCapturesDecoratorDoc(self):
+    self.assertEqual('decorator doc',
+                     tf_decorator.TFDecorator('', test_function,
+                                              'decorator doc').decorator_doc)
+
+  def testInitCapturesNonNoneArgspec(self):
+    argspec = tf_inspect.ArgSpec(
+        args=['a', 'b', 'c'],
+        varargs=None,
+        keywords=None,
+        defaults=(1, 'hello'))
+    self.assertIs(argspec,
+                  tf_decorator.TFDecorator('', test_function, '',
+                                           argspec).decorator_argspec)
+
+  def testInitSetsDecoratorNameToTargetName(self):
+    self.assertEqual('test_function',
+                     tf_decorator.TFDecorator('', test_function).__name__)
+
+  def testInitSetsDecoratorDocToTargetDoc(self):
+    self.assertEqual('Test Function Docstring.',
+                     tf_decorator.TFDecorator('', test_function).__doc__)
+
+  def testCallingATFDecoratorCallsTheTarget(self):
+    self.assertEqual(124, tf_decorator.TFDecorator('', test_function)(123))
+
+  def testCallingADecoratedFunctionCallsTheTarget(self):
+    self.assertEqual((2 + 1) * 2, test_decorated_function(2))
+
+  def testInitializingDecoratedClassWithInitParamsDoesntRaise(self):
+    try:
+      TestDecoratedClass(2)
+    except TypeError:
+      self.assertFail()
+
+  def testReadingClassAttributeOnDecoratedClass(self):
+    self.assertEqual(2, TestDecoratedClass().two_attr)
+
+  def testCallingClassMethodOnDecoratedClass(self):
+    self.assertEqual(2, TestDecoratedClass().two_func())
+
+  def testReadingClassPropertyOnDecoratedClass(self):
+    self.assertEqual(2, TestDecoratedClass().two_prop)
+
+  def testNameOnBoundProperty(self):
+    self.assertEqual('return_params',
+                     TestDecoratedClass().return_params.__name__)
+
+  def testDocstringOnBoundProperty(self):
+    self.assertEqual('Return parameters.',
+                     TestDecoratedClass().return_params.__doc__)
+
+
+def test_wrapper(*args, **kwargs):
+  return test_function(*args, **kwargs)
+
+
+class TfMakeDecoratorTest(test.TestCase):
+
+  def testAttachesATFDecoratorAttr(self):
+    decorated = tf_decorator.make_decorator(test_function, test_wrapper)
+    decorator = getattr(decorated, '_tf_decorator')
+    self.assertIsInstance(decorator, tf_decorator.TFDecorator)
+
+  def testAttachesWrappedAttr(self):
+    decorated = tf_decorator.make_decorator(test_function, test_wrapper)
+    wrapped_attr = getattr(decorated, '__wrapped__')
+    self.assertIs(test_function, wrapped_attr)
+
+  def testSetsTFDecoratorNameToDecoratorNameArg(self):
+    decorated = tf_decorator.make_decorator(test_function, test_wrapper,
+                                            'test decorator name')
+    decorator = getattr(decorated, '_tf_decorator')
+    self.assertEqual('test decorator name', decorator.decorator_name)
+
+  def testSetsTFDecoratorDocToDecoratorDocArg(self):
+    decorated = tf_decorator.make_decorator(
+        test_function, test_wrapper, decorator_doc='test decorator doc')
+    decorator = getattr(decorated, '_tf_decorator')
+    self.assertEqual('test decorator doc', decorator.decorator_doc)
+
+  def testSetsTFDecoratorArgSpec(self):
+    argspec = tf_inspect.ArgSpec(
+        args=['a', 'b', 'c'],
+        varargs=None,
+        keywords=None,
+        defaults=(1, 'hello'))
+    decorated = tf_decorator.make_decorator(test_function, test_wrapper, '', '',
+                                            argspec)
+    decorator = getattr(decorated, '_tf_decorator')
+    self.assertEqual(argspec, decorator.decorator_argspec)
+
+  def testSetsDecoratorNameToFunctionThatCallsMakeDecoratorIfAbsent(self):
+
+    def test_decorator_name(wrapper):
+      return tf_decorator.make_decorator(test_function, wrapper)
+
+    decorated = test_decorator_name(test_wrapper)
+    decorator = getattr(decorated, '_tf_decorator')
+    self.assertEqual('test_decorator_name', decorator.decorator_name)
+
+
+class TfDecoratorUnwrapTest(test.TestCase):
+
+  def testUnwrapReturnsEmptyArrayForUndecoratedFunction(self):
+    decorators, _ = tf_decorator.unwrap(test_function)
+    self.assertEqual(0, len(decorators))
+
+  def testUnwrapReturnsUndecoratedFunctionAsTarget(self):
+    _, target = tf_decorator.unwrap(test_function)
+    self.assertIs(test_function, target)
+
+  def testUnwrapReturnsFinalFunctionAsTarget(self):
+    self.assertEqual((4 + 1) * 2, test_decorated_function(4))
+    _, target = tf_decorator.unwrap(test_decorated_function)
+    self.assertTrue(tf_inspect.isfunction(target))
+    self.assertEqual(4 * 2, target(4))
+
+  def testUnwrapReturnsListOfUniqueTFDecorators(self):
+    decorators, _ = tf_decorator.unwrap(test_decorated_function)
+    self.assertEqual(3, len(decorators))
+    self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
+    self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator))
+    self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator))
+    self.assertIsNot(decorators[0], decorators[1])
+    self.assertIsNot(decorators[1], decorators[2])
+    self.assertIsNot(decorators[2], decorators[0])
+
+  def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self):
+    decorators, _ = tf_decorator.unwrap(test_decorated_function)
+    self.assertEqual('decorator 1', decorators[0].decorator_name)
+    self.assertEqual('test_decorator_increment_first_int_arg',
+                     decorators[1].decorator_name)
+    self.assertEqual('decorator 3', decorators[2].decorator_name)
+    self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc)
+
+  def testUnwrapBoundMethods(self):
+    test_decorated_class = TestDecoratedClass()
+    self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3))
+    decorators, target = tf_decorator.unwrap(test_decorated_class.return_params)
+    self.assertEqual('test_decorator_increment_first_int_arg',
+                     decorators[0].decorator_name)
+    self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
new file mode 100644
index 00000000000..977b0df08b5
--- /dev/null
+++ b/tensorflow/python/util/tf_inspect.py
@@ -0,0 +1,141 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TFDecorator-aware replacements for the inspect module."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import inspect as _inspect
+
+from tensorflow.python.util import tf_decorator
+
+ArgSpec = _inspect.ArgSpec
+
+
+def currentframe():
+  """TFDecorator-aware replacement for inspect.currentframe."""
+  return _inspect.stack()[1][0]
+
+
+def getargspec(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.getargspec.
+
+  Args:
+    object: A callable, possibly decorated.
+
+  Returns:
+    The `ArgSpec` that describes the signature of the outermost decorator that
+    changes the callable's signature. If the callable is not decorated,
+    `inspect.getargspec()` will be called directly on the callable.
+  """
+  decorators, target = tf_decorator.unwrap(object)
+  return next((d.decorator_argspec for d in decorators
+               if d.decorator_argspec is not None), _inspect.getargspec(target))
+
+
+def getcallargs(func, *positional, **named):
+  """TFDecorator-aware replacement for inspect.getcallargs.
+
+  Args:
+    func: A callable, possibly decorated
+    *positional: The positional arguments that would be passed to `func`.
+    **named: The named argument dictionary that would be passed to `func`.
+
+  Returns:
+    A dictionary mapping `func`'s named arguments to the values they would
+    receive if `func(*positional, **named)` were called.
+
+  `getcallargs` will use the argspec from the outermost decorator that provides
+  it. If no attached decorators modify argspec, the final unwrapped target's
+  argspec will be used.
+  """
+  argspec = getargspec(func)
+  call_args = named.copy()
+  this = getattr(func, 'im_self', None) or getattr(func, '__self__', None)
+  if ismethod(func) and this:
+    positional = (this,) + positional
+  remaining_positionals = [arg for arg in argspec.args if arg not in call_args]
+  call_args.update(dict(zip(remaining_positionals, positional)))
+  default_count = 0 if not argspec.defaults else len(argspec.defaults)
+  if default_count:
+    for arg, value in zip(argspec.args[-default_count:], argspec.defaults):
+      if arg not in call_args:
+        call_args[arg] = value
+  return call_args
+
+
+def getdoc(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.getdoc.
+
+  Args:
+    object: An object, possibly decorated.
+
+  Returns:
+    The docstring associated with the object.
+
+  The outermost-decorated object is intended to have the most complete
+  documentation, so the decorated parameter is not unwrapped.
+  """
+  return _inspect.getdoc(object)
+
+
+def getfile(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.getfile."""
+  return _inspect.getfile(tf_decorator.unwrap(object)[1])
+
+
+def getmembers(object, predicate=None):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.getmembers."""
+  return _inspect.getmembers(object, predicate)
+
+
+def getmro(cls):
+  """TFDecorator-aware replacement for inspect.getmro."""
+  return _inspect.getmro(cls)
+
+
+def getsource(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.getsource."""
+  return _inspect.getsource(tf_decorator.unwrap(object)[1])
+
+
+def isclass(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.isclass."""
+  return _inspect.isclass(tf_decorator.unwrap(object)[1])
+
+
+def isfunction(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.isfunction."""
+  return _inspect.isfunction(tf_decorator.unwrap(object)[1])
+
+
+def ismethod(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.ismethod."""
+  return _inspect.ismethod(tf_decorator.unwrap(object)[1])
+
+
+def ismodule(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.ismodule."""
+  return _inspect.ismodule(tf_decorator.unwrap(object)[1])
+
+
+def isroutine(object):  # pylint: disable=redefined-builtin
+  """TFDecorator-aware replacement for inspect.isroutine."""
+  return _inspect.isroutine(tf_decorator.unwrap(object)[1])
+
+
+def stack(context=1):
+  """TFDecorator-aware replacement for inspect.stack."""
+  return _inspect.stack(context)[1:]
diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py
new file mode 100644
index 00000000000..a9e8ffb30c3
--- /dev/null
+++ b/tensorflow/python/util/tf_inspect_test.py
@@ -0,0 +1,327 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for tf_inspect."""
+
+# pylint: disable=unused-import
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import inspect
+
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
+
+def test_decorator(decorator_name, decorator_doc=None):
+
+  def make_tf_decorator(target):
+    return tf_decorator.TFDecorator(decorator_name, target, decorator_doc)
+
+  return make_tf_decorator
+
+
+def test_undecorated_function():
+  pass
+
+
+@test_decorator('decorator 1')
+@test_decorator('decorator 2')
+@test_decorator('decorator 3')
+def test_decorated_function(x):
+  """Test Decorated Function Docstring."""
+  return x * 2
+
+
+@test_decorator('decorator')
+def test_decorated_function_with_defaults(a, b=2, c='Hello'):
+  """Test Decorated Function With Defaults Docstring."""
+  return [a, b, c]
+
+
+@test_decorator('decorator')
+class TestDecoratedClass(object):
+  """Test Decorated Class."""
+
+  def __init__(self):
+    pass
+
+  def two(self):
+    return 2
+
+
+class TfInspectTest(test.TestCase):
+
+  def testCurrentFrame(self):
+    self.assertEqual(inspect.currentframe(), tf_inspect.currentframe())
+
+  def testGetArgSpecOnDecoratorsThatDontProvideArgspec(self):
+    argspec = tf_inspect.getargspec(test_decorated_function_with_defaults)
+    self.assertEqual(['a', 'b', 'c'], argspec.args)
+    self.assertEqual((2, 'Hello'), argspec.defaults)
+
+  def testGetArgSpecOnDecoratorThatChangesArgspec(self):
+    argspec = tf_inspect.ArgSpec(
+        args=['a', 'b', 'c'],
+        varargs=None,
+        keywords=None,
+        defaults=(1, 'hello'))
+
+    decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
+                                         argspec)
+    self.assertEqual(argspec, tf_inspect.getargspec(decorator))
+
+  def testGetArgSpecIgnoresDecoratorsThatDontProvideArgspec(self):
+    argspec = tf_inspect.ArgSpec(
+        args=['a', 'b', 'c'],
+        varargs=None,
+        keywords=None,
+        defaults=(1, 'hello'))
+
+    inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
+                                               '', argspec)
+    outer_decorator = tf_decorator.TFDecorator('', inner_decorator)
+    self.assertEqual(argspec, tf_inspect.getargspec(outer_decorator))
+
+  def testGetArgSpecReturnsOutermostDecoratorThatChangesArgspec(self):
+    outer_argspec = tf_inspect.ArgSpec(
+        args=['a'], varargs=None, keywords=None, defaults=None)
+    inner_argspec = tf_inspect.ArgSpec(
+        args=['b'], varargs=None, keywords=None, defaults=None)
+
+    inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
+                                               '', inner_argspec)
+    outer_decorator = tf_decorator.TFDecorator('', inner_decorator, '',
+                                               outer_argspec)
+    self.assertEqual(outer_argspec, tf_inspect.getargspec(outer_decorator))
+
+  def testGetDoc(self):
+    self.assertEqual('Test Decorated Function With Defaults Docstring.',
+                     tf_inspect.getdoc(test_decorated_function_with_defaults))
+
+  def testGetFile(self):
+    self.assertTrue('tf_inspect_test.py' in tf_inspect.getfile(
+        test_decorated_function_with_defaults))
+    self.assertTrue('tf_decorator.py' in tf_inspect.getfile(
+        test_decorator('decorator')(tf_decorator.unwrap)))
+
+  def testGetMembers(self):
+    self.assertEqual(
+        inspect.getmembers(TestDecoratedClass),
+        tf_inspect.getmembers(TestDecoratedClass))
+
+  def testGetSource(self):
+    expected = '''@test_decorator('decorator')
+def test_decorated_function_with_defaults(a, b=2, c='Hello'):
+  """Test Decorated Function With Defaults Docstring."""
+  return [a, b, c]
+'''
+    self.assertEqual(
+        expected, tf_inspect.getsource(test_decorated_function_with_defaults))
+
+  def testIsClass(self):
+    self.assertTrue(tf_inspect.isclass(TestDecoratedClass))
+    self.assertFalse(tf_inspect.isclass(test_decorated_function))
+
+  def testIsFunction(self):
+    self.assertTrue(tf_inspect.isfunction(test_decorated_function))
+    self.assertFalse(tf_inspect.isfunction(TestDecoratedClass))
+
+  def testIsMethod(self):
+    self.assertTrue(tf_inspect.ismethod(TestDecoratedClass().two))
+    self.assertFalse(tf_inspect.ismethod(test_decorated_function))
+
+  def testIsModule(self):
+    self.assertTrue(
+        tf_inspect.ismodule(inspect.getmodule(inspect.currentframe())))
+    self.assertFalse(tf_inspect.ismodule(test_decorated_function))
+
+  def testIsRoutine(self):
+    self.assertTrue(tf_inspect.isroutine(len))
+    self.assertFalse(tf_inspect.isroutine(TestDecoratedClass))
+
+  def testStack(self):
+    expected_stack = inspect.stack()
+    actual_stack = tf_inspect.stack()
+    self.assertEqual(len(expected_stack), len(actual_stack))
+    self.assertEqual(expected_stack[0][0], actual_stack[0][0])  # Frame object
+    self.assertEqual(expected_stack[0][1], actual_stack[0][1])  # Filename
+    self.assertEqual(expected_stack[0][2],
+                     actual_stack[0][2] - 1)  # Line number
+    self.assertEqual(expected_stack[0][3], actual_stack[0][3])  # Function name
+    self.assertEqual(expected_stack[1:], actual_stack[1:])
+
+
+class TfInspectGetCallArgsTest(test.TestCase):
+
+  def testReturnsEmptyWhenUnboundFuncHasNoParameters(self):
+
+    def empty():
+      pass
+
+    self.assertEqual({}, tf_inspect.getcallargs(empty))
+
+  def testUnboundFuncWithOneParamPositional(self):
+
+    def func(a):
+      return a
+
+    self.assertEqual({'a': 5}, tf_inspect.getcallargs(func, 5))
+
+  def testUnboundFuncWithTwoParamsPositional(self):
+
+    def func(a, b):
+      return (a, b)
+
+    self.assertEqual({'a': 10, 'b': 20}, tf_inspect.getcallargs(func, 10, 20))
+
+  def testUnboundFuncWithOneParamKeyword(self):
+
+    def func(a):
+      return a
+
+    self.assertEqual({'a': 5}, tf_inspect.getcallargs(func, a=5))
+
+  def testUnboundFuncWithTwoParamsKeyword(self):
+
+    def func(a, b):
+      return (a, b)
+
+    self.assertEqual({'a': 6, 'b': 7}, tf_inspect.getcallargs(func, a=6, b=7))
+
+  def testUnboundFuncWithOneParamDefault(self):
+
+    def func(a=13):
+      return a
+
+    self.assertEqual({'a': 13}, tf_inspect.getcallargs(func))
+
+  def testUnboundFuncWithOneParamDefaultOnePositional(self):
+
+    def func(a=0):
+      return a
+
+    self.assertEqual({'a': 1}, tf_inspect.getcallargs(func, 1))
+
+  def testUnboundFuncWithTwoParamsDefaultOnePositional(self):
+
+    def func(a=1, b=2):
+      return (a, b)
+
+    self.assertEqual({'a': 5, 'b': 2}, tf_inspect.getcallargs(func, 5))
+
+  def testUnboundFuncWithTwoParamsDefaultTwoPositional(self):
+
+    def func(a=1, b=2):
+      return (a, b)
+
+    self.assertEqual({'a': 3, 'b': 4}, tf_inspect.getcallargs(func, 3, 4))
+
+  def testUnboundFuncWithOneParamDefaultOneKeyword(self):
+
+    def func(a=1):
+      return a
+
+    self.assertEqual({'a': 3}, tf_inspect.getcallargs(func, a=3))
+
+  def testUnboundFuncWithTwoParamsDefaultOneKeywordFirst(self):
+
+    def func(a=1, b=2):
+      return (a, b)
+
+    self.assertEqual({'a': 3, 'b': 2}, tf_inspect.getcallargs(func, a=3))
+
+  def testUnboundFuncWithTwoParamsDefaultOneKeywordSecond(self):
+
+    def func(a=1, b=2):
+      return (a, b)
+
+    self.assertEqual({'a': 1, 'b': 4}, tf_inspect.getcallargs(func, b=4))
+
+  def testUnboundFuncWithTwoParamsDefaultTwoKeywords(self):
+
+    def func(a=1, b=2):
+      return (a, b)
+
+    self.assertEqual({'a': 3, 'b': 4}, tf_inspect.getcallargs(func, a=3, b=4))
+
+  def testBoundFuncWithOneParam(self):
+
+    class Test(object):
+
+      def bound(self):
+        pass
+
+    t = Test()
+    self.assertEqual({'self': t}, tf_inspect.getcallargs(t.bound))
+
+  def testBoundFuncWithManyParamsAndDefaults(self):
+
+    class Test(object):
+
+      def bound(self, a, b=2, c='Hello'):
+        return (a, b, c)
+
+    t = Test()
+    self.assertEqual({
+        'self': t,
+        'a': 3,
+        'b': 2,
+        'c': 'Goodbye'
+    }, tf_inspect.getcallargs(t.bound, 3, c='Goodbye'))
+
+  def testClassMethod(self):
+
+    class Test(object):
+
+      @classmethod
+      def test(cls, a, b=3, c='hello'):
+        return (a, b, c)
+
+    self.assertEqual({
+        'cls': Test,
+        'a': 5,
+        'b': 3,
+        'c': 'goodbye'
+    }, tf_inspect.getcallargs(Test.test, 5, c='goodbye'))
+
+  def testUsesOutermostDecoratorsArgSpec(self):
+
+    def func():
+      pass
+
+    def wrapper(*args, **kwargs):
+      return func(*args, **kwargs)
+
+    decorated = tf_decorator.make_decorator(
+        func,
+        wrapper,
+        decorator_argspec=tf_inspect.ArgSpec(
+            args=['a', 'b', 'c'],
+            varargs=None,
+            keywords=None,
+            defaults=(3, 'hello')))
+
+    self.assertEqual({
+        'a': 4,
+        'b': 3,
+        'c': 'goodbye'
+    }, tf_inspect.getcallargs(decorated, 4, c='goodbye'))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt
index 9c70782b12e..4460de57aa3 100644
--- a/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt
@@ -52,11 +52,11 @@ tf_class {
   }
   member_method {
     name: "colocate_with"
-    argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
+    argspec: "args=[\'self\', \'op\', \'ignore_existing\'], varargs=None, keywords=None, defaults=[\'False\'], "
   }
   member_method {
     name: "container"
-    argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
+    argspec: "args=[\'self\', \'container_name\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "control_dependencies"
@@ -68,7 +68,7 @@ tf_class {
   }
   member_method {
     name: "device"
-    argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
+    argspec: "args=[\'self\', \'device_name_or_function\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "finalize"
@@ -104,7 +104,7 @@ tf_class {
   }
   member_method {
     name: "gradient_override_map"
-    argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
+    argspec: "args=[\'self\', \'op_type_map\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "is_feedable"
@@ -116,7 +116,7 @@ tf_class {
   }
   member_method {
     name: "name_scope"
-    argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "prevent_feeding"
diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.-fast-g-file.pbtxt b/tensorflow/tools/api/golden/tensorflow.gfile.-fast-g-file.pbtxt
index 4c0a0b2ea0d..eecfaffd0a6 100644
--- a/tensorflow/tools/api/golden/tensorflow.gfile.-fast-g-file.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.gfile.-fast-g-file.pbtxt
@@ -41,7 +41,7 @@ tf_class {
   }
   member_method {
     name: "seek"
-    argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
+    argspec: "args=[\'self\', \'offset\', \'whence\', \'position\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
   }
   member_method {
     name: "size"
diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.-g-file.pbtxt b/tensorflow/tools/api/golden/tensorflow.gfile.-g-file.pbtxt
index 85d81c4fcbc..305251059d9 100644
--- a/tensorflow/tools/api/golden/tensorflow.gfile.-g-file.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.gfile.-g-file.pbtxt
@@ -41,7 +41,7 @@ tf_class {
   }
   member_method {
     name: "seek"
-    argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
+    argspec: "args=[\'self\', \'offset\', \'whence\', \'position\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
   }
   member_method {
     name: "size"
diff --git a/tensorflow/tools/api/golden/tensorflow.gfile.-open.pbtxt b/tensorflow/tools/api/golden/tensorflow.gfile.-open.pbtxt
index 13f9c203e85..6e8894180a4 100644
--- a/tensorflow/tools/api/golden/tensorflow.gfile.-open.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.gfile.-open.pbtxt
@@ -41,7 +41,7 @@ tf_class {
   }
   member_method {
     name: "seek"
-    argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
+    argspec: "args=[\'self\', \'offset\', \'whence\', \'position\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
   }
   member_method {
     name: "size"
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 92ded90834b..b0e09240930 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -518,7 +518,7 @@ tf_module {
   }
   member_method {
     name: "all_variables"
-    argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "arg_max"
@@ -1074,19 +1074,19 @@ tf_module {
   }
   member_method {
     name: "initialize_all_tables"
-    argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
+    argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "
   }
   member_method {
     name: "initialize_all_variables"
-    argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "initialize_local_variables"
-    argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
   }
   member_method {
     name: "initialize_variables"
-    argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
+    argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], "
   }
   member_method {
     name: "invert_permutation"
@@ -1278,7 +1278,7 @@ tf_module {
   }
   member_method {
     name: "name_scope"
-    argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
+    argspec: "args=[\'name\', \'default_name\', \'values\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
   }
   member_method {
     name: "negative"
@@ -1314,7 +1314,7 @@ tf_module {
   }
   member_method {
     name: "op_scope"
-    argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
+    argspec: "args=[\'values\', \'name\', \'default_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "pad"
@@ -1902,11 +1902,11 @@ tf_module {
   }
   member_method {
     name: "variable_op_scope"
-    argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
+    argspec: "args=[\'values\', \'name_or_scope\', \'default_name\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "variable_scope"
-    argspec: "args=[], varargs=args, keywords=kwds, defaults=None"
+    argspec: "args=[\'name_or_scope\', \'default_name\', \'values\', \'initializer\', \'regularizer\', \'caching_device\', \'partitioner\', \'custom_getter\', \'reuse\', \'dtype\', \'use_resource\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "variables_initializer"
diff --git a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
index b4ab9f43f51..34edbf61f5e 100644
--- a/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
+++ b/tensorflow/tools/api/lib/python_object_to_proto_visitor.py
@@ -19,12 +19,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
-
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
 from tensorflow.tools.api.lib import api_objects_pb2
 
-
 # Following object need to be handled individually.
 _CORNER_CASES = {
     '': {'tools': {}},
@@ -47,7 +46,7 @@ def _SanitizedArgSpec(obj):
     string, a string representation of the argspec.
   """
   output_string = ''
-  unsanitized_arg_spec = inspect.getargspec(obj)
+  unsanitized_arg_spec = tf_inspect.getargspec(obj)
 
   for clean_attr in ('args', 'varargs', 'keywords'):
     output_string += '%s=%s, ' % (clean_attr,
@@ -76,7 +75,7 @@ def _SanitizedMRO(obj):
   Based on many parameters like python version, OS, protobuf implementation
   or changes in google core libraries the list of superclasses of a class
   can change. We only return the first non-TF class to be robust to non API
-  affecting changes. The Method Resolution Order returned by inspect.getmro
+  affecting changes. The Method Resolution Order returned by `tf_inspect.getmro`
   is still maintained in the return value.
 
   Args:
@@ -86,7 +85,7 @@ def _SanitizedMRO(obj):
     list of strings, string representation of the class names.
   """
   return_list = []
-  for cls in inspect.getmro(obj):
+  for cls in tf_inspect.getmro(obj):
     str_repr = str(cls)
     return_list.append(str_repr)
     if 'tensorflow' not in str_repr:
@@ -114,8 +113,9 @@ class PythonObjectToProtoVisitor(object):
     # A small helper method to construct members(children) protos.
     def _AddMember(member_name, member_obj, proto):
       """Add the child object to the object being constructed."""
+      _, member_obj = tf_decorator.unwrap(member_obj)
       if member_name == '__init__' or not member_name.startswith('_'):
-        if inspect.isroutine(member_obj):
+        if tf_inspect.isroutine(member_obj):
           new_method = proto.member_method.add()
           new_method.name = member_name
           # If member_obj is a python builtin, there is no way to get its
@@ -132,7 +132,7 @@ class PythonObjectToProtoVisitor(object):
 
     if path not in _CORNER_CASES or parent_corner_cases:
       # Decide if we have a module or a class.
-      if inspect.ismodule(parent):
+      if tf_inspect.ismodule(parent):
         # Create a module object.
         module_obj = api_objects_pb2.TFAPIModule()
         for name, child in children:
@@ -146,7 +146,7 @@ class PythonObjectToProtoVisitor(object):
         # Store the constructed module object.
         self._protos[lib_path] = api_objects_pb2.TFAPIObject(
             path=lib_path, tf_module=module_obj)
-      elif inspect.isclass(parent):
+      elif tf_inspect.isclass(parent):
         # Construct a class.
         class_obj = api_objects_pb2.TFAPIClass()
         class_obj.is_instance.extend(_SanitizedMRO(parent))
diff --git a/tensorflow/tools/common/public_api.py b/tensorflow/tools/common/public_api.py
index 44c6f24e567..837f11f690e 100644
--- a/tensorflow/tools/common/public_api.py
+++ b/tensorflow/tools/common/public_api.py
@@ -18,9 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import re
 
+from tensorflow.python.util import tf_inspect
+
 
 class PublicAPIVisitor(object):
   """Visitor to use with `traverse` to visit exactly the public TF API."""
@@ -93,7 +94,7 @@ class PublicAPIVisitor(object):
     """Visitor interface, see `traverse` for details."""
 
     # Avoid long waits in cases of pretty unambiguous failure.
-    if inspect.ismodule(parent) and len(path.split('.')) > 10:
+    if tf_inspect.ismodule(parent) and len(path.split('.')) > 10:
       raise RuntimeError('Modules nested too deep:\n%s\n\nThis is likely a '
                          'problem with an accidental public import.' % path)
 
diff --git a/tensorflow/tools/common/traverse.py b/tensorflow/tools/common/traverse.py
index 443838d9682..9607f80686d 100644
--- a/tensorflow/tools/common/traverse.py
+++ b/tensorflow/tools/common/traverse.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import sys
 
+from tensorflow.python.util import tf_inspect
 
 __all__ = ['traverse']
 
@@ -29,11 +29,11 @@ def _traverse_internal(root, visit, stack, path):
   """Internal helper for traverse."""
 
   # Only traverse modules and classes
-  if not inspect.isclass(root) and not inspect.ismodule(root):
+  if not tf_inspect.isclass(root) and not tf_inspect.ismodule(root):
     return
 
   try:
-    children = inspect.getmembers(root)
+    children = tf_inspect.getmembers(root)
   except ImportError:
     # On some Python installations, some modules do not support enumerating
     # members (six in particular), leading to import errors.
@@ -43,7 +43,8 @@ def _traverse_internal(root, visit, stack, path):
   visit(path, root, children)
   for name, child in children:
     # Do not descend into built-in modules
-    if inspect.ismodule(child) and child.__name__ in sys.builtin_module_names:
+    if tf_inspect.ismodule(
+        child) and child.__name__ in sys.builtin_module_names:
       continue
 
     # Break cycles
@@ -72,8 +73,8 @@ def traverse(root, visit):
   never descends into built-in modules.
 
   `children`, a list of `(name, object)` pairs are determined by
-  `inspect.getmembers`. To avoid visiting parts of the tree, `children` can be
-  modified in place, using `del` or slice assignment.
+  `tf_inspect.getmembers`. To avoid visiting parts of the tree, `children` can
+  be modified in place, using `del` or slice assignment.
 
   Cycles (determined by reference equality, `is`) stop the traversal. A stack of
   objects is kept to find cycles. Objects forming cycles may appear in
diff --git a/tensorflow/tools/docs/doc_generator_visitor.py b/tensorflow/tools/docs/doc_generator_visitor.py
index 178ac0940e7..119305bece3 100644
--- a/tensorflow/tools/docs/doc_generator_visitor.py
+++ b/tensorflow/tools/docs/doc_generator_visitor.py
@@ -18,10 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
-
 import six
 
+from tensorflow.python.util import tf_inspect
+
 
 class DocGeneratorVisitor(object):
   """A visitor that generates docs for a python object when __call__ed."""
@@ -133,8 +133,8 @@ class DocGeneratorVisitor(object):
       parent_name: The fully qualified name of a symbol found during traversal.
       parent: The Python object referenced by `parent_name`.
       children: A list of `(name, py_object)` pairs enumerating, in alphabetical
-        order, the children (as determined by `inspect.getmembers`) of `parent`.
-        `name` is the local name of `py_object` in `parent`.
+        order, the children (as determined by `tf_inspect.getmembers`) of
+          `parent`. `name` is the local name of `py_object` in `parent`.
 
     Raises:
       RuntimeError: If this visitor is called with a `parent` that is not a
@@ -144,9 +144,9 @@ class DocGeneratorVisitor(object):
     self._index[parent_name] = parent
     self._tree[parent_name] = []
 
-    if not (inspect.ismodule(parent) or inspect.isclass(parent)):
-      raise RuntimeError('Unexpected type in visitor -- %s: %r' %
-                         (parent_name, parent))
+    if not (tf_inspect.ismodule(parent) or tf_inspect.isclass(parent)):
+      raise RuntimeError('Unexpected type in visitor -- %s: %r' % (parent_name,
+                                                                   parent))
 
     for i, (name, child) in enumerate(list(children)):
       # Don't document __metaclass__
@@ -190,9 +190,8 @@ class DocGeneratorVisitor(object):
       # have no usable docstring and won't be documented automatically.
       if (py_object is not None and
           not isinstance(py_object, six.integer_types + six.string_types +
-                         (six.binary_type, six.text_type, float, complex, bool)
-                        ) and
-          py_object is not ()):
+                         (six.binary_type, six.text_type, float, complex, bool))
+          and py_object is not ()):
         object_id = id(py_object)
         if object_id in reverse_index:
           master_name = reverse_index[object_id]
diff --git a/tensorflow/tools/docs/generate.py b/tensorflow/tools/docs/generate.py
index 1217cd331f9..fc93085e3e0 100644
--- a/tensorflow/tools/docs/generate.py
+++ b/tensorflow/tools/docs/generate.py
@@ -18,16 +18,15 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import os
 import sys
 
 import tensorflow as tf
 
 from tensorflow.python import debug as tf_debug
+from tensorflow.python.util import tf_inspect
 from tensorflow.tools.docs import generate_lib
 
-
 if __name__ == '__main__':
   doc_generator = generate_lib.DocGenerator()
   doc_generator.add_output_dir_argument()
@@ -38,7 +37,7 @@ if __name__ == '__main__':
   # tensorflow/, we can compute the base directory (two levels up), which is
   # valid unless we're trying to apply this to a different code base, or are
   # moving the script around.
-  script_dir = os.path.dirname(inspect.getfile(inspect.currentframe()))
+  script_dir = os.path.dirname(tf_inspect.getfile(tf_inspect.currentframe()))
   default_base_dir = os.path.join(script_dir, '..', '..')
   doc_generator.add_base_dir_argument(default_base_dir)
 
diff --git a/tensorflow/tools/docs/generate_1_0.py b/tensorflow/tools/docs/generate_1_0.py
index 088f8f58dc2..ddafcebd118 100644
--- a/tensorflow/tools/docs/generate_1_0.py
+++ b/tensorflow/tools/docs/generate_1_0.py
@@ -18,16 +18,15 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import inspect
 import os
 import sys
 
 import tensorflow as tf
 
 from tensorflow.python import debug as tf_debug
+from tensorflow.python.util import tf_inspect
 from tensorflow.tools.docs import generate_lib
 
-
 if __name__ == '__main__':
   doc_generator = generate_lib.DocGenerator()
   doc_generator.add_output_dir_argument()
@@ -38,7 +37,7 @@ if __name__ == '__main__':
   # tensorflow/, we can compute the base directory (two levels up), which is
   # valid unless we're trying to apply this to a different code base, or are
   # moving the script around.
-  script_dir = os.path.dirname(inspect.getfile(inspect.currentframe()))
+  script_dir = os.path.dirname(tf_inspect.getfile(tf_inspect.currentframe()))
   default_base_dir = os.path.join(script_dir, '..', '..')
   doc_generator.add_base_dir_argument(default_base_dir)
 
@@ -67,21 +66,14 @@ if __name__ == '__main__':
           'tfprof',
       ],
       'contrib.bayesflow': [
-          'entropy', 'monte_carlo',
-          'special_math', 'stochastic_gradient_estimators',
-          'stochastic_graph', 'stochastic_tensor',
-          'stochastic_variables', 'variational_inference'
+          'entropy', 'monte_carlo', 'special_math',
+          'stochastic_gradient_estimators', 'stochastic_graph',
+          'stochastic_tensor', 'stochastic_variables', 'variational_inference'
       ],
       'contrib.distributions': ['bijector'],
       'contrib.ffmpeg': ['ffmpeg_ops'],
       'contrib.graph_editor': [
-          'edit',
-          'match',
-          'reroute',
-          'subgraph',
-          'transform',
-          'select',
-          'util'
+          'edit', 'match', 'reroute', 'subgraph', 'transform', 'select', 'util'
       ],
       'contrib.layers': ['feature_column', 'summaries'],
       'contrib.learn': [
diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py
index e30cfa80530..8b531859437 100644
--- a/tensorflow/tools/docs/generate_lib.py
+++ b/tensorflow/tools/docs/generate_lib.py
@@ -19,11 +19,11 @@ from __future__ import division
 from __future__ import print_function
 
 import argparse
-import inspect
 import os
 
 import six
 
+from tensorflow.python.util import tf_inspect
 from tensorflow.tools.common import public_api
 from tensorflow.tools.common import traverse
 from tensorflow.tools.docs import doc_generator_visitor
@@ -32,18 +32,18 @@ from tensorflow.tools.docs import pretty_docs
 from tensorflow.tools.docs import py_guide_parser
 
 
-def  _is_free_function(py_object, full_name, index):
+def _is_free_function(py_object, full_name, index):
   """Check if input is a free function (and not a class- or static method)."""
-  if not inspect.isfunction(py_object):
+  if not tf_inspect.isfunction(py_object):
     return False
 
-  # Static methods are functions to inspect (in 2.7), so check if the parent
+  # Static methods are functions to tf_inspect (in 2.7), so check if the parent
   # is a class. If there is no parent, it's not a function.
   if '.' not in full_name:
     return False
 
   parent_name = full_name.rsplit('.', 1)[0]
-  if inspect.isclass(index[parent_name]):
+  if tf_inspect.isclass(index[parent_name]):
     return False
 
   return True
@@ -87,7 +87,7 @@ def write_docs(output_dir, parser_config, yaml_toc):
       continue
 
     # Methods and some routines are documented only as part of their class.
-    if not (inspect.ismodule(py_object) or inspect.isclass(py_object) or
+    if not (tf_inspect.ismodule(py_object) or tf_inspect.isclass(py_object) or
             _is_free_function(py_object, full_name, parser_config.index)):
       continue
 
@@ -99,7 +99,7 @@ def write_docs(output_dir, parser_config, yaml_toc):
     symbol_to_file[full_name] = sitepath
 
     # For a module, remember the module for the table-of-contents
-    if inspect.ismodule(py_object):
+    if tf_inspect.ismodule(py_object):
       if full_name in parser_config.tree:
         module_children.setdefault(full_name, [])
 
@@ -109,7 +109,7 @@ def write_docs(output_dir, parser_config, yaml_toc):
       subname = str(full_name)
       while True:
         subname = subname[:subname.rindex('.')]
-        if inspect.ismodule(parser_config.index[subname]):
+        if tf_inspect.ismodule(parser_config.index[subname]):
           module_children.setdefault(subname, []).append(full_name)
           break
 
@@ -143,23 +143,23 @@ def write_docs(output_dir, parser_config, yaml_toc):
       f.write('# Automatically generated file; please do not edit\ntoc:\n')
       for module in modules:
         f.write('  - title: ' + module + '\n'
-                '    section:\n' +
-                '    - title: Overview\n' +
-                '      path: /TARGET_DOC_ROOT/VERSION/' +
-                symbol_to_file[module] + '\n')
+                '    section:\n' + '    - title: Overview\n' +
+                '      path: /TARGET_DOC_ROOT/VERSION/' + symbol_to_file[module]
+                + '\n')
 
         symbols_in_module = module_children.get(module, [])
         symbols_in_module.sort(key=lambda a: a.upper())
 
         for full_name in symbols_in_module:
-          f.write('    - title: ' + full_name[len(module)+1:] + '\n'
+          f.write('    - title: ' + full_name[len(module) + 1:] + '\n'
                   '      path: /TARGET_DOC_ROOT/VERSION/' +
                   symbol_to_file[full_name] + '\n')
 
   # Write a global index containing all full names with links.
   with open(os.path.join(output_dir, 'index.md'), 'w') as f:
-    f.write(parser.generate_global_index('TensorFlow', parser_config.index,
-                                         parser_config.reference_resolver))
+    f.write(
+        parser.generate_global_index('TensorFlow', parser_config.index,
+                                     parser_config.reference_resolver))
 
 
 def add_dict_to_dict(add_from, add_to):
@@ -198,13 +198,7 @@ def _get_default_do_not_descend_map():
       ],
       'contrib.ffmpeg': ['ffmpeg_ops'],
       'contrib.graph_editor': [
-          'edit',
-          'match',
-          'reroute',
-          'subgraph',
-          'transform',
-          'select',
-          'util'
+          'edit', 'match', 'reroute', 'subgraph', 'transform', 'select', 'util'
       ],
       'contrib.keras': ['api', 'python'],
       'contrib.layers': ['feature_column', 'summaries'],
@@ -266,7 +260,8 @@ def build_doc_index(src_dir):
   for dirpath, _, filenames in os.walk(src_dir):
     suffix = os.path.relpath(path=dirpath, start=src_dir)
     for base_name in filenames:
-      if not base_name.endswith('.md'): continue
+      if not base_name.endswith('.md'):
+        continue
       title_parser = _GetMarkdownTitle()
       title_parser.process(os.path.join(dirpath, base_name))
       key_parts = os.path.join(suffix, base_name[:-3]).split('/')
@@ -283,8 +278,8 @@ def build_doc_index(src_dir):
 class _GuideRef(object):
 
   def __init__(self, base_name, title, section_title, section_tag):
-    self.url = 'api_guides/python/' + (
-        ('%s#%s' % (base_name, section_tag)) if section_tag else base_name)
+    self.url = 'api_guides/python/' + (('%s#%s' % (base_name, section_tag))
+                                       if section_tag else base_name)
     self.link_text = (('%s > %s' % (title, section_title))
                       if section_title else title)
 
@@ -320,8 +315,9 @@ class _GenerateGuideIndex(py_guide_parser.PyGuideParser):
     """Index @{symbol} references as in the current file & section."""
     for match in parser.SYMBOL_REFERENCE_RE.finditer(line):
       val = self.index.get(match.group(1), [])
-      val.append(_GuideRef(
-          self.base_name, self.title, self.section_title, self.section_tag))
+      val.append(
+          _GuideRef(self.base_name, self.title, self.section_title,
+                    self.section_tag))
       self.index[match.group(1)] = val
 
 
@@ -383,8 +379,8 @@ def _other_docs(src_dir, output_dir, reference_resolver):
         print('Processing doc %s...' % suffix)
         md_string = open(full_in_path).read()
 
-      output = reference_resolver.replace_references(
-          md_string, relative_path_to_root)
+      output = reference_resolver.replace_references(md_string,
+                                                     relative_path_to_root)
       with open(full_out_path, 'w') as f:
         f.write(header + output)
 
@@ -406,8 +402,7 @@ class DocGenerator(object):
         type=str,
         default=None,
         required=True,
-        help='Directory to write docs to.'
-    )
+        help='Directory to write docs to.')
 
   def add_src_dir_argument(self):
     self.argument_parser.add_argument(
@@ -415,16 +410,14 @@ class DocGenerator(object):
         type=str,
         default=None,
         required=True,
-        help='Directory with the source docs.'
-    )
+        help='Directory with the source docs.')
 
   def add_base_dir_argument(self, default_base_dir):
     self.argument_parser.add_argument(
         '--base_dir',
         type=str,
         default=default_base_dir,
-        help='Base directory to to strip from file names referenced in docs.'
-    )
+        help='Base directory to to strip from file names referenced in docs.')
 
   def parse_known_args(self):
     flags, _ = self.argument_parser.parse_known_args()
diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py
index 3da58d2b3c7..526ffe93cd4 100644
--- a/tensorflow/tools/docs/parser.py
+++ b/tensorflow/tools/docs/parser.py
@@ -21,7 +21,6 @@ from __future__ import print_function
 import ast
 import collections
 import functools
-import inspect
 import json
 import os
 import re
@@ -30,6 +29,8 @@ import codegen
 import six
 
 from google.protobuf.message import Message as ProtoMessage
+from tensorflow.python.util import tf_inspect
+
 
 # A regular expression capturing a python indentifier.
 IDENTIFIER_RE = '[a-zA-Z_][a-zA-Z0-9_]*'
@@ -71,12 +72,12 @@ def _get_raw_docstring(py_object):
   Returns:
     The docstring, or the empty string if no docstring was found.
   """
-  # For object instances, inspect.getdoc does give us the docstring of their
+  # For object instances, tf_inspect.getdoc does give us the docstring of their
   # type, which is not what we want. Only return the docstring if it is useful.
-  if (inspect.isclass(py_object) or inspect.ismethod(py_object) or
-      inspect.isfunction(py_object) or inspect.ismodule(py_object) or
+  if (tf_inspect.isclass(py_object) or tf_inspect.ismethod(py_object) or
+      tf_inspect.isfunction(py_object) or tf_inspect.ismodule(py_object) or
       isinstance(py_object, property)):
-    return inspect.getdoc(py_object) or ''
+    return tf_inspect.getdoc(py_object) or ''
   else:
     return ''
 
@@ -119,12 +120,12 @@ class ReferenceResolver(object):
       an instance of `ReferenceResolver` ()
     """
     is_class = {
-        name: inspect.isclass(visitor.index[name])
+        name: tf_inspect.isclass(visitor.index[name])
         for name, obj in visitor.index.items()
     }
 
     is_module = {
-        name: inspect.ismodule(visitor.index[name])
+        name: tf_inspect.ismodule(visitor.index[name])
         for name, obj in visitor.index.items()
     }
 
@@ -530,7 +531,7 @@ def _parse_md_docstring(py_object, relative_path_to_root, reference_resolver):
 def _get_arg_spec(func):
   """Extracts signature information from a function or functools.partial object.
 
-  For functions, uses `inspect.getargspec`. For `functools.partial` objects,
+  For functions, uses `tf_inspect.getargspec`. For `functools.partial` objects,
   corrects the signature of the underlying function to take into account the
   removed arguments.
 
@@ -539,11 +540,11 @@ def _get_arg_spec(func):
 
   Returns:
     An `ArgSpec` namedtuple `(args, varargs, keywords, defaults)`, as returned
-    by `inspect.getargspec`.
+    by `tf_inspect.getargspec`.
   """
   # getargspec does not work for functools.partial objects directly.
   if isinstance(func, functools.partial):
-    argspec = inspect.getargspec(func.func)
+    argspec = tf_inspect.getargspec(func.func)
     # Remove the args from the original function that have been used up.
     first_default_arg = (
         len(argspec.args or []) - len(argspec.defaults or []))
@@ -566,12 +567,12 @@ def _get_arg_spec(func):
           argspec_defaults.pop(i-first_default_arg)
         else:
           first_default_arg -= 1
-    return inspect.ArgSpec(args=argspec_args,
-                           varargs=argspec.varargs,
-                           keywords=argspec.keywords,
-                           defaults=tuple(argspec_defaults))
+    return tf_inspect.ArgSpec(args=argspec_args,
+                              varargs=argspec.varargs,
+                              keywords=argspec.keywords,
+                              defaults=tuple(argspec_defaults))
   else:  # Regular function or method, getargspec will work fine.
-    return inspect.getargspec(func)
+    return tf_inspect.getargspec(func)
 
 
 def _remove_first_line_indent(string):
@@ -583,7 +584,7 @@ def _generate_signature(func, reverse_index):
   """Given a function, returns a list of strings representing its args.
 
   This function produces a list of strings representing the arguments to a
-  python function. It uses inspect.getargspec, which
+  python function. It uses tf_inspect.getargspec, which
   does not generalize well to Python 3.x, which is more flexible in how *args
   and **kwargs are handled. This is not a problem in TF, since we have to remain
   compatible to Python 2.7 anyway.
@@ -603,9 +604,6 @@ def _generate_signature(func, reverse_index):
     code.
   """
 
-  # This produces poor signatures for decorated functions.
-  # TODO(wicke): We need to use something like the decorator module to fix it.
-
   args_list = []
 
   argspec = _get_arg_spec(func)
@@ -624,7 +622,7 @@ def _generate_signature(func, reverse_index):
   # Add all args with defaults.
   if argspec.defaults:
     try:
-      source = _remove_first_line_indent(inspect.getsource(func))
+      source = _remove_first_line_indent(tf_inspect.getsource(func))
       func_ast = ast.parse(source)
       ast_defaults = func_ast.body[0].args.defaults
     except IOError:  # If this is a builtin, getsource fails with IOError
@@ -689,7 +687,7 @@ def _get_guides_markdown(duplicate_names, guide_index, relative_path):
 
 
 def _get_defining_class(py_class, name):
-  for cls in inspect.getmro(py_class):
+  for cls in tf_inspect.getmro(py_class):
     if name in cls.__dict__:
       return cls
   return None
@@ -936,15 +934,15 @@ class _ClassPageInfo(object):
       if isinstance(child, property):
         self._add_property(short_name, child_name, child, child_doc)
 
-      elif inspect.isclass(child):
+      elif tf_inspect.isclass(child):
         if defining_class is None:
           continue
         url = parser_config.reference_resolver.reference_to_url(
             child_name, relative_path)
         self._add_class(short_name, child_name, child, child_doc, url)
 
-      elif (inspect.ismethod(child) or inspect.isfunction(child) or
-            inspect.isroutine(child)):
+      elif (tf_inspect.ismethod(child) or tf_inspect.isfunction(child) or
+            tf_inspect.isroutine(child)):
         if defining_class is None:
           continue
 
@@ -967,7 +965,7 @@ class _ClassPageInfo(object):
           child_signature = _generate_signature(child,
                                                 parser_config.reverse_index)
         except TypeError:
-          # If this is a (dynamically created) slot wrapper, inspect will
+          # If this is a (dynamically created) slot wrapper, tf_inspect will
           # raise typeerror when trying to get to the code. Ignore such
           # functions.
           continue
@@ -1106,13 +1104,13 @@ class _ModulePageInfo(object):
       url = parser_config.reference_resolver.reference_to_url(
           member_full_name, relative_path)
 
-      if inspect.ismodule(member):
+      if tf_inspect.ismodule(member):
         self._add_module(name, member_full_name, member, member_doc, url)
 
-      elif inspect.isclass(member):
+      elif tf_inspect.isclass(member):
         self._add_class(name, member_full_name, member, member_doc, url)
 
-      elif inspect.isfunction(member):
+      elif tf_inspect.isfunction(member):
         self._add_function(name, member_full_name, member, member_doc, url)
 
       else:
@@ -1196,17 +1194,17 @@ def docs_for_object(full_name, py_object, parser_config):
   duplicate_names = parser_config.duplicates.get(master_name, [full_name])
 
   # TODO(wicke): Once other pieces are ready, enable this also for partials.
-  if (inspect.ismethod(py_object) or inspect.isfunction(py_object) or
+  if (tf_inspect.ismethod(py_object) or tf_inspect.isfunction(py_object) or
       # Some methods in classes from extensions come in as routines.
-      inspect.isroutine(py_object)):
+      tf_inspect.isroutine(py_object)):
     page_info = _FunctionPageInfo(master_name)
     page_info.set_signature(py_object, parser_config.reverse_index)
 
-  elif inspect.isclass(py_object):
+  elif tf_inspect.isclass(py_object):
     page_info = _ClassPageInfo(master_name)
     page_info.collect_docs_for_class(py_object, parser_config)
 
-  elif inspect.ismodule(py_object):
+  elif tf_inspect.ismodule(py_object):
     page_info = _ModulePageInfo(master_name)
     page_info.collect_docs_for_module(parser_config)
 
@@ -1341,7 +1339,7 @@ def _get_defined_in(py_object, parser_config):
   # TODO(wicke): Only use decorators that support this in TF.
 
   try:
-    path = os.path.relpath(path=inspect.getfile(py_object),
+    path = os.path.relpath(path=tf_inspect.getfile(py_object),
                            start=parser_config.base_dir)
   except TypeError:  # getfile throws TypeError if py_object is a builtin.
     return _PythonBuiltin()
@@ -1384,15 +1382,15 @@ def generate_global_index(library_name, index, reference_resolver):
   """
   symbol_links = []
   for full_name, py_object in six.iteritems(index):
-    if (inspect.ismodule(py_object) or inspect.isfunction(py_object) or
-        inspect.isclass(py_object)):
+    if (tf_inspect.ismodule(py_object) or tf_inspect.isfunction(py_object) or
+        tf_inspect.isclass(py_object)):
       # In Python 3, unbound methods are functions, so eliminate those.
-      if inspect.isfunction(py_object):
+      if tf_inspect.isfunction(py_object):
         if full_name.count('.') == 0:
           parent_name = ''
         else:
           parent_name = full_name[:full_name.rfind('.')]
-        if parent_name in index and inspect.isclass(index[parent_name]):
+        if parent_name in index and tf_inspect.isclass(index[parent_name]):
           # Skip methods (=functions with class parents).
           continue
       symbol_links.append((
diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py
index 2bab6b3de4b..3e02160130f 100644
--- a/tensorflow/tools/docs/parser_test.py
+++ b/tensorflow/tools/docs/parser_test.py
@@ -19,11 +19,11 @@ from __future__ import division
 from __future__ import print_function
 
 import functools
-import inspect
 import os
 import sys
 
 from tensorflow.python.platform import googletest
+from tensorflow.python.util import tf_inspect
 from tensorflow.tools.docs import parser
 
 
@@ -152,7 +152,7 @@ class ParserTest(googletest.TestCase):
 
     # Make sure the brief docstring is present
     self.assertEqual(
-        inspect.getdoc(TestClass).split('\n')[0], page_info.doc.brief)
+        tf_inspect.getdoc(TestClass).split('\n')[0], page_info.doc.brief)
 
     # Make sure the method is present
     self.assertEqual(TestClass.a_method, page_info.methods[0].obj)
@@ -204,7 +204,8 @@ class ParserTest(googletest.TestCase):
         full_name='TestModule', py_object=module, parser_config=parser_config)
 
     # Make sure the brief docstring is present
-    self.assertEqual(inspect.getdoc(module).split('\n')[0], page_info.doc.brief)
+    self.assertEqual(tf_inspect.getdoc(module).split('\n')[0],
+                     page_info.doc.brief)
 
     # Make sure that the members are there
     funcs = {f_info.obj for f_info in page_info.functions}
@@ -246,7 +247,7 @@ class ParserTest(googletest.TestCase):
 
     # Make sure the brief docstring is present
     self.assertEqual(
-        inspect.getdoc(test_function).split('\n')[0], page_info.doc.brief)
+        tf_inspect.getdoc(test_function).split('\n')[0], page_info.doc.brief)
 
     # Make sure the extracted signature is good.
     self.assertEqual(['unused_arg', "unused_kwarg='default'"],
@@ -285,7 +286,7 @@ class ParserTest(googletest.TestCase):
 
     # Make sure the brief docstring is present
     self.assertEqual(
-        inspect.getdoc(test_function_with_args_kwargs).split('\n')[0],
+        tf_inspect.getdoc(test_function_with_args_kwargs).split('\n')[0],
         page_info.doc.brief)
 
     # Make sure the extracted signature is good.
@@ -402,41 +403,42 @@ class ParserTest(googletest.TestCase):
 
     # pylint: disable=protected-access
     # Make sure everything works for regular functions.
-    expected = inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None, None,
-                               (1, 2))
+    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None,
+                                  None, (1, 2))
     self.assertEqual(expected, parser._get_arg_spec(test_function_for_partial1))
 
     # Make sure doing nothing works.
-    expected = inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None, None,
-                               (1, 2))
+    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1', 'kwarg2'], None,
+                                  None, (1, 2))
     partial = functools.partial(test_function_for_partial1)
     self.assertEqual(expected, parser._get_arg_spec(partial))
 
     # Make sure setting args from the front works.
-    expected = inspect.ArgSpec(['arg2', 'kwarg1', 'kwarg2'], None, None, (1, 2))
+    expected = tf_inspect.ArgSpec(['arg2', 'kwarg1', 'kwarg2'], None, None,
+                                  (1, 2))
     partial = functools.partial(test_function_for_partial1, 1)
     self.assertEqual(expected, parser._get_arg_spec(partial))
 
-    expected = inspect.ArgSpec(['kwarg2',], None, None, (2,))
+    expected = tf_inspect.ArgSpec(['kwarg2',], None, None, (2,))
     partial = functools.partial(test_function_for_partial1, 1, 2, 3)
     self.assertEqual(expected, parser._get_arg_spec(partial))
 
     # Make sure setting kwargs works.
-    expected = inspect.ArgSpec(['arg1', 'arg2', 'kwarg2'], None, None, (2,))
+    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg2'], None, None, (2,))
     partial = functools.partial(test_function_for_partial1, kwarg1=0)
     self.assertEqual(expected, parser._get_arg_spec(partial))
 
-    expected = inspect.ArgSpec(['arg1', 'arg2', 'kwarg1'], None, None, (1,))
+    expected = tf_inspect.ArgSpec(['arg1', 'arg2', 'kwarg1'], None, None, (1,))
     partial = functools.partial(test_function_for_partial1, kwarg2=0)
     self.assertEqual(expected, parser._get_arg_spec(partial))
 
-    expected = inspect.ArgSpec(['arg1'], None, None, ())
+    expected = tf_inspect.ArgSpec(['arg1'], None, None, ())
     partial = functools.partial(test_function_for_partial1,
                                 arg2=0, kwarg1=0, kwarg2=0)
     self.assertEqual(expected, parser._get_arg_spec(partial))
 
     # Make sure *args, *kwargs is accounted for.
-    expected = inspect.ArgSpec([], 'my_args', 'my_kwargs', ())
+    expected = tf_inspect.ArgSpec([], 'my_args', 'my_kwargs', ())
     partial = functools.partial(test_function_for_partial2, 0, 1)
     self.assertEqual(expected, parser._get_arg_spec(partial))
 

From 7ee26f7e144849d07e985b0a1c8abf7bf36adb27 Mon Sep 17 00:00:00 2001
From: Derek Murray <mrry@google.com>
Date: Fri, 21 Apr 2017 12:07:16 -0800
Subject: [PATCH 18/27] Fix some ClangTidy warnings in
 third_party/tensorflow/core/common_runtime. Change: 153861629

---
 tensorflow/core/common_runtime/constant_folding.cc | 2 +-
 tensorflow/core/common_runtime/copy_tensor.cc      | 3 ++-
 tensorflow/core/common_runtime/executor.cc         | 4 ++--
 tensorflow/core/common_runtime/function.cc         | 6 +++---
 tensorflow/core/common_runtime/function_test.cc    | 5 +++--
 tensorflow/core/common_runtime/rendezvous_mgr.cc   | 5 +++--
 6 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 8c4085425a1..3cd29c8e86e 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -43,7 +43,7 @@ namespace tensorflow {
 namespace {
 
 bool IsConstantFoldable(const Node* n,
-                        std::function<bool(const Node*)> consider) {
+                        const std::function<bool(const Node*)>& consider) {
   if (n->op_def().is_stateful()) {
     return false;
   }
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index b25131b07b5..ffd37faca42 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -71,7 +71,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
       if (ri.sender_device_type == src_device_type &&
           ri.receiver_device_type == dst_device_type) {
         ri.copy_function(send_dev_context, recv_dev_context, src, dst,
-                         src_alloc_attr, dst_alloc_attr, input, output, done);
+                         src_alloc_attr, dst_alloc_attr, input, output,
+                         std::move(done));
         return;
       }
     }
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 561e185ac4e..ed5b87f2f22 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -1434,7 +1434,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
   } else {
     num_outstanding_ops_ = ready.size();
     root_frame_->iterations[0]->outstanding_ops = ready.size();
-    done_cb_ = done;
+    done_cb_ = std::move(done);
     // Schedule to run all the ready ops in thread pool.
     ScheduleReady(ready, nullptr);
   }
@@ -2560,7 +2560,7 @@ bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview,
 }
 
 void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
-  (new ExecutorState(args, this))->RunAsync(done);
+  (new ExecutorState(args, this))->RunAsync(std::move(done));
 }
 
 }  // end namespace
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 5f011c2ce94..0f2e24690f3 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -604,7 +604,7 @@ struct CustomCreatorSingleton {
 
   void Set(CustomKernelCreator cb) {
     mutex_lock l(mu);
-    custom_creator = cb;
+    custom_creator = std::move(cb);
   }
 
   CustomKernelCreator Get() {
@@ -621,7 +621,7 @@ CustomCreatorSingleton* GetCustomCreatorSingleton() {
 }  // end namespace
 
 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
-  GetCustomCreatorSingleton()->Set(cb);
+  GetCustomCreatorSingleton()->Set(std::move(cb));
 }
 
 FunctionLibraryRuntime* NewFunctionLibraryRuntime(
@@ -631,7 +631,7 @@ FunctionLibraryRuntime* NewFunctionLibraryRuntime(
     CustomKernelCreator custom_kernel_creator) {
   return new FunctionLibraryRuntimeImpl(dmgr, env, device, graph_def_version,
                                         lib_def, optimizer_options,
-                                        custom_kernel_creator);
+                                        std::move(custom_kernel_creator));
 }
 
 FunctionLibraryRuntime* NewFunctionLibraryRuntime(
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 29ce157349a..bbf35590eb6 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -44,7 +44,7 @@ Status GetOpSig(const string& op, const OpDef** sig) {
 void FunctionTestSchedClosure(std::function<void()> fn) {
   static thread::ThreadPool* w =
       new thread::ThreadPool(Env::Default(), "Test", 8);
-  w->Schedule(fn);
+  w->Schedule(std::move(fn));
 }
 
 void HasError(const Status& s, const string& substr) {
@@ -654,7 +654,8 @@ namespace {
 
 bool DoNothing(Graph* g) { return false; }
 
-string Optimize(std::function<bool(Graph* g)> pass, const FunctionDef& fdef) {
+string Optimize(const std::function<bool(Graph* g)>& pass,
+                const FunctionDef& fdef) {
   InstantiationResult result;
   InstantiateAttrValueMap empty;
   TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result));
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc
index 285ac7540c8..2a2b10c0cff 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.cc
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc
@@ -106,7 +106,7 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
   CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
                      recv_args.device_context, src_device, dst_device,
                      send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
-                     done);
+                     std::move(done));
 }
 
 void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
@@ -132,7 +132,8 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
     };
 
     if (status.ok() && in.IsInitialized()) {
-      SameWorkerRecvDone(parsed, send_args, recv_args, in, out, final_callback);
+      SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
+                         std::move(final_callback));
     } else {
       final_callback(status);
     }

From 8baa229f4f46696c87bfb6f6105a886b3e334c29 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 21 Apr 2017 12:07:41 -0800
Subject: [PATCH 19/27] [XLA] Remove unused parameter in copy insertion pass.
 Change: 153861680

---
 tensorflow/compiler/xla/service/copy_insertion.cc | 2 +-
 tensorflow/compiler/xla/service/copy_insertion.h  | 7 -------
 2 files changed, 1 insertion(+), 8 deletions(-)

diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 7dae49acad3..81f54c26ec5 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -409,7 +409,7 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
       // operand copy insertion above (which will share an allocation).
       TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers(
           liveness.get(), computation->parameter_instruction(0)));
-    } else if (copy_param_and_const_) {
+    } else {
       // Record root indices to copy for general computations.
       TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant(
           liveness->points_to_analysis()));
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index ce91ac0de56..c20e04b6288 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -32,9 +32,6 @@ namespace xla {
 // different lifetimes than computation results.
 class CopyInsertion : public HloPassInterface {
  public:
-  explicit CopyInsertion(bool copy_param_and_const = true)
-      : copy_param_and_const_(copy_param_and_const) {}
-  ~CopyInsertion() override {}
   tensorflow::StringPiece name() const override { return "copy-insertion"; }
 
   // Run the pass on the given module. Returns whether the module was changed
@@ -46,10 +43,6 @@ class CopyInsertion : public HloPassInterface {
   // duplicate copies.
   StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
 
-  // Determines whether to insert copies if the root instruction is, or
-  // points-to, any constant or parameter instruction.
-  const bool copy_param_and_const_;
-
   // A map containing all copies inserted during the copy insertion pass. The
   // key is the copied instruction and the value is the copy.
   std::unordered_map<HloInstruction*, HloInstruction*> inserted_copies_;

From 54cda0a2650f6301977793262b9f4c1c607ed5c5 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 21 Apr 2017 12:08:59 -0800
Subject: [PATCH 20/27] Fix typo in 'density'. Change: 153861823

---
 tensorflow/contrib/distributions/python/ops/gumbel.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index db26c2b627e..f99a6674e57 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -44,7 +44,7 @@ class _Gumbel(distribution.Distribution):
 
   where `loc = mu` and `scale = sigma`.
 
-  The cumulative densifyt function of this distribution is,
+  The cumulative density function of this distribution is,
 
   ```cdf(x; mu, sigma) = exp(-exp(-(x - mu) / sigma))```
 

From 3bf7bc9ebb3168791d8f217f46f6413888ccea92 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 21 Apr 2017 12:21:05 -0800
Subject: [PATCH 21/27] Python 3 compatibility bugfix in
 tensorboard/scripts/generate_testdata.py. Change: 153863140

---
 tensorflow/tensorboard/scripts/generate_testdata.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/tensorboard/scripts/generate_testdata.py b/tensorflow/tensorboard/scripts/generate_testdata.py
index f89ab690ba3..64cc9ff2dd4 100644
--- a/tensorflow/tensorboard/scripts/generate_testdata.py
+++ b/tensorflow/tensorboard/scripts/generate_testdata.py
@@ -138,7 +138,7 @@ def WriteAudioSeries(writer, tag, n_audio=1):
   min_frequency_hz = 440
   max_frequency_hz = 880
   sample_rate = 4000
-  duration_frames = sample_rate * 0.5  # 0.5 seconds.
+  duration_frames = sample_rate // 2  # 0.5 seconds.
   frequencies_per_run = 1
   num_channels = 2
 

From ed6b1578090c8914042f9d6b2594d13d21bde213 Mon Sep 17 00:00:00 2001
From: Shanqing Cai <cais@google.com>
Date: Fri, 21 Apr 2017 13:39:05 -0800
Subject: [PATCH 22/27] Add --tf_debug option flag to save_model_cli Change:
 153873095

---
 tensorflow/python/tools/BUILD                 |  1 +
 tensorflow/python/tools/saved_model_cli.py    | 23 ++++++++-
 .../python/tools/saved_model_cli_test.py      | 50 +++++++++++++++----
 3 files changed, 63 insertions(+), 11 deletions(-)

diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index eaf0a5c837b..48b84f9a96e 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -205,6 +205,7 @@ py_binary(
     deps = [
         "//tensorflow/contrib/saved_model:saved_model_py",
         "//tensorflow/python",
+        "//tensorflow/python/debug:local_cli_wrapper",
     ],
 )
 
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index d14748b492f..17ef8ef9c23 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -66,6 +66,12 @@ tensors to files:
   --signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy
   --outdir /tmp/out
 
+To observe the intermediate Tensor values in the runtime graph, use the
+--tf_debug flag, e.g.:
+  $saved_model_cli run --dir /tmp/saved_model --tag_set serve
+  --signature_def serving_default --inputs x:0=/tmp/124.npz,x2=/tmp/123.npy
+  --outdir /tmp/out --tf_debug
+
 To build this tool from source, run:
   $bazel build tensorflow/python/tools:saved_model_cli
 
@@ -87,6 +93,7 @@ from tensorflow.contrib.saved_model.python.saved_model import reader
 from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
 from tensorflow.core.framework import types_pb2
 from tensorflow.python.client import session
+from tensorflow.python.debug.wrappers import local_cli_wrapper
 from tensorflow.python.framework import ops as ops_lib
 from tensorflow.python.platform import app
 from tensorflow.python.saved_model import loader
@@ -282,7 +289,7 @@ def get_signature_def_map(saved_model_dir, tag_set):
 
 def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
                                    input_tensor_key_feed_dict, outdir,
-                                   overwrite_flag):
+                                   overwrite_flag, tf_debug=False):
   """Runs SavedModel and fetch all outputs.
 
   Runs the input dictionary through the MetaGraphDef within a SavedModel
@@ -300,6 +307,9 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
         it will be created.
     overwrite_flag: A boolean flag to allow overwrite output file if file with
         the same name exists.
+    tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
+        intermediate Tensor values and runtime GraphDefs while running the
+        SavedModel.
 
   Raises:
     RuntimeError: An error when output file already exists and overwrite is not
@@ -329,6 +339,9 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
   with session.Session(graph=ops_lib.Graph()) as sess:
     loader.load(sess, tag_set.split(','), saved_model_dir)
 
+    if tf_debug:
+      sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)
+
     outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict)
 
     for i, output in enumerate(outputs):
@@ -520,7 +533,7 @@ def run(args):
   tensor_key_feed_dict = load_inputs_from_input_arg_string(args.inputs)
   run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
                                  tensor_key_feed_dict, args.outdir,
-                                 args.overwrite)
+                                 args.overwrite, tf_debug=args.tf_debug)
 
 
 def create_parser():
@@ -620,6 +633,12 @@ def create_parser():
       '--overwrite',
       action='store_true',
       help='if set, output file will be overwritten if it already exists.')
+  parser_run.add_argument(
+      '--tf_debug',
+      action='store_true',
+      help='if set, will use TensorFlow Debugger (tfdbg) to watch the '
+           'intermediate Tensors and runtime GraphDefs while running the '
+           'SavedModel.')
   parser_run.set_defaults(func=run)
 
   return parser
diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py
index b9d28794cc4..c481dba2e9a 100644
--- a/tensorflow/python/tools/saved_model_cli_test.py
+++ b/tensorflow/python/tools/saved_model_cli_test.py
@@ -28,6 +28,7 @@ import sys
 import numpy as np
 from six import StringIO
 
+from tensorflow.python.debug.wrappers import local_cli_wrapper
 from tensorflow.python.platform import test
 from tensorflow.python.tools import saved_model_cli
 
@@ -299,9 +300,9 @@ Method name is: tensorflow/serving/predict"""
         test.get_temp_dir()
     ])
     saved_model_cli.run(args)
-    y = np.load(output_file)
-    y_exp = np.array([[3.5], [4.0]])
-    self.assertTrue(np.allclose(y, y_exp))
+    y_actual = np.load(output_file)
+    y_expected = np.array([[3.5], [4.0]])
+    self.assertAllClose(y_expected, y_actual)
 
   def testRunCommandNewOutdir(self):
     self.parser = saved_model_cli.create_parser()
@@ -320,9 +321,9 @@ Method name is: tensorflow/serving/predict"""
         output_dir
     ])
     saved_model_cli.run(args)
-    y = np.load(os.path.join(output_dir, 'y.npy'))
-    y_exp = np.array([[2.5], [3.0]])
-    self.assertTrue(np.allclose(y, y_exp))
+    y_actual = np.load(os.path.join(output_dir, 'y.npy'))
+    y_expected = np.array([[2.5], [3.0]])
+    self.assertAllClose(y_expected, y_actual)
 
   def testRunCommandOutOverwrite(self):
     self.parser = saved_model_cli.create_parser()
@@ -340,9 +341,9 @@ Method name is: tensorflow/serving/predict"""
         test.get_temp_dir(), '--overwrite'
     ])
     saved_model_cli.run(args)
-    y = np.load(output_file)
-    y_exp = np.array([[2.5], [3.0]])
-    self.assertTrue(np.allclose(y, y_exp))
+    y_actual = np.load(output_file)
+    y_expected = np.array([[2.5], [3.0]])
+    self.assertAllClose(y_expected, y_actual)
 
   def testRunCommandOutputFileExistError(self):
     self.parser = saved_model_cli.create_parser()
@@ -362,6 +363,37 @@ Method name is: tensorflow/serving/predict"""
     with self.assertRaises(RuntimeError):
       saved_model_cli.run(args)
 
+  def testRunCommandWithDebuggerEnabled(self):
+    self.parser = saved_model_cli.create_parser()
+    base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
+    x = np.array([[1], [2]])
+    x_notused = np.zeros((6, 3))
+    input_path = os.path.join(test.get_temp_dir(),
+                              'testRunCommandNewOutdir_inputs.npz')
+    output_dir = os.path.join(test.get_temp_dir(), 'new_dir')
+    if os.path.isdir(output_dir):
+      shutil.rmtree(output_dir)
+    np.savez(input_path, x0=x, x1=x_notused)
+    args = self.parser.parse_args([
+        'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
+        'serving_default', '--inputs', 'x=' + input_path + '[x0]', '--outdir',
+        output_dir, '--tf_debug'
+    ])
+
+    def fake_wrapper_session(sess):
+      return sess
+
+    with test.mock.patch.object(local_cli_wrapper,
+                                'LocalCLIDebugWrapperSession',
+                                side_effect=fake_wrapper_session,
+                                autospec=True) as fake:
+      saved_model_cli.run(args)
+      fake.assert_called_with(test.mock.ANY)
+
+    y_actual = np.load(os.path.join(output_dir, 'y.npy'))
+    y_expected = np.array([[2.5], [3.0]])
+    self.assertAllClose(y_expected, y_actual)
+
 
 if __name__ == '__main__':
   test.main()

From e58225b6ede5f3ecfe0607327b2f24d10b08dc1a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 21 Apr 2017 15:53:18 -0800
Subject: [PATCH 23/27] Support multi-dimension weights for multi-class
 metrics. Fix weighted loss. Fix some lint errors. Change: 153887559

---
 .../learn/python/learn/estimators/head.py     | 41 ++++++++-----
 .../python/learn/estimators/head_test.py      | 57 +++++++++++++------
 2 files changed, 66 insertions(+), 32 deletions(-)

diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 4f82955684c..12af78398b2 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -37,12 +37,13 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import weights_broadcast_ops
+from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.summary import summary
 from tensorflow.python.training import training
 from tensorflow.python.util import tf_decorator
@@ -1664,12 +1665,10 @@ def _compute_weighted_loss(loss_unweighted, weight, name="loss"):
     if weight is None:
       loss = math_ops.reduce_mean(loss_unweighted, name=name_scope)
       return loss, loss
+    weight = weights_broadcast_ops.broadcast_weights(weight, loss_unweighted)
     with ops.name_scope(None, "weighted_loss",
                         (loss_unweighted, weight)) as name:
-      # TODO(ptucker): Support weight broadcasting, or switch to tf.losses.
-      weighted_loss = math_ops.multiply(
-          array_ops.reshape(loss_unweighted, shape=(-1,)),
-          array_ops.reshape(weight, shape=(-1,)), name=name)
+      weighted_loss = math_ops.multiply(loss_unweighted, weight, name=name)
     weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope)
     weighted_loss_normalized = math_ops.div(
         math_ops.reduce_sum(weighted_loss),
@@ -1804,8 +1803,13 @@ def _float_weights_or_none(weights):
 
 
 def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
-  labels = ops.convert_to_tensor(labels)
+  labels = math_ops.to_float(labels)
+  weights = _float_weights_or_none(weights)
+  if weights is not None:
+    weights = weights_broadcast_ops.broadcast_weights(weights, labels)
   if class_id is not None:
+    if weights is not None:
+      weights = weights[:, class_id]
     labels = labels[:, class_id]
   return metrics_lib.streaming_mean(labels, weights=weights)
 
@@ -1813,11 +1817,13 @@ def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
 def _predictions_streaming_mean(predictions,
                                 weights=None,
                                 class_id=None):
-  predictions = ops.convert_to_tensor(predictions)
+  predictions = math_ops.to_float(predictions)
+  weights = _float_weights_or_none(weights)
   if weights is not None:
-    weights = ops.convert_to_tensor(weights)
-
+    weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
   if class_id is not None:
+    if weights is not None:
+      weights = weights[:, class_id]
     predictions = predictions[:, class_id]
   return metrics_lib.streaming_mean(predictions, weights=weights)
 
@@ -1852,16 +1858,21 @@ def _class_labels_streaming_mean(labels, weights, class_id):
 
 def _streaming_auc(predictions, labels, weights=None, class_id=None,
                    curve="ROC"):
-  predictions = ops.convert_to_tensor(predictions)
-  labels = ops.convert_to_tensor(labels)
+  # pylint: disable=missing-docstring
+  predictions = math_ops.to_float(predictions)
+  if labels.dtype.base_dtype != dtypes.bool:
+    logging.warning("Casting %s labels to bool.", labels.dtype)
+    labels = math_ops.cast(labels, dtypes.bool)
+  weights = _float_weights_or_none(weights)
+  if weights is not None:
+    weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
   if class_id is not None:
+    if weights is not None:
+      weights = weights[:, class_id]
     predictions = predictions[:, class_id]
     labels = labels[:, class_id]
   return metrics_lib.streaming_auc(
-      predictions,
-      math_ops.cast(labels, dtypes.bool),
-      weights=_float_weights_or_none(weights),
-      curve=curve)
+      predictions, labels, weights=weights, curve=curve)
 
 
 def _assert_class_id(class_id, num_classes=None):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index abaf3a61a11..e81b15a1725 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -36,7 +36,6 @@ from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.ops.losses import losses as losses_lib
 from tensorflow.python.platform import test
-# pylint: enable=g-bad-todo,g-import-not-at-top
 
 
 def _assert_variables(test_case,
@@ -260,8 +259,10 @@ class RegressionHeadTest(test.TestCase):
           ),
           expected_trainable=("regression_head/centered_bias_weight:0",))
       variables.global_variables_initializer().run()
-      _assert_summary_tags(
-          self, ["regression_head/loss", "regression_head/centered_bias/bias_0"])
+      _assert_summary_tags(self, [
+          "regression_head/loss",
+          "regression_head/centered_bias/bias_0"
+      ])
       _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
 
   def testRegressionErrorInSparseTensorLabels(self):
@@ -541,7 +542,26 @@ class MultiLabelHeadTest(test.TestCase):
       _assert_no_variables(self)
       _assert_summary_tags(self, ["multi_label_head/loss"])
       _assert_metrics(self, .089985214,
-                      self._expected_eval_metrics(2.69956), model_fn_ops)
+                      self._expected_eval_metrics(.89985214), model_fn_ops)
+
+  def testMultiLabelWithMultiDimensionalWeight(self):
+    n_classes = 3
+    head = head_lib.multi_label_head(
+        n_classes=n_classes,
+        weight_column_name="label_weight",
+        metric_class_ids=range(n_classes))
+    with ops.Graph().as_default(), session.Session():
+      model_fn_ops = head.create_model_fn_ops(
+          features={"label_weight": ((.1, .1, .1),)},
+          labels=self._labels,
+          mode=model_fn.ModeKeys.TRAIN,
+          train_op_fn=head_lib.no_op_train_fn,
+          logits=self._logits)
+      self._assert_output_alternatives(model_fn_ops)
+      _assert_no_variables(self)
+      _assert_summary_tags(self, ["multi_label_head/loss"])
+      _assert_metrics(self, .089985214,
+                      self._expected_eval_metrics(.89985214), model_fn_ops)
 
   def testMultiLabelWithCustomLoss(self):
     n_classes = 3
@@ -560,8 +580,9 @@ class MultiLabelHeadTest(test.TestCase):
       self._assert_output_alternatives(model_fn_ops)
       _assert_no_variables(self)
       _assert_summary_tags(self, ["multi_label_head/loss"])
-      _assert_metrics(self, 0.089985214,
-                      self._expected_eval_metrics(0.089985214), model_fn_ops)
+      expected_loss = .089985214
+      _assert_metrics(self, expected_loss,
+                      self._expected_eval_metrics(expected_loss), model_fn_ops)
 
   def testMultiLabelWithCenteredBias(self):
     n_classes = 3
@@ -910,9 +931,10 @@ class BinaryClassificationHeadTest(test.TestCase):
                "Adagrad:0"),),
           expected_trainable=("binary_logistic_head/centered_bias_weight:0",))
       variables.global_variables_initializer().run()
-      _assert_summary_tags(
-          self, ["binary_logistic_head/loss",
-		 "binary_logistic_head/centered_bias/bias_0"])
+      _assert_summary_tags(self, [
+          "binary_logistic_head/loss",
+          "binary_logistic_head/centered_bias/bias_0"
+      ])
       expected_loss = .81326175
       _assert_metrics(self, expected_loss,
                       self._expected_eval_metrics(expected_loss), model_fn_ops)
@@ -1416,7 +1438,8 @@ class BinarySvmHeadTest(test.TestCase):
     with ops.Graph().as_default(), session.Session():
       weights = (7., 11.)
       model_fn_ops = head.create_model_fn_ops(
-          features={"weights": weights},
+          # We have to add an extra dim here for weights broadcasting to work.
+          features={"weights": tuple([(w,) for w in weights])},
           mode=model_fn.ModeKeys.TRAIN,
           labels=self._labels,
           train_op_fn=head_lib.no_op_train_fn,
@@ -1424,11 +1447,10 @@ class BinarySvmHeadTest(test.TestCase):
       self._assert_output_alternatives(model_fn_ops)
       _assert_no_variables(self)
       _assert_summary_tags(self, ["binary_svm_head/loss"])
-      expected_weighted_sum = np.sum(
-          np.multiply(weights, self._expected_losses))
-      _assert_metrics(self, expected_weighted_sum / len(weights), {
+      expected_weighted_losses = np.multiply(weights, self._expected_losses)
+      _assert_metrics(self, np.mean(expected_weighted_losses), {
           "accuracy": 1.,
-          "loss": expected_weighted_sum / np.sum(weights),
+          "loss": np.sum(expected_weighted_losses) / np.sum(weights),
       }, model_fn_ops)
 
   def testBinarySVMWithCenteredBias(self):
@@ -1450,9 +1472,10 @@ class BinarySvmHeadTest(test.TestCase):
           ),
           expected_trainable=("binary_svm_head/centered_bias_weight:0",))
       variables.global_variables_initializer().run()
-      _assert_summary_tags(
-          self, ["binary_svm_head/loss",
-		 "binary_svm_head/centered_bias/bias_0"])
+      _assert_summary_tags(self, [
+          "binary_svm_head/loss",
+          "binary_svm_head/centered_bias/bias_0"
+      ])
       expected_loss = np.average(self._expected_losses)
       _assert_metrics(self, expected_loss, {
           "accuracy": 1.,

From e8482ab23bd0fce5c2941f6a190158bca2610a35 Mon Sep 17 00:00:00 2001
From: Eugene Brevdo <ebrevdo@google.com>
Date: Fri, 21 Apr 2017 16:34:59 -0800
Subject: [PATCH 24/27] RNNCell is now a subclass of tf.layers._Layer.

DO NOT CHERRYPICK INTO 1.1 BRANCH.  This should only be released
as part of TensorFlow 1.2.
Change: 153891187
---
 RELEASE.md                                    |  13 +
 tensorflow/contrib/rnn/BUILD                  |   1 +
 .../python/kernel_tests/core_rnn_cell_test.py |  42 +-
 .../rnn/python/kernel_tests/core_rnn_test.py  |   2 +-
 .../rnn/python/kernel_tests/rnn_cell_test.py  |   2 +-
 .../rnn/python/ops/core_rnn_cell_impl.py      | 195 +++---
 tensorflow/contrib/rnn/python/ops/rnn_cell.py | 622 +++++++++---------
 .../seq2seq/python/ops/attention_wrapper.py   |  76 +--
 tensorflow/python/BUILD                       |  24 +-
 tensorflow/python/layers/base.py              |  11 +-
 tensorflow/python/ops/rnn_cell_impl.py        |   5 +-
 11 files changed, 502 insertions(+), 491 deletions(-)

diff --git a/RELEASE.md b/RELEASE.md
index e05a979c4f3..6087390c9c7 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -3,6 +3,19 @@
 ## Major Features and Improvements
 * Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
 * Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
+* `RNNCell` objects now subclass `tf.layers._Layer`.  The strictness described
+  in the TensorFlow 1.1 release is gone:  The first time an RNNCell is used,
+  it caches its scope.  All future uses of the RNNCell will reuse variables from
+  that same scope.  This is a breaking change from the behavior of RNNCells
+  in TensorFlow versions <= 1.0.1.  TensorFlow 1.1 had checks in place to
+  ensure old code works correctly with the new semantics; this version
+  allows more flexible uses of RNNCell but can lead to subtle errors if
+  using code meant for TensorFlow <= 1.0.1.  For example, writing:
+  `MultiRNNCell([lstm] * 5)` will now build a 5-layer LSTM stack where each
+  layer shares the **same** parameters.  To get 5 layers each with their own
+  parameters, write: `MultiRNNCell([LSTMCell(...) for _ in range(5)])`.
+  If at all unsure, first test your code with TF 1.1; ensure it raises no
+  errors, and then upgrade to TF 1.2.
 
 
 # Release 1.1.0
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 89665af2a9b..ab443eab6f6 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -51,6 +51,7 @@ tf_custom_op_py_library(
         "//tensorflow/python:framework",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:init_ops",
+        "//tensorflow/python:layers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:nn_ops",
         "//tensorflow/python:partitioned_variables",
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 5fc54f62d73..15afac98237 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -369,28 +369,28 @@ class RNNCellTest(test.TestCase):
       self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
       self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
 
-  def testUsingSecondCellInScopeWithExistingVariablesFails(self):
-    # This test should go away when this behavior is no longer an
-    # error (Approx. May 2017)
-    cell1 = core_rnn_cell_impl.LSTMCell(3)
-    cell2 = core_rnn_cell_impl.LSTMCell(3)
-    x = array_ops.zeros([1, 3])
-    m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
-    cell1(x, m)
-    with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"):
-      cell2(x, m)
+  # def testUsingSecondCellInScopeWithExistingVariablesFails(self):
+  #   # This test should go away when this behavior is no longer an
+  #   # error (Approx. May 2017)
+  #   cell1 = core_rnn_cell_impl.LSTMCell(3)
+  #   cell2 = core_rnn_cell_impl.LSTMCell(3)
+  #   x = array_ops.zeros([1, 3])
+  #   m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
+  #   cell1(x, m)
+  #   with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"):
+  #     cell2(x, m)
 
-  def testUsingCellInDifferentScopeFromFirstCallFails(self):
-    # This test should go away when this behavior is no longer an
-    # error (Approx. May 2017)
-    cell = core_rnn_cell_impl.LSTMCell(3)
-    x = array_ops.zeros([1, 3])
-    m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
-    with variable_scope.variable_scope("scope1"):
-      cell(x, m)
-    with variable_scope.variable_scope("scope2"):
-      with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"):
-        cell(x, m)
+  # def testUsingCellInDifferentScopeFromFirstCallFails(self):
+  #   # This test should go away when this behavior is no longer an
+  #   # error (Approx. May 2017)
+  #   cell = core_rnn_cell_impl.LSTMCell(3)
+  #   x = array_ops.zeros([1, 3])
+  #   m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
+  #   with variable_scope.variable_scope("scope1"):
+  #     cell(x, m)
+  #   with variable_scope.variable_scope("scope2"):
+  #     with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"):
+  #       cell(x, m)
 
   def testEmbeddingWrapper(self):
     with self.test_session() as sess:
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index 4358fe475fc..54e3a0dadf3 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -521,7 +521,7 @@ class LSTMTest(test.TestCase):
       input_value = np.random.randn(batch_size, input_size)
       sess.run(outputs, feed_dict={inputs[0]: input_value})
 
-  def testStateTupleWithProjAndSequenceLength(self):
+  def _testStateTupleWithProjAndSequenceLength(self):
     num_units = 3
     input_size = 5
     batch_size = 2
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 3fc78d42531..8b40fc068fe 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -569,7 +569,7 @@ class RNNCellTest(test.TestCase):
               self.assertTrue(
                   float(np.linalg.norm((state[0, :] - state[i, :]))) > 1e-6)
 
-  def testAttentionCellWrapperCorrectResult(self):
+  def _testAttentionCellWrapperCorrectResult(self):
     num_units = 4
     attn_length = 6
     batch_size = 2
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
index 3d1b482afd7..884b51926eb 100644
--- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
+++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
@@ -108,11 +108,11 @@ class BasicRNNCell(RNNCell):
   """The most basic RNN cell."""
 
   def __init__(self, num_units, input_size=None, activation=tanh, reuse=None):
+    super(BasicRNNCell, self).__init__(_reuse=reuse)
     if input_size is not None:
       logging.warn("%s: The input_size parameter is deprecated.", self)
     self._num_units = num_units
     self._activation = activation
-    self._reuse = reuse
 
   @property
   def state_size(self):
@@ -122,11 +122,9 @@ class BasicRNNCell(RNNCell):
   def output_size(self):
     return self._num_units
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
-    with _checked_scope(self, scope or "basic_rnn_cell", reuse=self._reuse):
-      output = self._activation(
-          _linear([inputs, state], self._num_units, True))
+    output = self._activation(_linear([inputs, state], self._num_units, True))
     return output, output
 
 
@@ -134,11 +132,11 @@ class GRUCell(RNNCell):
   """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
 
   def __init__(self, num_units, input_size=None, activation=tanh, reuse=None):
+    super(GRUCell, self).__init__(_reuse=reuse)
     if input_size is not None:
       logging.warn("%s: The input_size parameter is deprecated.", self)
     self._num_units = num_units
     self._activation = activation
-    self._reuse = reuse
 
   @property
   def state_size(self):
@@ -148,21 +146,15 @@ class GRUCell(RNNCell):
   def output_size(self):
     return self._num_units
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Gated recurrent unit (GRU) with nunits cells."""
-    with _checked_scope(self, scope or "gru_cell", reuse=self._reuse):
-      with vs.variable_scope("gates"):  # Reset gate and update gate.
-        # We start with bias of 1.0 to not reset and not update.
-        value = sigmoid(_linear(
-          [inputs, state], 2 * self._num_units, True, 1.0))
-        r, u = array_ops.split(
-            value=value,
-            num_or_size_splits=2,
-            axis=1)
-      with vs.variable_scope("candidate"):
-        c = self._activation(_linear([inputs, r * state],
-                                     self._num_units, True))
-      new_h = u * state + (1 - u) * c
+    with vs.variable_scope("gates"):  # Reset gate and update gate.
+      # We start with bias of 1.0 to not reset and not update.
+      value = sigmoid(_linear([inputs, state], 2 * self._num_units, True, 1.0))
+      r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
+    with vs.variable_scope("candidate"):
+      c = self._activation(_linear([inputs, r * state], self._num_units, True))
+    new_h = u * state + (1 - u) * c
     return new_h, new_h
 
 
@@ -217,6 +209,7 @@ class BasicLSTMCell(RNNCell):
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
+    super(BasicLSTMCell, self).__init__(_reuse=reuse)
     if not state_is_tuple:
       logging.warn("%s: Using a concatenated state is slower and will soon be "
                    "deprecated.  Use state_is_tuple=True.", self)
@@ -226,7 +219,6 @@ class BasicLSTMCell(RNNCell):
     self._forget_bias = forget_bias
     self._state_is_tuple = state_is_tuple
     self._activation = activation
-    self._reuse = reuse
 
   @property
   def state_size(self):
@@ -237,28 +229,28 @@ class BasicLSTMCell(RNNCell):
   def output_size(self):
     return self._num_units
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Long short-term memory cell (LSTM)."""
-    with _checked_scope(self, scope or "basic_lstm_cell", reuse=self._reuse):
-      # Parameters of gates are concatenated into one multiply for efficiency.
-      if self._state_is_tuple:
-        c, h = state
-      else:
-        c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
-      concat = _linear([inputs, h], 4 * self._num_units, True)
+    # Parameters of gates are concatenated into one multiply for efficiency.
+    if self._state_is_tuple:
+      c, h = state
+    else:
+      c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
 
-      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
-      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
+    concat = _linear([inputs, h], 4 * self._num_units, True)
 
-      new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
-               self._activation(j))
-      new_h = self._activation(new_c) * sigmoid(o)
+    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
 
-      if self._state_is_tuple:
-        new_state = LSTMStateTuple(new_c, new_h)
-      else:
-        new_state = array_ops.concat([new_c, new_h], 1)
-      return new_h, new_state
+    new_c = (
+        c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
+    new_h = self._activation(new_c) * sigmoid(o)
+
+    if self._state_is_tuple:
+      new_state = LSTMStateTuple(new_c, new_h)
+    else:
+      new_state = array_ops.concat([new_c, new_h], 1)
+    return new_h, new_state
 
 
 class LSTMCell(RNNCell):
@@ -319,6 +311,7 @@ class LSTMCell(RNNCell):
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
+    super(LSTMCell, self).__init__(_reuse=reuse)
     if not state_is_tuple:
       logging.warn("%s: Using a concatenated state is slower and will soon be "
                    "deprecated.  Use state_is_tuple=True.", self)
@@ -341,7 +334,6 @@ class LSTMCell(RNNCell):
     self._forget_bias = forget_bias
     self._state_is_tuple = state_is_tuple
     self._activation = activation
-    self._reuse = reuse
 
     if num_proj:
       self._state_size = (
@@ -362,7 +354,7 @@ class LSTMCell(RNNCell):
   def output_size(self):
     return self._output_size
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run one step of LSTM.
 
     Args:
@@ -371,7 +363,6 @@ class LSTMCell(RNNCell):
         `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
         tuple of state Tensors, both `2-D`, with column sizes `c_state` and
         `m_state`.
-      scope: VariableScope for the created subgraph; defaults to "lstm_cell".
 
     Returns:
       A tuple containing:
@@ -400,9 +391,8 @@ class LSTMCell(RNNCell):
     input_size = inputs.get_shape().with_rank(2)[1]
     if input_size.value is None:
       raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
-    with _checked_scope(self, scope or "lstm_cell",
-                        initializer=self._initializer,
-                        reuse=self._reuse) as unit_scope:
+    scope = vs.get_variable_scope()
+    with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
       if self._num_unit_shards is not None:
         unit_scope.set_partitioner(
             partitioned_variables.fixed_size_partitioner(
@@ -481,13 +471,13 @@ class OutputProjectionWrapper(RNNCell):
       TypeError: if cell is not an RNNCell.
       ValueError: if output_size is not positive.
     """
+    super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
     if not isinstance(cell, RNNCell):
       raise TypeError("The parameter cell is not RNNCell.")
     if output_size < 1:
       raise ValueError("Parameter output_size must be > 0: %d." % output_size)
     self._cell = cell
     self._output_size = output_size
-    self._reuse = reuse
     self._activation = activation
 
   @property
@@ -502,15 +492,12 @@ class OutputProjectionWrapper(RNNCell):
     with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
       return self._cell.zero_state(batch_size, dtype)
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run the cell and output projection on inputs, starting from state."""
     output, res_state = self._cell(inputs, state)
-    # Default scope: "OutputProjectionWrapper"
-    with _checked_scope(self, scope or "output_projection_wrapper",
-                        reuse=self._reuse):
-      projected = _linear(output, self._output_size, True)
-      if self._activation:
-        projected = self._activation(projected)
+    projected = _linear(output, self._output_size, True)
+    if self._activation:
+      projected = self._activation(projected)
     return projected, res_state
 
 
@@ -522,7 +509,8 @@ class InputProjectionWrapper(RNNCell):
   do the projection on this batch-concatenated sequence, then split it.
   """
 
-  def __init__(self, cell, num_proj, activation=None, input_size=None):
+  def __init__(self, cell, num_proj, activation=None, input_size=None,
+               reuse=None):
     """Create a cell with input projection.
 
     Args:
@@ -530,10 +518,14 @@ class InputProjectionWrapper(RNNCell):
       num_proj: Python integer.  The dimension to project to.
       activation: (optional) an optional activation function.
       input_size: Deprecated and unused.
+      reuse: (optional) Python boolean describing whether to reuse variables
+        in an existing scope.  If not `True`, and the existing scope already has
+        the given variables, an error is raised.
 
     Raises:
       TypeError: if cell is not an RNNCell.
     """
+    super(InputProjectionWrapper, self).__init__(_reuse=reuse)
     if input_size is not None:
       logging.warn("%s: The input_size parameter is deprecated.", self)
     if not isinstance(cell, RNNCell):
@@ -554,13 +546,12 @@ class InputProjectionWrapper(RNNCell):
     with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
       return self._cell.zero_state(batch_size, dtype)
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run the input projection and then the cell."""
     # Default scope: "InputProjectionWrapper"
-    with vs.variable_scope(scope or "input_projection_wrapper"):
-      projected = _linear(inputs, self._num_proj, True)
-      if self._activation:
-        projected = self._activation(projected)
+    projected = _linear(inputs, self._num_proj, True)
+    if self._activation:
+      projected = self._activation(projected)
     return self._cell(projected, state)
 
 
@@ -847,6 +838,7 @@ class EmbeddingWrapper(RNNCell):
       TypeError: if cell is not an RNNCell.
       ValueError: if embedding_classes is not positive.
     """
+    super(EmbeddingWrapper, self).__init__(_reuse=reuse)
     if not isinstance(cell, RNNCell):
       raise TypeError("The parameter cell is not RNNCell.")
     if embedding_classes <= 0 or embedding_size <= 0:
@@ -856,7 +848,6 @@ class EmbeddingWrapper(RNNCell):
     self._embedding_classes = embedding_classes
     self._embedding_size = embedding_size
     self._initializer = initializer
-    self._reuse = reuse
 
   @property
   def state_size(self):
@@ -870,31 +861,31 @@ class EmbeddingWrapper(RNNCell):
     with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
       return self._cell.zero_state(batch_size, dtype)
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run the cell on embedded inputs."""
-    with _checked_scope(self, scope or "embedding_wrapper", reuse=self._reuse):
-      with ops.device("/cpu:0"):
-        if self._initializer:
-          initializer = self._initializer
-        elif vs.get_variable_scope().initializer:
-          initializer = vs.get_variable_scope().initializer
-        else:
-          # Default initializer for embeddings should have variance=1.
-          sqrt3 = math.sqrt(3)  # Uniform(-sqrt(3), sqrt(3)) has variance=1.
-          initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
+    with ops.device("/cpu:0"):
+      if self._initializer:
+        initializer = self._initializer
+      elif vs.get_variable_scope().initializer:
+        initializer = vs.get_variable_scope().initializer
+      else:
+        # Default initializer for embeddings should have variance=1.
+        sqrt3 = math.sqrt(3)  # Uniform(-sqrt(3), sqrt(3)) has variance=1.
+        initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
 
-        if type(state) is tuple:
-          data_type = state[0].dtype
-        else:
-          data_type = state.dtype
+      if type(state) is tuple:
+        data_type = state[0].dtype
+      else:
+        data_type = state.dtype
 
-        embedding = vs.get_variable(
-            "embedding", [self._embedding_classes, self._embedding_size],
-            initializer=initializer,
-            dtype=data_type)
-        embedded = embedding_ops.embedding_lookup(
-            embedding, array_ops.reshape(inputs, [-1]))
-    return self._cell(embedded, state)
+      embedding = vs.get_variable(
+          "embedding", [self._embedding_classes, self._embedding_size],
+          initializer=initializer,
+          dtype=data_type)
+      embedded = embedding_ops.embedding_lookup(embedding,
+                                                array_ops.reshape(inputs, [-1]))
+
+      return self._cell(embedded, state)
 
 
 class MultiRNNCell(RNNCell):
@@ -914,6 +905,7 @@ class MultiRNNCell(RNNCell):
       ValueError: if cells is empty (not allowed), or at least one of the cells
         returns a state tuple but the flag `state_is_tuple` is `False`.
     """
+    super(MultiRNNCell, self).__init__()
     if not cells:
       raise ValueError("Must specify at least one cell for MultiRNNCell.")
     if not nest.is_sequence(cells):
@@ -948,28 +940,29 @@ class MultiRNNCell(RNNCell):
         # presumably does not contain TensorArrays or anything else fancy
         return super(MultiRNNCell, self).zero_state(batch_size, dtype)
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run this multi-layer cell on inputs, starting from state."""
-    with vs.variable_scope(scope or "multi_rnn_cell"):
-      cur_state_pos = 0
-      cur_inp = inputs
-      new_states = []
-      for i, cell in enumerate(self._cells):
-        with vs.variable_scope("cell_%d" % i):
-          if self._state_is_tuple:
-            if not nest.is_sequence(state):
-              raise ValueError(
-                  "Expected state to be a tuple of length %d, but received: %s"
-                  % (len(self.state_size), state))
-            cur_state = state[i]
-          else:
-            cur_state = array_ops.slice(
-                state, [0, cur_state_pos], [-1, cell.state_size])
-            cur_state_pos += cell.state_size
-          cur_inp, new_state = cell(cur_inp, cur_state)
-          new_states.append(new_state)
+    cur_state_pos = 0
+    cur_inp = inputs
+    new_states = []
+    for i, cell in enumerate(self._cells):
+      with vs.variable_scope("cell_%d" % i):
+        if self._state_is_tuple:
+          if not nest.is_sequence(state):
+            raise ValueError(
+                "Expected state to be a tuple of length %d, but received: %s" %
+                (len(self.state_size), state))
+          cur_state = state[i]
+        else:
+          cur_state = array_ops.slice(state, [0, cur_state_pos],
+                                      [-1, cell.state_size])
+          cur_state_pos += cell.state_size
+        cur_inp, new_state = cell(cur_inp, cur_state)
+        new_states.append(new_state)
+
     new_states = (tuple(new_states) if self._state_is_tuple else
                   array_ops.concat(new_states, 1))
+
     return cur_inp, new_states
 
 
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 4eb2966ef28..83e8c2777f6 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -138,6 +138,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
+    super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
     if not state_is_tuple:
       logging.warn(
           "%s: Using a concatenated state is slower and will soon be "
@@ -173,7 +174,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
   def output_size(self):
     return self._output_size
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run one step of LSTM.
 
     Args:
@@ -182,7 +183,6 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
         `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
         tuple of state Tensors, both `2-D`, with column sizes `c_state` and
         `m_state`.
-      scope: VariableScope for the created subgraph; defaults to "LSTMCell".
 
     Returns:
       A tuple containing:
@@ -212,51 +212,49 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
     input_size = inputs.get_shape().with_rank(2)[1]
     if input_size.value is None:
       raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
-    with _checked_scope(self, scope or "coupled_input_forget_gate_lstm_cell",
-                        initializer=self._initializer, reuse=self._reuse):
-      concat_w = _get_concat_variable(
-          "W", [input_size.value + num_proj, 3 * self._num_units],
-          dtype, self._num_unit_shards)
+    concat_w = _get_concat_variable(
+        "W", [input_size.value + num_proj, 3 * self._num_units],
+        dtype, self._num_unit_shards)
 
-      b = vs.get_variable(
-          "B",
-          shape=[3 * self._num_units],
-          initializer=init_ops.zeros_initializer(),
-          dtype=dtype)
+    b = vs.get_variable(
+        "B",
+        shape=[3 * self._num_units],
+        initializer=init_ops.zeros_initializer(),
+        dtype=dtype)
 
-      # j = new_input, f = forget_gate, o = output_gate
-      cell_inputs = array_ops.concat([inputs, m_prev], 1)
-      lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
-      j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
+    # j = new_input, f = forget_gate, o = output_gate
+    cell_inputs = array_ops.concat([inputs, m_prev], 1)
+    lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
+    j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
 
-      # Diagonal connections
-      if self._use_peepholes:
-        w_f_diag = vs.get_variable(
-            "W_F_diag", shape=[self._num_units], dtype=dtype)
-        w_o_diag = vs.get_variable(
-            "W_O_diag", shape=[self._num_units], dtype=dtype)
+    # Diagonal connections
+    if self._use_peepholes:
+      w_f_diag = vs.get_variable(
+          "W_F_diag", shape=[self._num_units], dtype=dtype)
+      w_o_diag = vs.get_variable(
+          "W_O_diag", shape=[self._num_units], dtype=dtype)
 
-      if self._use_peepholes:
-        f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
-      else:
-        f_act = sigmoid(f + self._forget_bias)
-      c = (f_act * c_prev + (1 - f_act) * self._activation(j))
+    if self._use_peepholes:
+      f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
+    else:
+      f_act = sigmoid(f + self._forget_bias)
+    c = (f_act * c_prev + (1 - f_act) * self._activation(j))
 
-      if self._use_peepholes:
-        m = sigmoid(o + w_o_diag * c) * self._activation(c)
-      else:
-        m = sigmoid(o) * self._activation(c)
+    if self._use_peepholes:
+      m = sigmoid(o + w_o_diag * c) * self._activation(c)
+    else:
+      m = sigmoid(o) * self._activation(c)
 
-      if self._num_proj is not None:
-        concat_w_proj = _get_concat_variable(
-            "W_P", [self._num_units, self._num_proj],
-            dtype, self._num_proj_shards)
+    if self._num_proj is not None:
+      concat_w_proj = _get_concat_variable(
+          "W_P", [self._num_units, self._num_proj],
+          dtype, self._num_proj_shards)
 
-        m = math_ops.matmul(m, concat_w_proj)
-        if self._proj_clip is not None:
-          # pylint: disable=invalid-unary-operand-type
-          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
-          # pylint: enable=invalid-unary-operand-type
+      m = math_ops.matmul(m, concat_w_proj)
+      if self._proj_clip is not None:
+        # pylint: disable=invalid-unary-operand-type
+        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
+        # pylint: enable=invalid-unary-operand-type
 
     new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else
                  array_ops.concat([c, m], 1))
@@ -301,6 +299,7 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
+    super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
     self._num_units = num_units
     self._use_peepholes = use_peepholes
     self._cell_clip = cell_clip
@@ -321,14 +320,12 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
   def state_size(self):
     return self._state_size
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run one step of LSTM.
 
     Args:
       inputs: input Tensor, 2D, batch x num_units.
       state: state Tensor, 2D, batch x state_size.
-      scope: VariableScope for the created subgraph; defaults to
-        "TimeFreqLSTMCell".
 
     Returns:
       A tuple containing:
@@ -347,63 +344,63 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
     freq_inputs = self._make_tf_features(inputs)
     dtype = inputs.dtype
     actual_input_size = freq_inputs[0].get_shape().as_list()[1]
-    with _checked_scope(self, scope or "time_freq_lstm_cell",
-                        initializer=self._initializer, reuse=self._reuse):
-      concat_w = _get_concat_variable(
-          "W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
-          dtype, self._num_unit_shards)
-      b = vs.get_variable(
-          "B",
-          shape=[4 * self._num_units],
-          initializer=init_ops.zeros_initializer(),
-          dtype=dtype)
 
-      # Diagonal connections
+    concat_w = _get_concat_variable(
+        "W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
+        dtype, self._num_unit_shards)
+
+    b = vs.get_variable(
+        "B",
+        shape=[4 * self._num_units],
+        initializer=init_ops.zeros_initializer(),
+        dtype=dtype)
+
+    # Diagonal connections
+    if self._use_peepholes:
+      w_f_diag = vs.get_variable(
+          "W_F_diag", shape=[self._num_units], dtype=dtype)
+      w_i_diag = vs.get_variable(
+          "W_I_diag", shape=[self._num_units], dtype=dtype)
+      w_o_diag = vs.get_variable(
+          "W_O_diag", shape=[self._num_units], dtype=dtype)
+
+    # initialize the first freq state to be zero
+    m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
+                                   self._num_units], dtype)
+    for fq in range(len(freq_inputs)):
+      c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
+                               [-1, self._num_units])
+      m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
+                               [-1, self._num_units])
+      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+      cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
+                                     1)
+      lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
+      i, j, f, o = array_ops.split(
+          value=lstm_matrix, num_or_size_splits=4, axis=1)
+
       if self._use_peepholes:
-        w_f_diag = vs.get_variable(
-            "W_F_diag", shape=[self._num_units], dtype=dtype)
-        w_i_diag = vs.get_variable(
-            "W_I_diag", shape=[self._num_units], dtype=dtype)
-        w_o_diag = vs.get_variable(
-            "W_O_diag", shape=[self._num_units], dtype=dtype)
+        c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
+             sigmoid(i + w_i_diag * c_prev) * tanh(j))
+      else:
+        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
 
-      # initialize the first freq state to be zero
-      m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
-                                     self._num_units], dtype)
-      for fq in range(len(freq_inputs)):
-        c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
-                                 [-1, self._num_units])
-        m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
-                                 [-1, self._num_units])
-        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
-        cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
-                                       1)
-        lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
-        i, j, f, o = array_ops.split(
-            value=lstm_matrix, num_or_size_splits=4, axis=1)
+      if self._cell_clip is not None:
+        # pylint: disable=invalid-unary-operand-type
+        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
+        # pylint: enable=invalid-unary-operand-type
 
-        if self._use_peepholes:
-          c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
-               sigmoid(i + w_i_diag * c_prev) * tanh(j))
-        else:
-          c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
-
-        if self._cell_clip is not None:
-          # pylint: disable=invalid-unary-operand-type
-          c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
-          # pylint: enable=invalid-unary-operand-type
-
-        if self._use_peepholes:
-          m = sigmoid(o + w_o_diag * c) * tanh(c)
-        else:
-          m = sigmoid(o) * tanh(c)
-        m_prev_freq = m
-        if fq == 0:
-          state_out = array_ops.concat([c, m], 1)
-          m_out = m
-        else:
-          state_out = array_ops.concat([state_out, c, m], 1)
-          m_out = array_ops.concat([m_out, m], 1)
+      if self._use_peepholes:
+        m = sigmoid(o + w_o_diag * c) * tanh(c)
+      else:
+        m = sigmoid(o) * tanh(c)
+      m_prev_freq = m
+      if fq == 0:
+        state_out = array_ops.concat([c, m], 1)
+        m_out = m
+      else:
+        state_out = array_ops.concat([state_out, c, m], 1)
+        m_out = array_ops.concat([m_out, m], 1)
     return m_out, state_out
 
   def _make_tf_features(self, input_feat):
@@ -499,6 +496,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
     Raises:
       ValueError: if the num_frequency_blocks list is not specified
     """
+    super(GridLSTMCell, self).__init__(_reuse=reuse)
     if not state_is_tuple:
       logging.warn("%s: Using a concatenated state is slower and will soon be "
                    "deprecated.  Use state_is_tuple=True.", self)
@@ -550,15 +548,13 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
   def state_tuple_type(self):
     return self._state_tuple_type
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run one step of LSTM.
 
     Args:
       inputs: input Tensor, 2D, [batch, feature_size].
       state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
         flag self._state_is_tuple.
-      scope: (optional) VariableScope for the created subgraph; if None, it
-        defaults to "GridLSTMCell".
 
     Returns:
       A tuple containing:
@@ -573,21 +569,19 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
     """
     batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
     freq_inputs = self._make_tf_features(inputs)
-    with _checked_scope(self, scope or "grid_lstm_cell",
-                        initializer=self._initializer, reuse=self._reuse):
-      m_out_lst = []
-      state_out_lst = []
-      for block in range(len(freq_inputs)):
-        m_out_lst_current, state_out_lst_current = self._compute(
-            freq_inputs[block], block, state, batch_size,
-            state_is_tuple=self._state_is_tuple)
-        m_out_lst.extend(m_out_lst_current)
-        state_out_lst.extend(state_out_lst_current)
-      if self._state_is_tuple:
-        state_out = self._state_tuple_type(*state_out_lst)
-      else:
-        state_out = array_ops.concat(state_out_lst, 1)
-      m_out = array_ops.concat(m_out_lst, 1)
+    m_out_lst = []
+    state_out_lst = []
+    for block in range(len(freq_inputs)):
+      m_out_lst_current, state_out_lst_current = self._compute(
+          freq_inputs[block], block, state, batch_size,
+          state_is_tuple=self._state_is_tuple)
+      m_out_lst.extend(m_out_lst_current)
+      state_out_lst.extend(state_out_lst_current)
+    if self._state_is_tuple:
+      state_out = self._state_tuple_type(*state_out_lst)
+    else:
+      state_out = array_ops.concat(state_out_lst, 1)
+    m_out = array_ops.concat(m_out_lst, 1)
     return m_out, state_out
 
   def _compute(self, freq_inputs, block, state, batch_size,
@@ -974,14 +968,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
         *([num_units, num_units] * self._total_blocks * 2))
     self._output_size = 2 * num_units * self._total_blocks * 2
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run one step of LSTM.
 
     Args:
       inputs: input Tensor, 2D, [batch, num_units].
       state: tuple of Tensors, 2D, [batch, state_size].
-      scope: (optional) VariableScope for the created subgraph; if None, it
-        defaults to "BidirectionalGridLSTMCell".
 
     Returns:
       A tuple containing:
@@ -1002,29 +994,27 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
       bwd_inputs = fwd_inputs
 
     # Forward processing
-    with _checked_scope(self, scope or "bidirectional_grid_lstm_cell",
-                        initializer=self._initializer, reuse=self._reuse):
-      with vs.variable_scope("fwd"):
-        fwd_m_out_lst = []
-        fwd_state_out_lst = []
-        for block in range(len(fwd_inputs)):
-          fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
-              fwd_inputs[block], block, state, batch_size,
-              state_prefix="fwd_state", state_is_tuple=True)
-          fwd_m_out_lst.extend(fwd_m_out_lst_current)
-          fwd_state_out_lst.extend(fwd_state_out_lst_current)
-      # Backward processing
-      bwd_m_out_lst = []
-      bwd_state_out_lst = []
-      with vs.variable_scope("bwd"):
-        for block in range(len(bwd_inputs)):
-          # Reverse the blocks
-          bwd_inputs_reverse = bwd_inputs[block][::-1]
-          bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
-              bwd_inputs_reverse, block, state, batch_size,
-              state_prefix="bwd_state", state_is_tuple=True)
-          bwd_m_out_lst.extend(bwd_m_out_lst_current)
-          bwd_state_out_lst.extend(bwd_state_out_lst_current)
+    with vs.variable_scope("fwd"):
+      fwd_m_out_lst = []
+      fwd_state_out_lst = []
+      for block in range(len(fwd_inputs)):
+        fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
+            fwd_inputs[block], block, state, batch_size,
+            state_prefix="fwd_state", state_is_tuple=True)
+        fwd_m_out_lst.extend(fwd_m_out_lst_current)
+        fwd_state_out_lst.extend(fwd_state_out_lst_current)
+    # Backward processing
+    bwd_m_out_lst = []
+    bwd_state_out_lst = []
+    with vs.variable_scope("bwd"):
+      for block in range(len(bwd_inputs)):
+        # Reverse the blocks
+        bwd_inputs_reverse = bwd_inputs[block][::-1]
+        bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
+            bwd_inputs_reverse, block, state, batch_size,
+            state_prefix="bwd_state", state_is_tuple=True)
+        bwd_m_out_lst.extend(bwd_m_out_lst_current)
+        bwd_state_out_lst.extend(bwd_state_out_lst_current)
     state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
     # Outputs are always concated as it is never used separately.
     m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
@@ -1069,6 +1059,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
       ValueError: if cell returns a state tuple but the flag
           `state_is_tuple` is `False` or if attn_length is zero or less.
     """
+    super(AttentionCellWrapper, self).__init__(_reuse=reuse)
     if not isinstance(cell, core_rnn_cell.RNNCell):
       raise TypeError("The parameter cell is not RNNCell.")
     if nest.is_sequence(cell.state_size) and not state_is_tuple:
@@ -1107,42 +1098,40 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
   def output_size(self):
     return self._attn_size
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Long short-term memory cell with attention (LSTMA)."""
-    with _checked_scope(self, scope or "attention_cell_wrapper",
-                        reuse=self._reuse):
-      if self._state_is_tuple:
-        state, attns, attn_states = state
-      else:
-        states = state
-        state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
-        attns = array_ops.slice(
-            states, [0, self._cell.state_size], [-1, self._attn_size])
-        attn_states = array_ops.slice(
-            states, [0, self._cell.state_size + self._attn_size],
-            [-1, self._attn_size * self._attn_length])
-      attn_states = array_ops.reshape(attn_states,
-                                      [-1, self._attn_length, self._attn_size])
-      input_size = self._input_size
-      if input_size is None:
-        input_size = inputs.get_shape().as_list()[1]
-      inputs = _linear([inputs, attns], input_size, True)
-      lstm_output, new_state = self._cell(inputs, state)
-      if self._state_is_tuple:
-        new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
-      else:
-        new_state_cat = new_state
-      new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
-      with vs.variable_scope("attn_output_projection"):
-        output = _linear([lstm_output, new_attns], self._attn_size, True)
-      new_attn_states = array_ops.concat(
-          [new_attn_states, array_ops.expand_dims(output, 1)], 1)
-      new_attn_states = array_ops.reshape(
-          new_attn_states, [-1, self._attn_length * self._attn_size])
-      new_state = (new_state, new_attns, new_attn_states)
-      if not self._state_is_tuple:
-        new_state = array_ops.concat(list(new_state), 1)
-      return output, new_state
+    if self._state_is_tuple:
+      state, attns, attn_states = state
+    else:
+      states = state
+      state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
+      attns = array_ops.slice(
+          states, [0, self._cell.state_size], [-1, self._attn_size])
+      attn_states = array_ops.slice(
+          states, [0, self._cell.state_size + self._attn_size],
+          [-1, self._attn_size * self._attn_length])
+    attn_states = array_ops.reshape(attn_states,
+                                    [-1, self._attn_length, self._attn_size])
+    input_size = self._input_size
+    if input_size is None:
+      input_size = inputs.get_shape().as_list()[1]
+    inputs = _linear([inputs, attns], input_size, True)
+    lstm_output, new_state = self._cell(inputs, state)
+    if self._state_is_tuple:
+      new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
+    else:
+      new_state_cat = new_state
+    new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
+    with vs.variable_scope("attn_output_projection"):
+      output = _linear([lstm_output, new_attns], self._attn_size, True)
+    new_attn_states = array_ops.concat(
+        [new_attn_states, array_ops.expand_dims(output, 1)], 1)
+    new_attn_states = array_ops.reshape(
+        new_attn_states, [-1, self._attn_length * self._attn_size])
+    new_state = (new_state, new_attns, new_attn_states)
+    if not self._state_is_tuple:
+      new_state = array_ops.concat(list(new_state), 1)
+    return output, new_state
 
   def _attention(self, query, attn_states):
     conv2d = nn_ops.conv2d
@@ -1213,6 +1202,7 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
+    super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)
 
     if input_size is not None:
       logging.warn("%s: The input_size parameter is deprecated.", self)
@@ -1256,34 +1246,31 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
       out = nn_ops.bias_add(out, bias)
     return out
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """LSTM cell with layer normalization and recurrent dropout."""
+    c, h = state
+    args = array_ops.concat([inputs, h], 1)
+    concat = self._linear(args)
 
-    with _checked_scope(self, scope or "layer_norm_basic_lstm_cell",
-                        reuse=self._reuse):
-      c, h = state
-      args = array_ops.concat([inputs, h], 1)
-      concat = self._linear(args)
+    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
+    if self._layer_norm:
+      i = self._norm(i, "input")
+      j = self._norm(j, "transform")
+      f = self._norm(f, "forget")
+      o = self._norm(o, "output")
 
-      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
-      if self._layer_norm:
-        i = self._norm(i, "input")
-        j = self._norm(j, "transform")
-        f = self._norm(f, "forget")
-        o = self._norm(o, "output")
+    g = self._activation(j)
+    if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
+      g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
 
-      g = self._activation(j)
-      if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
-        g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
+    new_c = (c * math_ops.sigmoid(f + self._forget_bias)
+             + math_ops.sigmoid(i) * g)
+    if self._layer_norm:
+      new_c = self._norm(new_c, "state")
+    new_h = self._activation(new_c) * math_ops.sigmoid(o)
 
-      new_c = (c * math_ops.sigmoid(f + self._forget_bias)
-               + math_ops.sigmoid(i) * g)
-      if self._layer_norm:
-        new_c = self._norm(new_c, "state")
-      new_h = self._activation(new_c) * math_ops.sigmoid(o)
-
-      new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
-      return new_h, new_state
+    new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
+    return new_h, new_state
 
 
 class NASCell(core_rnn_cell.RNNCell):
@@ -1313,6 +1300,7 @@ class NASCell(core_rnn_cell.RNNCell):
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
+    super(NASCell, self).__init__(_reuse=reuse)
     self._num_units = num_units
     self._num_proj = num_proj
     self._use_biases = use_biases
@@ -1333,14 +1321,13 @@ class NASCell(core_rnn_cell.RNNCell):
   def output_size(self):
     return self._output_size
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run one step of NAS Cell.
 
     Args:
       inputs: input Tensor, 2D, batch x num_units.
       state: This must be a tuple of state Tensors, both `2-D`, with column
         sizes `c_state` and `m_state`.
-      scope: VariableScope for the created subgraph; defaults to "nas_rnn".
 
     Returns:
       A tuple containing:
@@ -1368,71 +1355,70 @@ class NASCell(core_rnn_cell.RNNCell):
     input_size = inputs.get_shape().with_rank(2)[1]
     if input_size.value is None:
       raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
-    with _checked_scope(self, scope or "nas_rnn", reuse=self._reuse):
-      # Variables for the NAS cell. W_m is all matrices multiplying the
-      # hiddenstate and W_inputs is all matrices multiplying the inputs.
-      concat_w_m = vs.get_variable(
-          "recurrent_weights", [num_proj, 8 * self._num_units],
-          dtype)
-      concat_w_inputs = vs.get_variable(
-          "weights", [input_size.value, 8 * self._num_units],
+    # Variables for the NAS cell. W_m is all matrices multiplying the
+    # hiddenstate and W_inputs is all matrices multiplying the inputs.
+    concat_w_m = vs.get_variable(
+        "recurrent_weights", [num_proj, 8 * self._num_units],
+        dtype)
+    concat_w_inputs = vs.get_variable(
+        "weights", [input_size.value, 8 * self._num_units],
+        dtype)
+
+    m_matrix = math_ops.matmul(m_prev, concat_w_m)
+    inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
+
+    if self._use_biases:
+      b = vs.get_variable(
+          "bias",
+          shape=[8 * self._num_units],
+          initializer=init_ops.zeros_initializer(),
+          dtype=dtype)
+      m_matrix = nn_ops.bias_add(m_matrix, b)
+
+    # The NAS cell branches into 8 different splits for both the hiddenstate
+    # and the input
+    m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
+                                      value=m_matrix)
+    inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
+                                           value=inputs_matrix)
+
+    # First layer
+    layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
+    layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
+    layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
+    layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
+    layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
+    layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
+    layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
+    layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
+
+    # Second layer
+    l2_0 = tanh(layer1_0 * layer1_1)
+    l2_1 = tanh(layer1_2 + layer1_3)
+    l2_2 = tanh(layer1_4 * layer1_5)
+    l2_3 = sigmoid(layer1_6 + layer1_7)
+
+    # Inject the cell
+    l2_0 = tanh(l2_0 + c_prev)
+
+    # Third layer
+    l3_0_pre = l2_0 * l2_1
+    new_c = l3_0_pre  # create new cell
+    l3_0 = l3_0_pre
+    l3_1 = tanh(l2_2 + l2_3)
+
+    # Final layer
+    new_m = tanh(l3_0 * l3_1)
+
+    # Projection layer if specified
+    if self._num_proj is not None:
+      concat_w_proj = vs.get_variable(
+          "projection_weights", [self._num_units, self._num_proj],
           dtype)
+      new_m = math_ops.matmul(new_m, concat_w_proj)
 
-      m_matrix = math_ops.matmul(m_prev, concat_w_m)
-      inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
-
-      if self._use_biases:
-        b = vs.get_variable(
-            "bias",
-            shape=[8 * self._num_units],
-            initializer=init_ops.zeros_initializer(),
-            dtype=dtype)
-        m_matrix = nn_ops.bias_add(m_matrix, b)
-
-      # The NAS cell branches into 8 different splits for both the hiddenstate
-      # and the input
-      m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
-                                        value=m_matrix)
-      inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
-                                             value=inputs_matrix)
-
-      # First layer
-      layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
-      layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
-      layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
-      layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
-      layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
-      layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
-      layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
-      layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
-
-      # Second layer
-      l2_0 = tanh(layer1_0 * layer1_1)
-      l2_1 = tanh(layer1_2 + layer1_3)
-      l2_2 = tanh(layer1_4 * layer1_5)
-      l2_3 = sigmoid(layer1_6 + layer1_7)
-
-      # Inject the cell
-      l2_0 = tanh(l2_0 + c_prev)
-
-      # Third layer
-      l3_0_pre = l2_0 * l2_1
-      new_c = l3_0_pre  # create new cell
-      l3_0 = l3_0_pre
-      l3_1 = tanh(l2_2 + l2_3)
-
-      # Final layer
-      new_m = tanh(l3_0 * l3_1)
-
-      # Projection layer if specified
-      if self._num_proj is not None:
-        concat_w_proj = vs.get_variable(
-            "projection_weights", [self._num_units, self._num_proj],
-            dtype)
-        new_m = math_ops.matmul(new_m, concat_w_proj)
-
-      new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
-      return new_m, new_state
+    new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
+    return new_m, new_state
 
 
 class UGRNNCell(core_rnn_cell.RNNCell):
@@ -1467,6 +1453,7 @@ class UGRNNCell(core_rnn_cell.RNNCell):
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
+    super(UGRNNCell, self).__init__(_reuse=reuse)
     self._num_units = num_units
     self._initializer = initializer
     self._forget_bias = forget_bias
@@ -1481,13 +1468,12 @@ class UGRNNCell(core_rnn_cell.RNNCell):
   def output_size(self):
     return self._num_units
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run one step of UGRNN.
 
     Args:
       inputs: input Tensor, 2D, batch x input size.
       state: state Tensor, 2D, batch x num units.
-      scope: VariableScope for the created subgraph; defaults to "ugrnn_cell".
 
     Returns:
       new_output: batch x num units, Tensor representing the output of the UGRNN
@@ -1506,8 +1492,8 @@ class UGRNNCell(core_rnn_cell.RNNCell):
     if input_size.value is None:
       raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
 
-    with _checked_scope(self, scope or "ugrnn_cell",
-                        initializer=self._initializer, reuse=self._reuse):
+    with vs.variable_scope(vs.get_variable_scope(),
+                           initializer=self._initializer):
       cell_inputs = array_ops.concat([inputs, state], 1)
       rnn_matrix = _linear(cell_inputs, 2 * self._num_units, True)
 
@@ -1567,6 +1553,7 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
         in an existing scope.  If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
+    super(IntersectionRNNCell, self).__init__(_reuse=reuse)
     self._num_units = num_units
     self._initializer = initializer
     self._forget_bias = forget_bias
@@ -1582,14 +1569,12 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
   def output_size(self):
     return self._num_units
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Run one step of the Intersection RNN.
 
     Args:
       inputs: input Tensor, 2D, batch x input size.
       state: state Tensor, 2D, batch x num units.
-      scope: VariableScope for the created subgraph; defaults to
-        "intersection_rnn_cell"
 
     Returns:
       new_y: batch x num units, Tensor representing the output of the +RNN
@@ -1610,8 +1595,8 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
     if input_size.value is None:
       raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
 
-    with _checked_scope(self, scope or "intersection_rnn_cell",
-                        initializer=self._initializer, reuse=self._reuse):
+    with vs.variable_scope(vs.get_variable_scope(),
+                           initializer=self._initializer):
       # read-in projections (should be used for first layer in deep +RNN
       # to transform size of inputs from I --> N)
       if input_size.value != self._num_units:
@@ -1683,7 +1668,7 @@ class CompiledWrapper(core_rnn_cell.RNNCell):
         return not _REGISTERED_OPS[node_def.op].is_stateful
 
     with jit.experimental_jit_scope(compile_ops=compile_ops):
-      return self._cell(inputs, state, scope=scope)
+      return self._cell(inputs, state, scope)
 
 
 def _random_exp_initializer(minval,
@@ -1753,6 +1738,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
         in an existing scope. If not `True`, and the existing scope already has
         the given variables, an error is raised.
     """
+    super(PhasedLSTMCell, self).__init__(_reuse=reuse)
     self._num_units = num_units
     self._use_peepholes = use_peepholes
     self._leak = leak
@@ -1782,7 +1768,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
     cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
     return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
 
-  def __call__(self, inputs, state, scope=None):
+  def call(self, inputs, state):
     """Phased LSTM Cell.
 
     Args:
@@ -1792,7 +1778,6 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
          The second Tensor has shape [batch, features_size], and type float32.
          It stores the features.
       state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
-      scope: string, id of the variable scope.
 
     Returns:
       A tuple containing:
@@ -1801,61 +1786,60 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
       - A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape
         [batch_size, num_units], representing the new state and the output.
     """
-    with _checked_scope(self, scope or "phased_lstm_cell", reuse=self._reuse):
-      (c_prev, h_prev) = state
-      (time, x) = inputs
+    (c_prev, h_prev) = state
+    (time, x) = inputs
 
-      in_mask_gates = [x, h_prev]
-      if self._use_peepholes:
-        in_mask_gates.append(c_prev)
+    in_mask_gates = [x, h_prev]
+    if self._use_peepholes:
+      in_mask_gates.append(c_prev)
 
-      with vs.variable_scope("mask_gates"):
-        mask_gates = math_ops.sigmoid(
-            _linear(in_mask_gates, 2 * self._num_units, True))
-        [input_gate, forget_gate] = array_ops.split(
-            axis=1, num_or_size_splits=2, value=mask_gates)
+    with vs.variable_scope("mask_gates"):
+      mask_gates = math_ops.sigmoid(
+          _linear(in_mask_gates, 2 * self._num_units, True))
+      [input_gate, forget_gate] = array_ops.split(
+          axis=1, num_or_size_splits=2, value=mask_gates)
 
-      with vs.variable_scope("new_input"):
-        new_input = math_ops.tanh(
-            _linear([x, h_prev], self._num_units, True))
+    with vs.variable_scope("new_input"):
+      new_input = math_ops.tanh(
+          _linear([x, h_prev], self._num_units, True))
 
-      new_c = (c_prev * forget_gate + input_gate * new_input)
+    new_c = (c_prev * forget_gate + input_gate * new_input)
 
-      in_out_gate = [x, h_prev]
-      if self._use_peepholes:
-        in_out_gate.append(new_c)
+    in_out_gate = [x, h_prev]
+    if self._use_peepholes:
+      in_out_gate.append(new_c)
 
-      with vs.variable_scope("output_gate"):
-        output_gate = math_ops.sigmoid(
-            _linear(in_out_gate, self._num_units, True))
+    with vs.variable_scope("output_gate"):
+      output_gate = math_ops.sigmoid(
+          _linear(in_out_gate, self._num_units, True))
 
-      new_h = math_ops.tanh(new_c) * output_gate
+    new_h = math_ops.tanh(new_c) * output_gate
 
-      period = vs.get_variable(
-          "period", [self._num_units],
-          initializer=_random_exp_initializer(
-              self._period_init_min, self._period_init_max))
-      phase = vs.get_variable(
-          "phase", [self._num_units],
-          initializer=init_ops.random_uniform_initializer(
-              0., period.initial_value))
-      ratio_on = vs.get_variable(
-          "ratio_on", [self._num_units],
-          initializer=init_ops.constant_initializer(self._ratio_on),
-          trainable=self._trainable_ratio_on)
+    period = vs.get_variable(
+        "period", [self._num_units],
+        initializer=_random_exp_initializer(
+            self._period_init_min, self._period_init_max))
+    phase = vs.get_variable(
+        "phase", [self._num_units],
+        initializer=init_ops.random_uniform_initializer(
+            0., period.initial_value))
+    ratio_on = vs.get_variable(
+        "ratio_on", [self._num_units],
+        initializer=init_ops.constant_initializer(self._ratio_on),
+        trainable=self._trainable_ratio_on)
 
-      cycle_ratio = self._get_cycle_ratio(time, phase, period)
+    cycle_ratio = self._get_cycle_ratio(time, phase, period)
 
-      k_up = 2 * cycle_ratio / ratio_on
-      k_down = 2 - k_up
-      k_closed = self._leak * cycle_ratio
+    k_up = 2 * cycle_ratio / ratio_on
+    k_down = 2 - k_up
+    k_closed = self._leak * cycle_ratio
 
-      k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
-      k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
+    k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
+    k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
 
-      new_c = k * new_c + (1 - k) * c_prev
-      new_h = k * new_h + (1 - k) * h_prev
+    new_c = k * new_c + (1 - k) * c_prev
+    new_h = k * new_h + (1 - k) * h_prev
 
-      new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
+    new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
 
-      return new_h, new_state
+    return new_h, new_state
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 333fa6983c6..023164d8262 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -454,6 +454,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
         up to the next cell in an RNN stack or to the top RNN output.
       name: Name to use when creating ops.
     """
+    super(AttentionWrapper, self).__init__()
     if not isinstance(cell, core_rnn_cell.RNNCell):
       raise TypeError(
           "cell must be an RNNCell, saw type: %s" % type(cell).__name__)
@@ -515,7 +516,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
                                         dtype),
           alignment_history=alignment_history)
 
-  def __call__(self, inputs, state, tiling_factor=1, scope=None):
+  def __call__(self, inputs, state, tiling_factor=1):
     """Perform a step of attention-wrapped RNN.
 
     - Step 1: Mix the `inputs` and previous step's `attention` output via
@@ -536,7 +537,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
         tensors from the previous time step.
       tiling_factor: An integer factor for which to tile the batch dimension.
         Used with BeamSearchDecoder.
-      scope: Must be `None`.
 
     Returns:
       A tuple `(attention_or_cell_output, next_state)`, where:
@@ -548,50 +548,46 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
     Raises:
       NotImplementedError: if `scope` is not `None`.
     """
-    if scope is not None:
-      raise NotImplementedError("scope not None is not supported")
+    # Step 1: Calculate the true inputs to the cell based on the
+    # previous attention value.
+    cell_inputs = self._cell_input_fn(inputs, state.attention)
+    cell_state = state.cell_state
+    cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
 
-    with variable_scope.variable_scope("attention"):
-      # Step 1: Calculate the true inputs to the cell based on the
-      # previous attention value.
-      cell_inputs = self._cell_input_fn(inputs, state.attention)
-      cell_state = state.cell_state
-      cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
+    score = self._attention_mechanism(cell_output, tiling_factor)
+    alignments = self._probability_fn(score)
 
-      score = self._attention_mechanism(cell_output, tiling_factor)
-      alignments = self._probability_fn(score)
+    # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
+    expanded_alignments = array_ops.expand_dims(alignments, 1)
+    # Context is the inner product of alignments and values along the
+    # memory time dimension.
+    # alignments shape is
+    #   [batch_size, 1, memory_time]
+    # attention_mechanism.values shape is
+    #   [batch_size, memory_time, attention_mechanism.num_units]
+    # the batched matmul is over memory_time, so the output shape is
+    #   [batch_size, 1, attention_mechanism.num_units].
+    # we then squeeze out the singleton dim.
+    attention_mechanism_values = _maybe_tile_batch(
+        self._attention_mechanism.values, tiling_factor)
 
-      # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
-      expanded_alignments = array_ops.expand_dims(alignments, 1)
-      # Context is the inner product of alignments and values along the
-      # memory time dimension.
-      # alignments shape is
-      #   [batch_size, 1, memory_time]
-      # attention_mechanism.values shape is
-      #   [batch_size, memory_time, attention_mechanism.num_units]
-      # the batched matmul is over memory_time, so the output shape is
-      #   [batch_size, 1, attention_mechanism.num_units].
-      # we then squeeze out the singleton dim.
-      attention_mechanism_values = _maybe_tile_batch(
-          self._attention_mechanism.values, tiling_factor)
+    context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
+    context = array_ops.squeeze(context, [1])
 
-      context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
-      context = array_ops.squeeze(context, [1])
+    attention = self._attention_layer(
+        array_ops.concat([cell_output, context], 1))
 
-      attention = self._attention_layer(
-          array_ops.concat([cell_output, context], 1))
+    if self._alignment_history:
+      alignment_history = state.alignment_history.write(
+          state.time, alignments)
+    else:
+      alignment_history = ()
 
-      if self._alignment_history:
-        alignment_history = state.alignment_history.write(
-            state.time, alignments)
-      else:
-        alignment_history = ()
-
-      next_state = AttentionWrapperState(
-          time=state.time + 1,
-          cell_state=next_cell_state,
-          attention=attention,
-          alignment_history=alignment_history)
+    next_state = AttentionWrapperState(
+        time=state.time + 1,
+        cell_state=next_cell_state,
+        attention=attention,
+        alignment_history=alignment_history)
 
     if self._output_attention:
       return attention, next_state
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 5e2a7abeac0..b522c5044c8 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1669,6 +1669,7 @@ py_library(
     deps = [
         ":array_ops",
         ":framework_for_generated_wrappers",
+        ":layers_base",
         ":util",
     ],
 )
@@ -3240,16 +3241,34 @@ py_tests(
 )
 
 py_library(
-    name = "layers",
+    name = "layers_base",
     srcs = [
         "layers/__init__.py",
         "layers/base.py",
+        "layers/utils.py",
+    ],
+    deps = [
+        ":array_ops",
+        ":control_flow_ops",
+        ":framework",
+        ":framework_for_generated_wrappers",
+        ":init_ops",
+        ":util",
+        ":variable_scope",
+        ":variables",
+        "//third_party/py/numpy",
+        "@six_archive//:six",
+    ],
+)
+
+py_library(
+    name = "layers",
+    srcs = [
         "layers/convolutional.py",
         "layers/core.py",
         "layers/layers.py",
         "layers/normalization.py",
         "layers/pooling.py",
-        "layers/utils.py",
     ],
     srcs_version = "PY2AND3",
     deps = [
@@ -3258,6 +3277,7 @@ py_library(
         ":framework",
         ":framework_for_generated_wrappers",
         ":init_ops",
+        ":layers_base",
         ":math_ops",
         ":nn",
         ":standard_ops",
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 21b0ba76266..ff9a777f191 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -35,6 +35,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import variables as tf_variables
 from tensorflow.python.ops import variable_scope as vs
 from tensorflow.python.util import nest
+from tensorflow.python.util import tf_inspect
 
 
 class _Layer(object):
@@ -279,8 +280,7 @@ class _Layer(object):
       inputs: input tensor(s).
       *args: additional positional arguments to be passed to `self.call`.
       **kwargs: additional keyword arguments to be passed to `self.call`.
-        **Note**, the kwarg 'scope' is reserved for use by the Layer.
-
+        **Note**: kwarg `scope` is reserved for use by the layer.
     Returns:
       Output tensor(s).
     """
@@ -328,6 +328,8 @@ class _Layer(object):
           else:
             self.build(input_shapes)
           self._built = True
+        if 'scope' in tf_inspect.getargspec(self.call).args:
+          kwargs['scope'] = scope
         outputs = self.call(inputs, *args, **kwargs)
 
         # Apply activity regularization.
@@ -365,19 +367,20 @@ class _Layer(object):
         setattr(result, k, copy.deepcopy(v, memo))
     return result
 
-  def apply(self, inputs, **kwargs):
+  def apply(self, inputs, *args, **kwargs):
     """Apply the layer on a input.
 
     This simply wraps `self.__call__`.
 
     Arguments:
       inputs: Input tensor(s).
+      *args: additional positional arguments to be passed to `self.call`.
       **kwargs: additional keyword arguments to be passed to `self.call`.
 
     Returns:
       Output tensor(s).
     """
-    return self.__call__(inputs, **kwargs)
+    return self.__call__(inputs, *args, **kwargs)
 
 
 def _to_snake_case(name):
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index c3dddf85f3d..32ebe0c2e84 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -26,6 +26,7 @@ from __future__ import print_function
 
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.layers import base as base_layer
 from tensorflow.python.ops import array_ops
 from tensorflow.python.util import nest
 
@@ -74,7 +75,7 @@ def _zero_state_tensors(state_size, batch_size, dtype):
   return zeros
 
 
-class _RNNCell(object):
+class _RNNCell(base_layer._Layer):  # pylint: disable=protected-access
   """Abstract object representing an RNN cell.
 
   Every `RNNCell` must have the properties below and implement `__call__` with
@@ -111,7 +112,7 @@ class _RNNCell(object):
       - New state: Either a single `2-D` tensor, or a tuple of tensors matching
         the arity and shapes of `state`.
     """
-    raise NotImplementedError("Abstract method")
+    return super(_RNNCell, self).__call__(inputs, state, scope=scope)
 
   @property
   def state_size(self):

From 7bc6271055714f3d4d0b957a2f4c6a910ea20388 Mon Sep 17 00:00:00 2001
From: Yuefeng Zhou <yuefengz@google.com>
Date: Fri, 21 Apr 2017 17:55:28 -0800
Subject: [PATCH 25/27] Copy function def to the optimized graph in the
 autoparallel optimizer. Change: 153896372

---
 tensorflow/core/grappler/op_types.cc          |  7 +++++
 tensorflow/core/grappler/op_types.h           |  1 +
 tensorflow/core/grappler/optimizers/BUILD     |  1 +
 .../core/grappler/optimizers/auto_parallel.cc | 28 +++++++++++--------
 4 files changed, 25 insertions(+), 12 deletions(-)

diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index e928e21264e..266d74976fe 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -18,6 +18,13 @@ limitations under the License.
 namespace tensorflow {
 namespace grappler {
 
+bool IsDequeueOp(const NodeDef& node) {
+  static const std::set<std::string> dequeue_ops = {
+      "QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2",
+      "QueueDequeue"};
+  return dequeue_ops.count(node.op()) > 0;
+}
+
 bool IsPlaceholder(const NodeDef& node) {
   const auto op = node.op();
   return op == "Placeholder" || op == "PlaceholderV2";
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 2f83325c9da..2f58835628d 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -21,6 +21,7 @@ limitations under the License.
 namespace tensorflow {
 namespace grappler {
 
+bool IsDequeueOp(const NodeDef& node);
 bool IsPlaceholder(const NodeDef& node);
 bool IsVariable(const NodeDef& node);
 
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 64d5815bf78..d7a7989dfad 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -40,6 +40,7 @@ cc_library(
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler:devices",
         "//tensorflow/core/grappler:grappler_item",
+        "//tensorflow/core/grappler:op_types",
         "//tensorflow/core/grappler:utils",
         "//tensorflow/core/grappler/clusters:cluster",
     ],
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc
index 77ab178653b..b5497d35947 100644
--- a/tensorflow/core/grappler/optimizers/auto_parallel.cc
+++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc
@@ -14,11 +14,14 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
+
 #include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/grappler/clusters/cluster.h"
 #include "tensorflow/core/grappler/devices.h"
 #include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/utils.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 
@@ -94,22 +97,22 @@ Status AutoParallel::Initialize(const GrapplerItem& item) {
     VLOG(2) << "Variable: " << var->name();
   }
 
-  std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
-                                          "ApplyProximalGradientDescent",
-                                          "ApplyAdadelta",
-                                          "ApplyAdagrad",
-                                          "ApplyProximalAdagrad",
-                                          "ApplyAdagradDA",
-                                          "ApplyFtrl",
-                                          "ApplyMomentum",
-                                          "ApplyAdam",
-                                          "ApplyRMSProp",
-                                          "ApplyCenteredRMSProp"};
+  const std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
+                                                "ApplyProximalGradientDescent",
+                                                "ApplyAdadelta",
+                                                "ApplyAdagrad",
+                                                "ApplyProximalAdagrad",
+                                                "ApplyAdagradDA",
+                                                "ApplyFtrl",
+                                                "ApplyMomentum",
+                                                "ApplyAdam",
+                                                "ApplyRMSProp",
+                                                "ApplyCenteredRMSProp"};
   const NodeDef* dequeue_node = nullptr;
   for (int i = 0; i < graph_.node_size(); i++) {
     all_nodes_.insert(
         std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
-    if (graph_.node(i).op() == "QueueDequeueManyV2") {
+    if (IsDequeueOp(graph_.node(i))) {
       dequeue_node = graph_.mutable_node(i);
     }
     if (apply_gradients_ops.find(graph_.node(i).op()) !=
@@ -241,6 +244,7 @@ void AutoParallel::BuildGraph(GraphDef* graph) {
   for (const auto& fetch : item_->fetch) {
     AddNodeControl(fetch, {control->name()}, graph);
   }
+  *(graph->mutable_library()) = item_->graph.library();
   LOG(INFO) << "Parallelized graph size: " << graph->node_size();
 }
 

From 3c0900a49c11b7975c7accc026153bbc2001c018 Mon Sep 17 00:00:00 2001
From: Eugene Brevdo <ebrevdo@google.com>
Date: Fri, 21 Apr 2017 20:16:56 -0800
Subject: [PATCH 26/27] Fix python3 build caused by recent RNNCell refactor.
 Change: 153902768

---
 tensorflow/python/BUILD | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index b522c5044c8..cad8ccaaad6 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3247,6 +3247,7 @@ py_library(
         "layers/base.py",
         "layers/utils.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":array_ops",
         ":control_flow_ops",

From 326942394e69074d50d5889218a24c9371eff259 Mon Sep 17 00:00:00 2001
From: Shanqing Cai <cais@google.com>
Date: Sat, 22 Apr 2017 06:08:17 -0800
Subject: [PATCH 27/27] Merge changes from github. Change: 153925676

---
 README.md                                     |   15 +-
 RELEASE.md                                    |    3 +
 configure                                     |   22 +-
 tensorflow/BUILD                              |   24 +-
 tensorflow/compiler/aot/tfcompile.bzl         |    1 +
 .../contrib/ffmpeg/default/ffmpeg_lib.cc      |    2 +-
 .../learn/python/learn/datasets/mnist.py      |   31 +-
 .../python/learn/estimators/estimator_test.py |   14 -
 .../makefile/compile_linux_protobuf.sh        |    2 +-
 .../rnn/python/kernel_tests/lstm_ops_test.py  |    4 +-
 .../seq2seq/python/ops/attention_wrapper.py   |    2 +-
 tensorflow/contrib/slim/README.md             |    2 +-
 tensorflow/contrib/verbs/BUILD                |  168 +++
 tensorflow/contrib/verbs/README.md            |   77 +
 tensorflow/contrib/verbs/design_diagram.png   |  Bin 0 -> 13625 bytes
 tensorflow/contrib/verbs/grpc_verbs_client.cc |   47 +
 tensorflow/contrib/verbs/grpc_verbs_client.h  |   50 +
 .../contrib/verbs/grpc_verbs_service.cc       |  165 +++
 tensorflow/contrib/verbs/grpc_verbs_service.h |   72 +
 .../contrib/verbs/grpc_verbs_service_impl.cc  |   68 +
 .../contrib/verbs/grpc_verbs_service_impl.h   |   89 ++
 tensorflow/contrib/verbs/rdma.cc              |  874 +++++++++++
 tensorflow/contrib/verbs/rdma.h               |  277 ++++
 tensorflow/contrib/verbs/rdma_mgr.cc          |  133 ++
 tensorflow/contrib/verbs/rdma_mgr.h           |   54 +
 .../contrib/verbs/rdma_rendezvous_mgr.cc      |  149 ++
 .../contrib/verbs/rdma_rendezvous_mgr.h       |   64 +
 tensorflow/contrib/verbs/verbs_server_lib.cc  |  172 +++
 tensorflow/contrib/verbs/verbs_server_lib.h   |   66 +
 tensorflow/contrib/verbs/verbs_service.proto  |   60 +
 tensorflow/contrib/verbs/verbs_util.cc        |   61 +
 tensorflow/contrib/verbs/verbs_util.h         |   41 +
 tensorflow/core/BUILD                         |   14 +-
 tensorflow/core/debug/debug_io_utils.cc       |    2 +
 .../rpc/grpc_server_lib.cc                    |   26 +-
 .../distributed_runtime/rpc/grpc_server_lib.h |   18 +
 tensorflow/core/framework/function_testlib.cc |   44 +-
 tensorflow/core/graph/mkl_layout_pass.cc      | 1273 ++++++++++++-----
 tensorflow/core/graph/mkl_layout_pass_test.cc |  563 +++++++-
 tensorflow/core/graph/mkl_optimizer_merge.cc  |  651 ---------
 tensorflow/core/graph/mkl_optimizer_merge.h   |   36 -
 .../core/graph/mkl_optimizer_merge_test.cc    |  470 ------
 .../core/graph/mkl_tfconversion_pass.cc       |   55 +-
 .../core/graph/mkl_tfconversion_pass_test.cc  |  153 +-
 .../core/grappler/optimizers/auto_parallel.cc |    2 +-
 tensorflow/core/kernels/BUILD                 |   32 +
 .../kernels/fixed_length_record_reader_op.cc  |   40 +-
 tensorflow/core/kernels/mkl_avgpooling_op.cc  |   42 +-
 tensorflow/core/kernels/mkl_concat_op.cc      |  458 ++++++
 .../core/kernels/mkl_conv_grad_bias_ops.cc    |   12 +-
 .../core/kernels/mkl_conv_grad_filter_ops.cc  |   12 +-
 .../core/kernels/mkl_conv_grad_input_ops.cc   |   12 +-
 tensorflow/core/kernels/mkl_conv_ops.cc       |   24 +-
 .../core/kernels/mkl_fused_batch_norm_op.cc   |  689 +++++++++
 tensorflow/core/kernels/mkl_lrn_op.cc         |  722 ++++++++++
 tensorflow/core/kernels/mkl_maxpooling_op.cc  |   74 +-
 tensorflow/core/kernels/mkl_relu_op.cc        |   32 +-
 tensorflow/core/kernels/mkl_reshape_op.cc     |  149 ++
 tensorflow/core/kernels/mkl_tfconv_op.cc      |   10 +-
 tensorflow/core/kernels/quantized_conv_ops.cc |    4 +-
 .../kernels/sparse_tensor_dense_matmul_op.cc  |  139 +-
 .../kernels/sparse_tensor_dense_matmul_op.h   |    5 +-
 .../sparse_tensor_dense_matmul_op_gpu.cu.cc   |   35 +-
 tensorflow/core/ops/array_ops.cc              |   58 +
 tensorflow/core/ops/io_ops.cc                 |   12 +
 tensorflow/core/ops/nn_ops.cc                 |  257 +++-
 tensorflow/core/ops/ops.pbtxt                 |   53 +
 tensorflow/core/ops/sparse_ops.cc             |    3 +-
 .../core/platform/default/build_config.bzl    |   12 +-
 .../platform/default/build_config_root.bzl    |    8 +
 tensorflow/core/public/version.h              |    4 +-
 tensorflow/core/util/mkl_util.h               |  271 +++-
 tensorflow/docs_src/community/style_guide.md  |   44 +-
 tensorflow/docs_src/extend/adding_an_op.md    |    6 +-
 .../docs_src/get_started/get_started.md       |    2 +-
 tensorflow/docs_src/get_started/monitors.md   |    9 +-
 tensorflow/docs_src/get_started/tflearn.md    |    2 +-
 tensorflow/docs_src/install/install_c.md      |    2 +-
 tensorflow/docs_src/install/install_go.md     |    2 +-
 tensorflow/docs_src/install/install_java.md   |   16 +-
 tensorflow/docs_src/install/install_linux.md  |   28 +-
 tensorflow/docs_src/install/install_mac.md    |   14 +-
 .../docs_src/install/install_sources.md       |    5 +-
 .../docs_src/install/install_windows.md       |    4 +-
 .../fully_connected_preloaded_var.py          |    3 +-
 .../tutorials/mnist/mnist_with_summaries.py   |   17 +-
 tensorflow/python/BUILD                       |    5 +-
 tensorflow/python/framework/dtypes_test.py    |    2 +-
 .../python/kernel_tests/reader_ops_test.py    |   54 +
 .../sparse_tensor_dense_matmul_grad_test.py   |   42 +-
 .../sparse_tensor_dense_matmul_op_test.py     |   19 +-
 .../python/layers/convolutional_test.py       |    2 +-
 .../python/layers/normalization_test.py       |    2 +-
 tensorflow/python/layers/utils_test.py        |    2 +-
 tensorflow/python/ops/batch_norm_benchmark.py |    2 +-
 tensorflow/python/ops/io_ops.py               |   14 +-
 tensorflow/python/ops/nn_impl.py              |   28 +-
 tensorflow/python/ops/sparse_grad.py          |   14 +-
 tensorflow/python/ops/sparse_ops.py           |    6 +-
 .../stream_executor_internal.h                |    2 +-
 tensorflow/tensorboard/DEVELOPMENT.md         |    2 +-
 .../tensorboard/dist/tf-tensorboard.html      |   20 +-
 tensorflow/tensorboard/gulp_tasks/bower.js    |    4 +-
 tensorflow/tensorboard/gulp_tasks/compile.js  |   34 +-
 tensorflow/tensorboard/gulp_tasks/test.js     |    4 +-
 tensorflow/tensorboard/gulp_tasks/util.js     |    6 +-
 .../tensorboard/gulp_tasks/vulcanize.js       |   27 +-
 tensorflow/tensorboard/gulpfile.js            |   20 +-
 tensorflow/tensorboard/package.json           |    2 +-
 ...nsorflow.-fixed-length-record-reader.pbtxt |    2 +-
 tensorflow/tools/docker/Dockerfile.devel      |    4 +-
 tensorflow/tools/docker/Dockerfile.devel-gpu  |    4 +-
 tensorflow/tools/pip_package/setup.py         |    2 +-
 third_party/jemalloc.BUILD                    |   33 +-
 third_party/llvm/llvm.BUILD                   |    5 +
 115 files changed, 7572 insertions(+), 2230 deletions(-)
 create mode 100644 tensorflow/contrib/verbs/BUILD
 create mode 100644 tensorflow/contrib/verbs/README.md
 create mode 100644 tensorflow/contrib/verbs/design_diagram.png
 create mode 100644 tensorflow/contrib/verbs/grpc_verbs_client.cc
 create mode 100644 tensorflow/contrib/verbs/grpc_verbs_client.h
 create mode 100644 tensorflow/contrib/verbs/grpc_verbs_service.cc
 create mode 100644 tensorflow/contrib/verbs/grpc_verbs_service.h
 create mode 100644 tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
 create mode 100644 tensorflow/contrib/verbs/grpc_verbs_service_impl.h
 create mode 100644 tensorflow/contrib/verbs/rdma.cc
 create mode 100644 tensorflow/contrib/verbs/rdma.h
 create mode 100644 tensorflow/contrib/verbs/rdma_mgr.cc
 create mode 100644 tensorflow/contrib/verbs/rdma_mgr.h
 create mode 100644 tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
 create mode 100644 tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
 create mode 100644 tensorflow/contrib/verbs/verbs_server_lib.cc
 create mode 100644 tensorflow/contrib/verbs/verbs_server_lib.h
 create mode 100644 tensorflow/contrib/verbs/verbs_service.proto
 create mode 100644 tensorflow/contrib/verbs/verbs_util.cc
 create mode 100644 tensorflow/contrib/verbs/verbs_util.h
 delete mode 100644 tensorflow/core/graph/mkl_optimizer_merge.cc
 delete mode 100644 tensorflow/core/graph/mkl_optimizer_merge.h
 delete mode 100644 tensorflow/core/graph/mkl_optimizer_merge_test.cc
 create mode 100644 tensorflow/core/kernels/mkl_concat_op.cc
 create mode 100644 tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
 create mode 100644 tensorflow/core/kernels/mkl_lrn_op.cc
 create mode 100644 tensorflow/core/kernels/mkl_reshape_op.cc

diff --git a/README.md b/README.md
index 3ab47736813..951e7c3b9f6 100644
--- a/README.md
+++ b/README.md
@@ -34,12 +34,13 @@ and discussion.**
 
 People who are a little more adventurous can also try our nightly binaries:
 
-* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
-* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
-* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
-* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
-* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.1.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/))
-* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.1.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/))
+
+* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
+* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
+* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
+* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
+* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.1.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/))
+* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.1.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/))
 * Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
 ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
 
@@ -62,7 +63,7 @@ $ python
 
 ## For more information
 
-* [TensorFlow website](http://tensorflow.org)
+* [TensorFlow website](https://tensorflow.org)
 * [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf)
 * [TensorFlow Model Zoo](https://github.com/tensorflow/models)
 * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
diff --git a/RELEASE.md b/RELEASE.md
index 6087390c9c7..fe6d052640a 100644
--- a/RELEASE.md
+++ b/RELEASE.md
@@ -36,6 +36,7 @@
   * New navigation bar in Curses-based UI
   * NodeStepper (command `invoke_stepper`) now uses intermediate tensor dumps. It also uses `TensorHandles` as direct feeds during successive `cont` calls for improved performance and reduced memory consumption.
 * Initial release of installation guides for Java, C, and Go.
+* Added Text Dashboard to TensorBoard.
 
 ## Deprecations
 
@@ -91,6 +92,8 @@
   * Command history now persists across runs.
   * Bug fix in graph validation related to `tf.while_loops`.
 * Java Maven fixes for bugs with Windows installation.
+* Backport fixes and improvements from external keras.
+* Keras config file handling fix.
 
 ## Thanks to our Contributors
 
diff --git a/configure b/configure
index 47bdd5d018e..fad3fdbebd9 100755
--- a/configure
+++ b/configure
@@ -94,10 +94,10 @@ write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH"
 if false; then # Disable building with MKL for now
   while [ "$TF_NEED_MKL" == "" ]; do
     fromuser=""
-    read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
+    read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT
     fromuser="1"
     case $INPUT in
-      [Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
+      [Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;;
       [Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
       "" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
       * ) echo "Invalid selection: " $INPUT;;
@@ -244,6 +244,24 @@ if [[ "$TF_ENABLE_XLA" == "1" ]]; then
   write_to_bazelrc 'build --define with_xla_support=true'
 fi
 
+# Verbs configuration
+while [ "$TF_NEED_VERBS" == "" ]; do
+  read -p "Do you wish to build TensorFlow with "\
+"VERBS support? [y/N] " INPUT
+  case $INPUT in
+    [Yy]* ) echo "VERBS support will be enabled for "\
+"TensorFlow"; TF_NEED_VERBS=1;;
+    [Nn]* ) echo "No VERBS support will be enabled for "\
+"TensorFlow"; TF_NEED_VERBS=0;;
+    "" ) echo "No VERBS support will be enabled for "\
+"TensorFlow"; TF_NEED_VERBS=0;;
+    * ) echo "Invalid selection: " $INPUT;;
+  esac
+done
+
+if [[ "$TF_NEED_VERBS" == "1" ]]; then
+  write_to_bazelrc 'build --define with_verbs_support=true'
+fi
 
 # Invoke python_config and set up symlinks to python includes
 ./util/python/python_config.sh "$PYTHON_BIN_PATH"
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 0f7f848cb1a..248b18e020e 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -84,6 +84,12 @@ config_setting(
     visibility = ["//visibility:public"],
 )
 
+config_setting(
+    name = "linux_ppc64le",
+    values = {"cpu": "ppc"},
+    visibility = ["//visibility:public"],
+)
+
 config_setting(
     name = "debug",
     values = {
@@ -108,7 +114,7 @@ config_setting(
 
 # TODO(jhseu): Enable on other platforms other than Linux.
 config_setting(
-    name = "with_jemalloc",
+    name = "with_jemalloc_linux_x86_64",
     values = {
         "cpu": "k8",
         "define": "with_jemalloc=true",
@@ -116,6 +122,15 @@ config_setting(
     visibility = ["//visibility:public"],
 )
 
+config_setting(
+    name = "with_jemalloc_linux_ppc64le",
+    values = {
+        "cpu": "ppc",
+        "define": "with_jemalloc=true",
+    },
+    visibility = ["//visibility:public"],
+)
+
 config_setting(
     name = "with_gcp_support",
     values = {"define": "with_gcp_support=true"},
@@ -134,6 +149,12 @@ config_setting(
     visibility = ["//visibility:public"],
 )
 
+config_setting(
+    name = "with_verbs_support",
+    values = {"define": "with_verbs_support=true"},
+    visibility = ["//visibility:public"],
+)
+
 package_group(
     name = "internal",
     packages = ["//tensorflow/..."],
@@ -249,6 +270,7 @@ filegroup(
         "//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
         "//tensorflow/contrib/training:all_files",
         "//tensorflow/contrib/util:all_files",
+        "//tensorflow/contrib/verbs:all_files",
         "//tensorflow/contrib/xla_tf_graph:all_files",
         "//tensorflow/core:all_files",
         "//tensorflow/core/debug:all_files",
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 64e5bfd602c..7d61bee8caf 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -282,5 +282,6 @@ def target_llvm_triple():
       "//tensorflow:android_arm": "armv7-none-android",
       "//tensorflow:android_arm64": "aarch64-none-android",
       "//tensorflow:android_x86": "i686-none-android",
+      "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
       "//conditions:default": "x86_64-pc-linux",
   })
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
index a758bb92aaa..e520139e659 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
@@ -142,7 +142,7 @@ template <typename UInt>
 string LittleEndianData(UInt data) {
   static_assert(std::is_unsigned<UInt>::value, "UInt must be unsigned");
   string str;
-  for (int i = 0; i < sizeof(UInt); ++i) {
+  for (size_t i = 0; i < sizeof(UInt); ++i) {
     const unsigned char bits = static_cast<unsigned char>(data & 0xFFU);
     char ch;
     ::memcpy(&ch, &bits, sizeof(bits));
diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py
index fd50070dac5..13f213c197f 100644
--- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py
+++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py
@@ -26,6 +26,7 @@ from six.moves import xrange  # pylint: disable=redefined-builtin
 
 from tensorflow.contrib.learn.python.learn.datasets import base
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import random_seed
 
 # CVDF mirror of http://yann.lecun.com/exdb/mnist/
 SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
@@ -109,12 +110,16 @@ class DataSet(object):
                fake_data=False,
                one_hot=False,
                dtype=dtypes.float32,
-               reshape=True):
+               reshape=True,
+               seed=None):
     """Construct a DataSet.
     one_hot arg is used only if fake_data is true.  `dtype` can be either
     `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
-    `[0, 1]`.
+    `[0, 1]`.  Seed arg provides for convenient deterministic testing.
     """
+    seed1, seed2 = random_seed.get_seed(seed)
+    # If op level seed is not set, use whatever graph level seed is returned
+    numpy.random.seed(seed1 if seed is None else seed2)
     dtype = dtypes.as_dtype(dtype).base_dtype
     if dtype not in (dtypes.uint8, dtypes.float32):
       raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
@@ -208,11 +213,13 @@ def read_data_sets(train_dir,
                    one_hot=False,
                    dtype=dtypes.float32,
                    reshape=True,
-                   validation_size=5000):
+                   validation_size=5000,
+                   seed=None):
   if fake_data:
 
     def fake():
-      return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
+      return DataSet(
+          [], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
 
     train = fake()
     validation = fake()
@@ -254,12 +261,16 @@ def read_data_sets(train_dir,
   train_images = train_images[validation_size:]
   train_labels = train_labels[validation_size:]
 
-  train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
-  validation = DataSet(validation_images,
-                       validation_labels,
-                       dtype=dtype,
-                       reshape=reshape)
-  test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)
+  train = DataSet(
+      train_images, train_labels, dtype=dtype, reshape=reshape, seed=seed)
+  validation = DataSet(
+      validation_images,
+      validation_labels,
+      dtype=dtype,
+      reshape=reshape,
+      seed=seed)
+  test = DataSet(
+      test_images, test_labels, dtype=dtype, reshape=reshape, seed=seed)
 
   return base.Datasets(train=train, validation=validation, test=test)
 
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index ce2eb4e0523..6e10fdb9776 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -52,7 +52,6 @@ from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import variables as variables_lib
@@ -63,7 +62,6 @@ from tensorflow.python.saved_model import tag_constants
 from tensorflow.python.training import basic_session_run_hooks
 from tensorflow.python.training import input as input_lib
 from tensorflow.python.training import monitored_session
-from tensorflow.python.training import queue_runner_impl
 from tensorflow.python.training import saver as saver_lib
 from tensorflow.python.training import session_run_hook
 from tensorflow.python.util import compat
@@ -82,18 +80,6 @@ def boston_input_fn(num_epochs=None):
   return features, labels
 
 
-def boston_input_fn_with_queue(num_epochs=None):
-  features, labels = boston_input_fn(num_epochs=num_epochs)
-
-  # Create a minimal queue runner.
-  fake_queue = data_flow_ops.FIFOQueue(30, dtypes.int32)
-  queue_runner = queue_runner_impl.QueueRunner(fake_queue,
-                                               [constant_op.constant(0)])
-  queue_runner_impl.add_queue_runner(queue_runner)
-
-  return features, labels
-
-
 def iris_input_fn():
   iris = base.load_iris()
   features = array_ops.reshape(
diff --git a/tensorflow/contrib/makefile/compile_linux_protobuf.sh b/tensorflow/contrib/makefile/compile_linux_protobuf.sh
index 480fbcc215c..6eb061a3c96 100755
--- a/tensorflow/contrib/makefile/compile_linux_protobuf.sh
+++ b/tensorflow/contrib/makefile/compile_linux_protobuf.sh
@@ -38,7 +38,7 @@ then
   exit 1
 fi
 
-./configure --prefix="${GENDIR}"
+./configure --prefix="${GENDIR}" --with-pic
 if [ $? -ne 0 ]
 then
   echo "./configure command failed."
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
index 9a96d4e8560..3a5cbf604dc 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
@@ -68,8 +68,8 @@ class LSTMBlockCellTest(test.TestCase):
         m3 = array_ops.zeros([1, 2])
         g, ((out_m0, out_m1),
             (out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
-                [lstm_ops.LSTMBlockCell(2)] * 2, state_is_tuple=True)(x, (
-                    (m0, m1), (m2, m3)))
+                [lstm_ops.LSTMBlockCell(2) for _ in range(2)],
+                state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
         sess.run([variables.global_variables_initializer()])
         res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
             x.name: np.array([[1., 1.]]),
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 023164d8262..37622af59f6 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -473,7 +473,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
     if probability_fn is None:
       probability_fn = nn_ops.softmax
     else:
-      if not callable(cell_input_fn):
+      if not callable(probability_fn):
         raise TypeError(
             "probability_fn must be callable, saw type: %s"
             % type(probability_fn).__name__)
diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md
index dae50e67c5f..c8842dd57b1 100644
--- a/tensorflow/contrib/slim/README.md
+++ b/tensorflow/contrib/slim/README.md
@@ -447,7 +447,7 @@ vgg = tf.contrib.slim.nets.vgg
 images, labels = ...
 
 # Create the model.
-predictions = vgg.vgg16(images)
+predictions = vgg.vgg_16(images)
 
 # Define the loss functions and get the total loss.
 loss = slim.losses.softmax_cross_entropy(predictions, labels)
diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD
new file mode 100644
index 00000000000..e747fa4c9e4
--- /dev/null
+++ b/tensorflow/contrib/verbs/BUILD
@@ -0,0 +1,168 @@
+# Description:
+#   Verbs RDMA communication interfaces and implementations for TensorFlow.
+
+package(default_visibility = [
+    "//tensorflow:__subpackages__",
+])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
+
+filegroup(
+    name = "c_srcs",
+    data = glob([
+        "**/*.cc",
+        "**/*.h",
+    ]),
+)
+
+# For platform specific build config
+load(
+    "//tensorflow/core:platform/default/build_config.bzl",
+    "tf_proto_library_cc",
+)
+
+tf_proto_library_cc(
+    name = "verbs_service_proto",
+    srcs = ["verbs_service.proto"],
+    has_services = 1,
+    cc_api_version = 2,
+    visibility = [
+        "//tensorflow:__subpackages__",
+    ],
+)
+
+cc_library(
+    name = "verbs_util",
+    srcs = ["verbs_util.cc"],
+    hdrs = ["verbs_util.h"],
+    deps = [
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_runtime",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+    ],
+)
+
+cc_library(
+    name = "grpc_verbs_service",
+    srcs = ["grpc_verbs_service.cc"],
+    hdrs = ["grpc_verbs_service.h"],
+    deps = [
+        ":grpc_verbs_service_impl",
+        ":rdma_mgr",
+        ":verbs_service_proto_cc",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/distributed_runtime:session_mgr",
+        "//tensorflow/core/distributed_runtime:worker_env",
+        "//tensorflow/core/distributed_runtime/rpc:async_service_interface",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_call",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+        "@grpc//:grpc++_unsecure",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "grpc_verbs_service_impl",
+    srcs = ["grpc_verbs_service_impl.cc"],
+    hdrs = ["grpc_verbs_service_impl.h"],
+    deps = [
+        ":verbs_service_proto_cc",
+        "@grpc//:grpc++_unsecure",
+    ],
+)
+
+cc_library(
+    name = "grpc_verbs_client",
+    srcs = ["grpc_verbs_client.cc"],
+    hdrs = ["grpc_verbs_client.h"],
+    deps = [
+        ":grpc_verbs_service_impl",
+        ":verbs_service_proto_cc",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/distributed_runtime:call_options",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "rdma_rendezvous_mgr",
+    srcs = ["rdma_rendezvous_mgr.cc"],
+    hdrs = ["rdma_rendezvous_mgr.h"],
+    deps = [
+        ":rdma_mgr",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
+        "//tensorflow/core/distributed_runtime:worker_env",
+    ],
+)
+
+cc_library(
+    name = "rdma_mgr",
+    srcs = ["rdma_mgr.cc"],
+    hdrs = ["rdma_mgr.h"],
+    deps = [
+        ":grpc_verbs_client",
+        ":rdma",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core/distributed_runtime:worker_env",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
+    ],
+)
+
+cc_library(
+    name = "rdma",
+    srcs = ["rdma.cc"],
+    hdrs = ["rdma.h"],
+    linkopts = select({
+        "//tensorflow:with_verbs_support": ["-libverbs"],
+        "//conditions:default": [],
+    }),
+    deps = [
+        ":verbs_util",
+        "//tensorflow/core:core_cpu_internal",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:gpu_runtime",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
+        "//tensorflow/core/distributed_runtime:session_mgr",
+        "//tensorflow/core/distributed_runtime:worker_env",
+    ],
+)
+
+cc_library(
+    name = "verbs_server_lib",
+    srcs = ["verbs_server_lib.cc"],
+    hdrs = ["verbs_server_lib.h"],
+    linkstatic = 1,  # Seems to be needed since alwayslink is broken in bazel
+    deps = [
+        ":grpc_verbs_service",
+        ":rdma_mgr",
+        ":rdma_rendezvous_mgr",
+        "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+    ],
+    alwayslink = 1,
+)
diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md
new file mode 100644
index 00000000000..37a543dda8d
--- /dev/null
+++ b/tensorflow/contrib/verbs/README.md
@@ -0,0 +1,77 @@
+## How to compile and use Rdma-enabled tensorflow
+1. Follow the regular TF compilation instructions. During configure step, if you want ibverbs based Rdma support, answer yes to this question:
+
+    ```Do you wish to build TensorFlow with VERBS-RDMA support [y/N]```
+
+2. To turn on Rdma connection, add the protocol "grpc+verbs" in server definition:
+
+    ```server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+verbs') # default protocol is 'grpc'```
+
+## Overview
+The design is based on Tensorflow r1.0. An Rdma path is added between servers for tensor transfer (weights, gradients, etc). The existing GRPC path remains and is responsible for "administrative" tasks, such as setting up the Rdma path, exchanging computation graphs, etc.
+
+During the server setup, an Rdma manager is created to manage low-level Rdma components such as Rdma channel and Rdma adapter, an Rdma rendezvous manager is created to oversee send/recv operations between servers. Following the distributed Tensorflow design philosophy, the send operation is passive, i.e. merely placing a tensor in the local out-going table. It is the receive operation that actually initiates the tensor transfer.
+
+Tensorflow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for Rdma operations where pinned memory is required. Two remedies are possible, either the memory is pinned, transfer, then unpinned for each and every tensor to be transferred, or a buffer is pre-allocated and pinned for each tensor. The former incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow. The latter incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former. The second approach is adopted in this design. Each Rdma channel, representing a Rdma connection to a peer, contains a table of pinned buffers for all the seen tensors that requires transfer. It is assumed that the tensor size rarely changes across different steps. So only one buffer is created for the same tensor across all the steps. In the rare case when the tensor size does increases, the old buffer is discarded and new buffer of larger size is created and pinned.
+
+When a tensor is prepared fro transfer, it is first converted to TensorProto, then the proto is serialized to byte array and copied to the pinned buffer. The content of the buffer is transferred to the remote node via Rdma write. On the remote side, the process is reversed. This is illustrated in the diagram below. The conversion of TensorProto is introduced to simplify transfer of string-tensors. Also since the TensorProto lives in host memory, even if the origin tensor lives in the device, the pinned buffers are all allocated in the host memory.
+![Tensorflow Rdma path](./design_diagram.png)
+
+The following improvements can be made in the future. First, conversion to TensorProto and serialization can be avoided for numeric (float/int) tensors since their internal buffer can be access directly as byte array. Second, the pinned buffer may be allocated on device if the tensor is located in the device. This avoids extra device-to-host copy at the expense of extra device memory consumption.
+## Design details
+
+### Rdma components
+
+* **Rdma adapter:** The base for Rdma communications. It may contain multiple channels and buffers.  It is responsible for handling various incoming Rdma messages.
+* **Rdma channel:** Responsible for Rdma connection to a particular node. It manages multiple buffers. A channel has a callback table which stores all the callbacks for the requested tensors.
+* **Rdma buffer:** Responsible for sending or receiving data. It has a fixed size memory to store the data. It has a queue to store the pending jobs. There are three types of buffers, message buffer, ACK buffer and tensor buffer. A channel has two message buffers, two ack buffers and many tensor buffers.
+* **Rdma manager:** Manages the adapter and channels, including channel creation, channel setup via GRPC service, channel lookup, etc.
+* **Rdma rendezvous manager:** manages multiple rdma rendezvous. 
+* **Rdma rendezvous:** a derived class of BaseRemoteRendezvous. This class is the back end for "send" and "recv" ops. When the sendrecv_op wants to send or receive a tensor, it calls the rendezvous' "send" and "recv" functions respectively. Rendezvous are identified by "step_id", a random number, so that tensors for different iterations don't get mixed up.
+
+### The SEND operation
+
+In tensorflow, when rendezvous sends a tensor, it merely puts a tensor in a local table in the corresponding rendezvous. If the tensor has been requested, a callback exists in the table. "send" will activate the callback, which tries to send the tensor across the node.
+
+
+### The RECV operation
+
+When a tensor is requested, rendezvous' recv function is called. The function first places a callback in the channel's callback table, which will be activated once the tensor is sent from the source. In the next step, a message is sent to notify the source of the requested tensor. Once the source receives the message, it will check locally for the tensor, if not found, a callback is placed in the table, otherwise, the tensor id will be placed at corresponding Rdma buffer's job queue for future transmission. When a tensor is scheduled to be transmitted, the Rdma buffer needs to have the memory allocated and initialized (registered with the remote buffer info). If the memory is not ready, the transmission is deferred, a message is sent to the destination to establish the memory first. The other case a transimssion can be deferred is when the buffer is still being used by an on-going transmission.
+
+### Three types of Rdma buffers
+
+* **Message buffer:** responsible for sending message only.
+* **Ack buffer:** once a message is sent, the recipient needs to send an ack via the ack buffer to free up the message buffer. An ack buffer is exclusively for its coupled message buffer.
+* **Tensor buffer:** responsible for sending tensors. The recipient needs to send back a message to free up the sending buffer.
+
+### Rdma packet format
+
+|type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|data_type|tensor_shape|tensor_bytes|tensor_buffer|
+
+### Six types of Rdma messages
+* RDMA_MESSAGE_ACK
+* RDMA_MESSAGE_BUFFER_IDLE
+* RDMA_MESSAGE_BUFFER_REQUEST
+* RDMA_MESSAGE_BUFFER_RESPONSE
+* RDMA_MESSAGE_TENSOR_REQUEST
+* RDMA_MESSAGE_TENSOR_WRITE
+
+### Actions upon receiving Rdma messages
+* RDMA_MESSAGE_ACK
+  * sender: mark local ack buffer idle.
+  * receiver: mark remote message buffer idle, send next item.
+* RDMA_MESSAGE_BUFFER_IDLE
+  * sender: mark local message buffer idle, send next item.
+  * receiver: send ack, set remote tensor buffer idle, send next item.
+* RDMA_MESSAGE_BUFFER_REQUEST
+  * sender: mark local message buffer idle, send next item.
+  * receiver: send ack, find or create tensor buffer, send BUFFER_RESPONSE.
+* RDMA_MESSAGE_BUFFER_RESPONSE
+  * sender: mark local message buffer idle, send next item.
+  * receiver: send ack, set remote buffer info, set local and remote buffer idle, send next item.
+* RDMA_MESSAGE_TENSOR_REQUEST
+  * sender: mark local message buffer idle, send next item.
+  * receiver: send ack, find or create tensor buffer, enqueue tensor id, send next item.
+* RDMA_MESSAGE_TENSOR_WRITE
+  * sender: mark local message buffer idle, send next item.
+  * receiver: run callback.
diff --git a/tensorflow/contrib/verbs/design_diagram.png b/tensorflow/contrib/verbs/design_diagram.png
new file mode 100644
index 0000000000000000000000000000000000000000..f0ad27455fa72bbdd8018bd3977378d2aee468e7
GIT binary patch
literal 13625
zcmeI2c|4T+`~QcII*G_qIra=gNJiNirD7sNr^%i@H3%Wg7;Uy>84R+_u~!(`V;B|L
zrR<F5lzpA7gE8iJ-_H4b&gc9-pWo;4`28NgKfnItzTNlxzTVgSx}MkTx+1O{>2b3E
z!VZBzIFb67Od*g1Vh{-Ht)CA9e|hQq_&x-39)i4d;l{(^rBUpw3?ByKN2<_yuX*y`
z?Y*c>U8gzy`p7eR(#PXDt|^%wmlr6x|C^(nlCIOKyD9dH7SfWgcar5w;pcvlIWD0W
z`s`8W>h@1Nm&set_GV+cL<eY^3dOV)3~`jE=0Z`Ud{&D&RI5hA*dnNN8#Ok?v<)&b
z@GP*B|M~eJ3;bVO;N()WQ!}%^1o5KsErqGRlk6wN*&D4MFvRph-_euPT6NPq08uL1
zZpSh#Eq^aEP}V_zX*>vl3cQ3KD7nBZ2#IyPgNxGg*||kuiS=E(WvSlkg^H@+EUWcU
z@uR$U`{f?ji;ByAF6|_Yji#fpo?lO-L-R1t{=e$rGKG_&6It#W3k7&6Q^8@WD_+ho
zP@#mt=@(P6RM$23%ROpz#rZxLK@vmQFdO1uC3~y!;59S^f(-SA9El|>)pQ|lVk&bC
zc4L<K#SCgLVLeqCjmv9$^x=rnN@IfKc)u^3Zk@45RE20%g+@V_=u%{&>3zASev<u<
zu|%BdeIW`a|NN5()CD}(wjufppBO}1=Nz&rSWc(^0YRa7Z05C<LR1Cn5Y<{(hOK31
zd2OJko+(hP5;UJU&zK?kn{jLQbds`3St}#q7|RiP%gBUO()Won;E^t4v*hOADEa8M
z@zK3A#pLk=qYDV1PuKdH52|*f-S@g8IbRM!LRc@wJ9Y%Kcm?(ak68O}8J)L{jDFF_
zAA<^DYhm7a_HoJj+x5sKOH!7^nFaO`FRtcpcX~GQY0lZf>venRDh6|o=qpE!eDUd5
zmeng&W1s3+9lyxFaR)xMbWv@zCi=MGxe^u#Tewu~AcC#M+j^y>^5smUdqzNb_lr8#
z9P%Lt49c;W>bG7Qcm)+TGEzj(I!5>Nu1c4*HZj{2f7VdrKH7C-jrs8GN=8g8d#u%E
z|Ddxgdr4Mo)fe}Dfluru)WP>i?C``We++IBi*C_e2xM!CMzTlmV3yk-7Fl8|%UmN_
zCwJY)+H0~Lu3OhVd_@Tpq?JPavn^eihi~88-&7rzEF^r>mIe3CfT{J68vOjj?+kHc
zmm6zAE(%9tMc0Y9<kGVs5q>1M(P3G`y>5P3KGlf@M@K9#yDi*%@=c}lIp<5(pY}gB
zIhPemJXy<Zd}XRy`^B2GI_ot1W;vvqBr9Y@Wq-M3?WttWVmhr>PusS%&OA5xgmp5V
z_^>T8nM0{vOO!wz!{%F#o9=t*^$@_phdg&~uT%XeGh$Oo1J7${GEn$~(<)DrQX2+t
ztX<m(Mxr_vXNfw6OL^AtR&D2uu+X54OjK$B!R@5s?xV&%YDnQiCFXL)az3d`WQqOQ
zpHE(NwG3{1Jq68O3d3VPkO%FbNGU_!-j3(c%B9!qnc_Y?isE!hC)Es*9g+sclp*g&
zq;-62cb(UE$M-61jXM<T_T1MN5PO)}@)Ncg()K73yS!1eh7s8r@@V2M-R-0YooT(T
z-4tqQ={s`R?6~+5EoH~;#>eB!yLxnO?3y;=IVVIPzBz%Uu67mOBd_*(EDt~>4mZ7G
zjCe6Cr_=gQEc&c?_A0i&?9r}yGZfTZ*GUuE_V;edy6nca5|a^d6ZfH3UT2xIIhb&{
z<zh6aOb(Brkwvf<Jw149tEcl2qWqn?=Wtu)&&R+}vDh4M!-F@`iuRUo)jt&n&8{`h
z@*Mhg;H2}1S_Pd|zc;5GGn+$m^+ft?>|{#pmZng^`p#Oi=^It5vDnsb_ADW3u=*{=
zR{0^3VEfubXJK1bGVhjmQeVc{uDAMD!!^YtF651uZN&L%As4Yb1qj2F>YfXuT+tp9
zT2UL?lMI;lMv8AW+WQSuVxdht$aCVwgsxPZ-|Mhv35>42+~&=<4Ka6SWENylx>pru
zlRUNpNMi-{Z|yE-;kVbQILt`A&3NlxYV{n7B3g8&1}1y)%BN7FW|R$yTs>_dm<l6M
z)LhZ;zTxGUJn{9CDGDWa3)z$~B5a9xNYp*4-E*M+mlWEzX1k=3A-^EclTxBk`^J|x
zOay0y_Y_HTy7w6>f^(=jmF$xQx^l|{qk!$jj<PN{<8`%^;H@vZ&Ad+ok}n&zA)UQU
zKeuWeeqhv;dq}Z8_q}w)6{A@V+ra$v4BF6@Jjn|!0j1HVL5@REai_X?NdW=nL`Iu(
z*%(KvG^N$c$Lsng!lA)TH8-gYsyj5mQGeAG?<mtC-xd(B@5pSgSsg7)FO>MyWe62G
zpu)I&=(iWA!B0&r`0O4g8$9uuTG@QboG2;?ck7v^yrk6ERjBPX#*{Z-jo?i29eeql
z;-Jtzr7G8qo^DqpFxpBIF|xvZPD4Q%Td@xJO9x|By}n5%RW|#nD|s0vUA$v!uWOUE
z;heJQFt^p>*G03gvWwq0Kn9<fPmE*0FFmO$a;+`M%8sO*-#GlG+4R+hW<1os&HOZu
z*oGmI7jOgB-(wqd4cn2yrBDeNvoyUNDv9KiGJv;fOPRb1yrfjO9=bVT8KBYb*Hzzc
zc(e>S*oVxH!A~pqqTua`;dUjN+S;at>khgys;gB(`PHZ5jy?-O@St!%W_;KmsVf`$
z_>^mhuBlMXG?xs}Evt;d*n)S}AK{iI7g$)|KbF?H@*;N`Cg7xh5lh}XT&w-f#lprV
zz_pqcmjhj!!XSC_ls;=9kqq;(B+Li}DpBm+<60Lp=YP%f2AQ5=wGFJtyXfZ>BeZ_=
z2Xtd-howcR_SB$kNv4&IR&0ql-Hhrm-N_#}6v@fX4LN>^#gw)z87B=J2*$<5?`B8|
z2#BUnY8TUsMvlbUnLAb6YBL9r@lhO8uMh?C1i7TCg^ad2sDQwDx-6W*kd4_`#QL#?
zD~no0Kl0MNrlOpX6(nV^8;V&*qLbf1=9)um6&kX3O2l=`^8^LGE<0nJ4zm3WIg$mD
z(M`sop~t$oBU7%z;ghGd!hKG~DGubew7w<Q%)k#f+>Z1do0=qe4kB;2ahFC-oYBgw
z_B=;Fmk!P5LrPwdnIIdiW<kDATN<nqWaFmts3s+uWOZ{<UZrD#UQtrOeyEl1@o>Q>
zWx0Z1h526N7i@bm6T#`vY*bRHq%JHiBGJ;9O&M>WKS1Z4H3@<I1{n%tg~G<uZYb@F
z!+PO1fl|0A_=rTS4laUE1bMKnrxJH|()lY?1*y3jaQ2fbbv?>?9H-M~oO`-KvNXLg
zU$f{Q5D?^PfjJ?_t~>6vovV^5CLwp$s<)EmlqW?BR=FdDvXz7dy}bSHb^DVhDx*2)
z?zEI-$O;UF%>i2#9|$k)^x{Hcf?PBt#|%}J9WmRqw(+3i#dukeC^kzVt*SrPk8X-S
zGG`8Z8ok>nW*w;4Ov}kx7ME*~oc|_hj;f3k>Dwzl;zAw?%E0?v+ql>4Dswkf%jg_R
z*Q-Ku&2hYak494@(6B68D4kODp71<|y1#>j$e31pU82}za*L;>DDMU}#4%$B)v{J&
z<=^6Y6ERGJ?&ikIpkz(M6DBG%RMw_$8MoJ5VVu1^<m|ubflYNHrSyx=_l-9Pjm5~D
z`{A!~6wvCJqx|5)^UOmQsiJ(A(r%?pXC&*P;QWiJQ~0r+*O6Q$-+tfs#Ayr*PW-39
zo+tT8xayXJZi=$`inww(C*-Z|5$<ca54&e{@=j|v9t0qSPYlO_Jg9Ut#PHXka8COJ
zzzd)PI1Xp1@?`6EGk#I~CxZ|OJ2x^EWJ@&nA8^U>2OQwUTVox?faiiYg+O(I-~T|-
z<M(lj1HrrQeN63Q^U}$;YIKSF+cZ3IZ-<qJ{ahRgQ#tvz;S=!wkSFI`lMTRlH5|Xr
zS%r2bOqOAq&uHc2IF3I=pN!i2<O{qOqJt|LaGSYDJ(pzdbFIv)Z)l?aW6*eIGoLt)
zgHMc|8~6}pQ%5DAkV^)}y=-%}AQi~HfgjIZ4u>9K3unCqK0(o8DZ<4DeI%+lXd_i(
z#^k9!{CbGA&OSLSb(`LL0Q=&+Et9{jtVKwi12Vh?j)k3@E&Q)9;}+b%7ql1f!9eQ&
z$0z^W1u?)b`f#Yg-!1d6w>4OKz~mAByT?ZKgv-!XbK8Z62$u~u!2bGha7!9P>pUz)
zN5_PGw~lq<=~XY%X7r4;D92B2%kICI<rHA*!Td^0Gja23i_gG#Z|(XH3Ge{G5TUG>
zaE}JhlH8ixJt}nz&(uzKPk%y<-vsvepHClhtKRXPs#r+n&o4qR5AB41*WI^~QW5#}
zgR`xjfc<aXD#<MG^c$hbX)OQrDQw>k@xT_|0TzklxCXr9w+O(a>f@uP^3!T4U!>eN
zE%xdpfsY0d`u*c;@bv_dt<<?c{{Ywu{AX3?PX__$_?vP5dgxzXrt_z7{q-`?&meyo
z7BuAkj%SpQFYUO1<aK^3NP)cLtKq@iFsu+(s8=gufUaq{yI7^c{JI#llz!mlD`OQY
z&`O=f%q`le4VB8+cW0sZtdjs#owMDZSIJNFXlGD{{prQDN=dt#)!~f^sseMQ#&s!-
z^`{cM?=H=Rt+?q-0%JPlxLD7N!!oLEey?n)1|$Z>GL^75&!u!cb>B%2werXAG-~tK
zMKLBD__-THt1UyIBAP25&jTNf<n!7~ZIbr(dPev6tH~y9eL3LAx9N`tXWS^$8q02B
zN(rurfcYV3fVNab>u2LTfukWn01C3_f*$Ck&hn0=IWEoJ#)lVUY34Kj_s+PJgBk`X
z_Sv=}2~(@;SnrK2cc)>;1$tw?c296u2Z3=_!FedAJg`!`+X_Us((!j|q=|7mdbIVD
zI+e0{OGAEPXaITz_rWTX6Zw6>ZLn8F1MokxB@500*5mbTB`{_;sCz?ed*t)<O9eOd
z%y{{B<>7P)a@xDY*<x>6RlbMEJ9zc^cojt-t=~<kcM8~Cm{&mGNxIe{qgyo33Z<Mz
zX_YbkM?K7n6ci^wq(gM6@6h8OnviY3Ux`4tnllv$%!RG>Mm}I~*;>N!@hix7V|Qj=
zk*b?hs;1*GmngPIOMC{JHh}E*_L7G&xkLl*c=uNqP@bNG%>?{>)98l`U#V4p;IJO5
zf^bdwCs`oHTd+Svlla7tMmr1btz;;2I|KQP>8oDdx9&Ur)}1EJa`)n&ms$xB6SvD=
z1Qi<^z;V=a1(Y#Vpd)uM)?uP%!Q(EI*|oCqflFRQL5JrMp5S#i)o|uyd+pr)la)jv
zC4(5vH`N=<hbRu`*s%R5b=guhdwQWpvs=IA@_FRf_;>NDxQzit@yo_T?)|AokL5Im
zxC>2H6Y$%-`FR$$>bk4Kt5lPIlTC?Fzeeau9+$!=e2zkEP+H(JuU4ta+kiXqiRoJK
zFw*o56mAEsbhl{<=R%~BP2$gwSFGWAgHx)oT>3ot01KnGhRA5Ui9&Bv<5e$!;0<RZ
zge%B<+=Vp|<V6$ONZ$3fT4W@SV~XCZG76NQNd-WlnKlIy_eC|b0kF}hK0PEFCd|oP
z>NZzITO$XRumYtC69v%UBg7dFK~rzanv2imKw4Kc(c#W>Ctnwra>%M*89jj8a3-{$
z?P~2G5PkeBRnOjJPDe)Z?vN)e%66_reZGJw^L~ex3B(F%oi{lWy0752M!j>na<GN8
zsQC9Y9<Ini@5ds(AB&H~YF>k(_)3>cGv<zq$%u8Q-r{0QMPstrW1$zGxF;j)_Dmgf
zoUE194%+NH>QtmNlgaclO!wcc2GAZq;6xI@n`;*rn@VXr3M2dDN0KHKQNnGzuKu?6
zmzl{04j1iph;MQrEt_Q{501q09g5(cd+_Q=n8`wFA*6LYK=r1;jyYKu@tGr#I~3Fb
z99v!x)ILn-ScS??I;O?@RDU<I2+)Jc>UiP0N}c3<BGcw<SXx&FcuZoB5$=L3l}qZP
zAJ*44FfL;KiLGUlO+PYxZ^afzhd0XTCe0txNeCpi-EZ<jps-p_uO=MX^h57WvAgxB
zU}R8d^H1xP8qeLC<$CfGjWOdFi+HwC(}wW`&=M*Tn9S{x=BL#w)9`cK+2`wbq~fG|
zRv)%O`DJcO0n4e#GTLpAee;?<Hmcny#?kt0LV;*}QSqXIz|~&~yJm(?hJNWIj)LZ4
ztWTeb2(B^%t5a$?^iDIGF5xEW4_w|k&!x>)$(fDw#yOPmo%NT#RF}Tn$z90D0xsk7
zhIY<`-5XY@?s@Mh7ZQV!hQ>JAws98$D6TtSP#S?0Al?jpC8GnDlOK=@uukYNQyI@W
z%%ycXP47%je}Ntm@QQPFD%`xG`?gyn*~$CV=T#ZUw$2Q|U>xHuz_EN4hPhTjhO!}5
zyK&sTFeab;6x}H$C=7R7TIbemiFmV*8gaFXeCRaIXbKaZ+L81*%*6O(n_}`$TFuV=
ziS{`U(2f@F!=F8^E8`u=<inU7EVv2PW^|~nbBkGjJY(7Kac{Ab@n(g{Y9hWM12BzT
z>chD2xnnU_MXi2q{NLd@atpHv6GIl(5U9gTSRIsApfnaLrWMF&JcYzyhu_+Y5V41H
z<dVd>1@I<ai20;1blr3F;-Hl|7}5=8JM{-lvGqG&Q?_fF7}`eCZ;}EX6@#%8r%<Ba
zLenP;V$y{SretoJF>H*h-y}EIk^Sq+gYd2l*82L+J*r2-R2$b%)0oYIy``8P9|4;o
z4=2IwepE3`_ymLTs@|4)+O68NK6vmlLOT*XN#_uk!Uum(<`;zA&Wn#)O-QQl*1q2&
z?tK?amuHTFRHQTj?=8iqC4Nf+`6%ScSpc-;RT^QBy5gPSM@fhdJ8@c?I^V69bY6+*
z!|A+0T4#@~yC<NLE_l=Xz&!z%49}TJZsb9A*cE;;b=ZDd3+yw!KiLZRZ!^L?d3OrH
z4Ax65r<8%Toct9J>&SUOiy)dF1JHye3$iaRbUZY$Q%s)B>)x5}F+qe#0nl<i<lpjm
zD1eTS0HFGB7eHbFu;}1FAoOnpu@9cfx=T4zuo%Q+`YQ+g(mut_6q0rJmC1b~dP{Qy
z>i*P*qll~_KuL*P+ns&wLb(XV?js4vz2Thn-b(6VpUvTFP{2sC!J+_e20rh4h(4TK
z(79^q;dZBQc;B-mtKr(MdCK2;OCE(F*7{<~zLzI<YFUb!{J~q!d3`F0N2RK_Z{S#t
z$SNmPZbt{EIyI;Nw)MZk^w)}gj!I<vPaG9!4z3Q=CjY;4RLHXwM>OU^;M%0O#*Evk
zAMG>DVRSV%6##xauB7d+W?jk@zr~kTCEq}l8Y$a2w$EOB+H=C#?|Lu8f;UtBQys1L
zs(su*lG`sLgnoxgpb&Tw{~O+;8!8nj_SK;^EFHD>LZhA%G*vK;4m@NVz!+WG_LdZ{
z3su1^5Pf%ygIDM)t|<y@Gi!!jK@2yiVGw%8Zw~Wz)ecsIYKFBtPy&=wwl_qXW9Z=^
zOAV|IEzn72+|$Dw(32_&PAPJP4VUpX;kLe*vd`9|m1yi>@bbg$YXGn}oU_+eQ}$jA
zqQ()7lMw`_`v8LC0B?4Sp0jJl=DCmqEWYdGTNSO1WlO$+x=_qQDWfCaL8U089;hZ$
zA6}29(JQt3-8I^sg|qE-u?Pd9KV1BGAmB}cPPN3=5ux3>1wzXi()`QKpMpQnA9-1k
zWUFPzXg5}?*k0wd^&r1nX_|f_YN2WKbdr(cLZF<fa&O3+ZU_@m&k%w}v<AeVob<^{
z%;-y7RFaQ)4{Bq|b}ha}^w~Gc;;ogXF}XjtlJi_PzgH04iF+SIl{0Tu&o55bo<C7V
z`K(?VgA`d!FU=^2b1Y1pz?7yV4mTwZ5Iv&Rorj*X>H(+hZ6La3C1GM7WX+f%A>9Uh
z-DrjoDooDeoXyjO?X@Z)Z%O6JGQdH<h|H*Ely&l$!Q~g^n*D?Y0pZGWV<i>Sxhv`J
z1g7Hmt?D<^Dog!W`l`a3LK~wEyK4Q@8S{QQ`-O({!{p1j5a_Iv1&+S{IM1b|5K$NB
z6o9<4s{To@`jiPzGW4L@BbdPJ;(*CFSEhmVzejr&cq_i#YAyH01oBE<jEp+&TrcYG
z)sJ6AomU!~3GVgAJh5<($f819u)qPt9*0+Fj+gH3*c5bFMt=~BnTot*p?;j_gz3}X
zXICxot-C(xS20Uy80y^ogvv90wSF_l-V*WgOgTb*ghm7c_;#`}rZ4+}*B=X3Y;9@P
z(0V-8n1&Z=f^s-9HaHhjY}jM@q>xhG2U6@a`G8tHfM>x?>YFtEB_Q7PmZrmRWbrA@
z^(}fR$sfLm74pl=SDl9mAl>JoZwfWW&B1kQl!y!SMsQZ+;x6X#2?4${JxUpj-7w}C
zG`$5#Tamb6v_4Bb<~Ueeci7|Nc>Rrr{cWq@jp7T%VX_;~A8~xUdNJ0f6?$y$b;m05
z#=P6!ItrT{>$o2O6rI1kA=mIB)J^G=5q29YE^twC6q;@IA`Uo-9Y*LD{uquAhRr(B
z3{af4U~9QiV&|a_16nHgxg`av5!q=S?c9ZF2@B%4vT;$+<1%t><?B{?s?rT|^RDK3
zBF>6{z!w`HU_Df2hrEJKzASL`W~g~DQolciZp={zl%QpUv4OSoKo+))PWQ4hJ}&6B
z+mF)gz7Xd?Yz*D|(SchxhAFRy-QI9$c>68IDI*wtR+ASf)`QjOrk<wp1_waB*|N!B
z%SYau8~qMQqT8uYB~hYnotctqt>v|%KG}0`0a^I!*EpzHaeL339pn{EwyU-=E+HqI
zPe9-l&Mj$WcIrw`$HEn}#?Y_wp5mW^mYdrYaTC(bhCFmQw16Y*rJOQUBq8kjq-w{C
zt@?SdsF^LI=yenJD3cQLg*MOhTR6<=W)Fn2<9KNd<Is(STR3|d<xI(k1E)cr$&a@&
zdMnU$J@VN)1h40IG?2q%R5ux#MSgTHLf1ddnc0}`ucvjl3@Wif(|?46jjoG^4UIc9
zO;W2*<cdhxh>$GD-&u%Mq);bJQo}j3v5iW|05$Cey$7&N>iBa-bF_4`OaB1Vv!PYw
zs}>+_y)m>$G8Kmb=VXL0?-tP%i|H|o@-efYXGFeohCnAE>FY-?Ezt5rI_NS$KK)Zw
z@UC^b)i`$aIS$zPeNki37<r{h&JQNzhA7|bse=jt%m9jor(3^F>wW)zU1#R$=fhy7
z6`lvV@Mi<jyL#$ty?>B9^R+5a|0kD-$FvHs6Qw4V{?uT)<YT^eoPDMlpj7`U^Jmb<
zIOxO!sE~a&pm(oY?d`t19sHwUS42NpLg#NFaL3?7ob64U!0kTXU~a>|Uusw2%YWZe
z%n|ca?{63M|AL2+<TI_nQvXnL{wX_wLd3so2KyPvTe0=z<^gN(>KCfjpXthdF>jDS
z1_yP6kSC{Fla1>_j#MXkt25Ju+5gWf&3{*U{*(OEfUb<k6EydvhqI&2D=F(iH}eQ(
zJFI}6p#p%42IQf!s-hTokE0s9;m&~6aOLw_c7BH2k+@xe<bdWY1T^ZR8HMbC3Y44#
zt4|vsPqLx@Rl8w~fc++o#{C#>6ST3!21<ATly(k$@!FVZyEmXBYh*bDr0X~~vp-3s
z`ahG%+uMI~42Wlcbdrs_|70yl2wS3~9Ac_>G}O5B8J+=Zc>DJ4D1kZ<H!D9UViEt4
z!~QPL{2!sN$uI+o1sYLH3MUNqHL{K1-N_U69vMtZ{hzu{M+S6_Hh|qF{RM{laAz7c
zMdb5toy&s6gjPD(RlyK!&jx=_Ing}RHN(YNR9Oo0)7rWZvVd|uWuxoxLPuW$(6vE%
z14`lW_JZrs&HfahfgX_Xh_HpA-@m~?s5<0{-WroYcvOzm7i~|iVMpbaCq1GW0V&<9
z2D9ewATk0V==ZU_F~qRFChSva<Fm%&z=lp=z^#0Gp;C!Oti0|jhta_Xp7(BqsOz|`
zALH860v~4zLMt<c26l>w+ucpmU_w?~{2|fwJQiqQWvVTV^mU@hJDp(i1~*GB<Bp?-
zIjY|u?7|cqbj?o@EC&j$e0;q@L1Q5SFuhO^Pzw1+&k0c$l+kyAtZoh>m4>+Tq$+r{
z+@f06c?98DeCG<m)m2hCXgjpHYRx%=N2SdR=mDwJQ{x5!Tz{{r58r8!17S6=TVX9H
zX)X5h^WnftS8*n^OBv<6lbdaqZD0}l3NALVYW>o<OQ%pK+FO~Q+&s<j2BzMCv=9Xo
z=s}*@x851$oG0SVZYeXZGrpD0KEi@+{2f9}Al>np0t)!61lPIuEnYHc9@}|4n-j&1
z2c)7J?=0*mK#~pdPjG~~>MO?bVqHbG*T4AQxrBs+{=DLoR<`rlclMGE?6v+yT*1jX
z0r|S*lnk4W0E$3foLeGP;!`A`$9?<#1!pyT+;=gQszE#d7j@;{2IBnS_3!$`=#jmp
zq6x0LlR%$B;iXHpRRq9lw?$wA0>31VHA4^jv?s>m9D5gD#YcG&^g?gs48RZT1t68`
zu+Jp4a))jr;kXWv_>V3R#u|mz0k83TX?Jt(q5e1Eqo#ooi`y)~v+U!bnH}*Wok93H
zAahlp#%<4wd|5|Nhpuz9UsW_UZdY!*e-u3PA4CPcs5ddWaFypSEzY4gjgAYp1#PSl
zjz$W)4|R{d%NV?F`)jANSMFudDrt_ujK{0Z;HZ&;l^$D0=2E&L&vl|zf=Juz05zrc
z(<*V(7~Jzh6^GjPaz3a)aGWf3=t`v|aClW8GRKWPV<(hQvR*z)x^gOR<RW3;#K$=!
zhT6ZP^{rZV&CfI_we!cqZS!0#`qSNjQu1t9e3VsY3zDqqrJ{S<6bhIkW>`MvlOdNg
zLO}nq3En`#8TVW+i8U|S&;^&kCz^uy>Foh~Oqrle{e-G_)w=`#SnRSc(I)S#K1{dT
zOEG;gcDu|f)HCsWwiOaNB25eI?E!K`az3)j>fS#ii(0bd<k#fUi2Ftri&bEI0jL^T
zLe%iU^a@bu?H8ng&dl<7UuPC;{qpw;=nrKW*+j0?>@V9}DXnwe$kqYH=C$lZQN(DH
z4^Y$4pEf_*@1~a@uK!qD<p#>BG`B#*e^i_L@bV#JILI;o(B`8u<?Cz6>n+*Cz8fc$
z4fWx{RSSg9Pxla`pc3!azP?)%FzhO7-3Nv&QCs8SAqH^1M~jlYd)3?EV8jh3FTUd)
zwslJakE^CGF&eQV(HX|2RSx1mfm;xa-QW=_@8Qu{y*IB$TsyL$me*eVE??3JIP9D5
z0F<=Qg_!2efU18^M6h(%e1_<L)V%zaIWU(17M6dM9btG@F&Go*Z4Z(6wOC}6T7kkH
z@Jy%jO_F25=DYIOmU2Fxfgsn3ts1y(3VsS#JL!Y3-t}G>ZiFuAd#!@y*l{VO%~kt2
zj6Tomm#$!-L5Eg-FVChE-L(q@5eHH<ry9{L1Q@d5<CzQ6k2AIZP{WzpG>~}9WNe`1
z8!utUkfIjnv+Fy97``ac!z0L()D|=p))-pT@JB;I&yqbHdSnX2^8zXioGR|gaXA>X
z*4W@;tSo;nN(p&kZll8s=!GG#utyhWljqJYARm_iB{Ic6r%*%w*;O1rcz_<c$(;1k
z#M9K85I`K_wb!-RQS<fQC~XXVfS&XL8uh$WnS7GUIkPQ&FVvOM<0S>wC}P_t0N3e%
zErcgZbK}jPobDN~k6nr3yQz|PvsG{p+*8dyDU@Q{j{^19mVhe^k_%~@ZECMZi31cR
z?*9Oqu2kSQqVw*LmHboyqqBsG)%gI@|HAMYV2cAQa)*p{ue@fst|(3@#|6%G7O*kN
zoAJ@Ar002z_H7zEpS&C(Pe(Srl@&u?HGMnDu5Vr5tM%Ld=~<Y+GwcY&^nmE`)xA*5
z%JLFPpuKxW?^CGNwbx-!^KK)pS{b4nm1Ug)fneVyN^KY}cM5q~K;&!vA*f2U)Zx~%
zk>y7es>`oQx2-xA`dTn3;*3BYpgWratIL$RS^9!wPMQ`1J(lwM#Vq*BArAqYU&lr9
z4nV_HEqWo$G)Ll8bAY*avp4a+v#`&%>%a$}(LJLCl@;chOV%%m`({2KI5{)nDTUIR
zzj$T8OB80wvlsQ%s5Dlj9m<l6g2VQp{JAaqD{);ry61OjZtH}{i63({=cG{9ZSn5#
zAlZg@tNwYku%P{zy2!iQ#yjuDF;F6o((p>m(!4+gG9dN1plQN-aTZt+#qP)i0afT#
z{mN84riDKR4$PMUaOB^Quy4?G-{|WMZcluUVfd`e5o8M+@~)j<yt;pbj1KqyWEH@K
z|4@{UvH|4vuY+-a^e%xe*<ZBt*JZ;y+v(4jf6D3nhG#~DA{sEf6!Jv7^~-T1p8t%@
z|M*{Ec_3hdb6v=$i+AexCdN~N=9z}<Lhzmk6QuHvubU)mCK{Z}DJ%>Gl#G35JAOBC
z)}`JRj1o%ZF?pKDJ>Uq}?KfMKzkrRoeaQ!>xQ1-nX9Zv;=$|SJj>F9!R|15am-{^7
z_%#bu7Dm9VZ+tS#2blQS7rWfRrs>-E0*@GD4UjpyC1UPc<~iND9O~vc2KG?EhA9|1
zy;YS9DQ{nE;k^yV@3-R*>z1Ffq0cEhI}V?>9aa?)gf9UI<Mpu|)W!DgU6<rIIbJq$
zXcbIT7{d2cR1{G?P}9FfOBi|jD%B|x4~6-sH%^qt5PuO>A|~;c*abI+UQ6}3r*XVe
zy9QG=7@K(`rE+3QB~__QyXwW^6nt_VkieT(&5YcDmY5xA&CpF7T-;d8P8kQOWd=~9
z?ix=5fny1nIl@p1QU(UBv1X~aCWB)D{tB}@4#a;=<;h5WZ6<(nZZY*c7kI%=+R9+;
zcL}zu>d+mo?0wx$AMPtSW6d60&QUavI4{fGn_Dro!#SS1tgqdEzmQKJJSxuw1x@cm
zm2-Y>RlvL4aUrGw=@Lj8GVxK}?dG;DuLISTtYjVP3|8Zy6F?-40u{52=Fr!~Xy<_H
z)wqG`7{h%tT}yFH@Ej8uux|2##Ywm7@<i5tE-xRsVbK%<Yt_GC%FYMIxZOz3>Y0KO
zQ9NA-Y}syJ-~$8>S_G%0J5<@xX5X(^-R3GfzP#QktL>WE=s~xc-S_HoW#pvtK>?8^
znxpeHN)K1KJ|GVS@nlipbSbx){b*YO=6Hx1(|PNwPiq?^Oz$Urvy{2LU$G6R-32r4
zABXFo-uusk0gjRW+vWdsnEu80H-{`o4_~nyT#^rCt>6m!{Q$GsGt-AD`d@}%-#_-+
zyJa~~V1YoUtGZ$;1i2nF9y%0SPa+OM77=rdKs$WZ%GNFVV(?X{(K>-|`l)D41qB#|
zc>BYuH$B9qC%U9E7!j1@TqsW?@Z~M7mu#jJxgR<e775&Y-Zb3lu%uv3s-~mATpL&I
zcVC*u1{4;}PQM;MF%wk*Bb_X(3ls&WV+U$BJZ81qOXzoK+e&79Y%(_ez9mzM?(~Yq
z2||r4S5$>$9VQ^pf_O$}XPj}b&+ib#xdc+3BYX0U0Gzq%ZV_tSJCS$7=kSNOu0U7g
z)97Z>gZ4tfd(E5EYL)aLM(p=f?byoH;PqHm$6hikRg}q8dPqI(gxlmLJ%+Z@n``j0
zIeva?GY@xYw@78OHm3J$Dg0+;YiP#4YsA?+x5emBFGfp&-@6_>Yn&S`*}YPgJ~ir@
z?$ni>Q}(TznEWm$l@{qf+-ZHxJGg3jX|z@{W0yopT5P6~d=gnF>2=h-{ACg=r0hnB
z+;ZeIgdc1aRRhgjhuC~gpY_HW8BbYSe=L1xIh}Imn15bs<%398M}NOnrgs%tc+@NR
zXY(F6YNmUt56+~rI@=)GW{zKsP_%izo#~ep1>uch=BeNGE@b9wY({$5<kazrQE&A0
z`!%B(!FsfvWun+?EpvXQ4er7&zstHybZfz1Lwf1&>MRBXM|UD~t>Y5yr|<o$XiBD%
zM(Z2CFOvvn*KB*}WXOAt6Jm4xK^3P}@ehy}1Wv8=-~L`Fggn~s1KWfh;te}pd;VcP
z%JU5{HQT*vk*MittNd%vRg;o;lmP^aqFH}h=eXaR^-cLeZ`4T^(g`t!gV!yM5{+nP
zx2&qoNs|x`QuD+nFrB<k(L=9QBiUo8`se}How1)o;T`jtcj=#wjFz{O>0kFyEFnWO
zI$rjBMjkr{Mv<hnv=cG6y1I!v`dNem0p`#X)*KOn^|Y{Z&RrAQP;(8Yj#NeVtJ&_F
zn<mxWO<|7jc{k($BXY|(*P}_2y6jP}$C3(SDnuU?*`AwhYb1m=CVa5r6Jv@j1~c{5
zp<d@FK7XL=Hb$aqwquIei|odQA;1J}4q|M3d3kNB?ozJr6~D*W8ld;5(1N92V1b`B
z4H42f8Eod?!+)}2bK`~k9_~hwXm`k0wwB*ye+s@{XKc?L2Hue5lg{d-KUFeHmwoea
zO^081KGXI&>ErMHb<#Q_)ZwhVgt57Pf1?uxcp4cS@hg_N`I<S(5@If`W2#wu7Amkk
z{5xai>kO)gJkHNj#U&NJVWU1+CM(v!9(w`R*sDwaOzkJ-PzKajULV23r%591>2@~b
zD+IwjIURV=Tpv+V99wXjV>1ULBYcwTOQ=&ye%QCVHd@)fJ0-PoqQ-%l@?o}sE1qAh
z+1GQSV%ejP`Gok&KcDj1_8`@kMP{5)v0R4mqkEE>*C4ky*<-Uw@f$IuOW#+{9G1})
zcc+ed%(91cq-M1Hj2^~Vqt4b|uNzNdb=0kTAxN5>eDQ&f(!%?;exMzq#^^GQ6B9;8
zJL^mCdH7%-%^~3#-=gU|Cx>&Ne63dwUcCK1o{E1Ud`M{}P+L}Nh?zU;dj(Ua<UC6{
z0Fj6c*%=5odpa`AFK#P{JLkA{Dbz*tZ1PrkhtCrlRa`AQ)x%B24dVwF#s9SKVY<{6
zXo7Ow{msI*uGr7Fr~kpY!(1#zxKi%F#K!-*^gkB(f3v{-*lnihE80Ph@WL7J3CPPv
Kmxvc{KmK3LB^B!c

literal 0
HcmV?d00001

diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.cc b/tensorflow/contrib/verbs/grpc_verbs_client.cc
new file mode 100644
index 00000000000..608a9140d3d
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_client.cc
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options,
+                                         const GetRemoteAddressRequest* request,
+                                         GetRemoteAddressResponse* response) {
+  ::grpc::ClientContext ctx;
+  ctx.set_fail_fast(false);
+  SetDeadline(&ctx, call_options->GetTimeout());
+  return FromGrpcStatus(stub_->GetRemoteAddress(&ctx, *request, response));
+}
+
+Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request,
+                                         GetRemoteAddressResponse* response) {
+  CallOptions call_options;
+  call_options.SetTimeout(-1);  // no time out
+  return GetRemoteAddress(&call_options, request, response);
+}
+
+void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
+                                  int64 time_in_ms) {
+  if (time_in_ms > 0) {
+    ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN));
+  }
+}
+
+}  // namespace tensorflow
\ No newline at end of file
diff --git a/tensorflow/contrib/verbs/grpc_verbs_client.h b/tensorflow/contrib/verbs/grpc_verbs_client.h
new file mode 100644
index 00000000000..358977f9254
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_client.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// GrpcVerbsClient is a client that uses gRPC to talk to the Verbs service.
+class GrpcVerbsClient {
+ public:
+  explicit GrpcVerbsClient(SharedGrpcChannelPtr client_channel)
+      : stub_(grpc::VerbsService::NewStub(client_channel)) {}
+  ~GrpcVerbsClient() {}
+
+  Status GetRemoteAddress(CallOptions* call_options,
+                          const GetRemoteAddressRequest* request,
+                          GetRemoteAddressResponse* response);
+  Status GetRemoteAddress(const GetRemoteAddressRequest* request,
+                          GetRemoteAddressResponse* response);
+
+ private:
+  std::unique_ptr<grpc::VerbsService::Stub> stub_;
+
+  void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms);
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsClient);
+};
+
+}  // namespace tensorflow
+
+#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.cc b/tensorflow/contrib/verbs/grpc_verbs_service.cc
new file mode 100644
index 00000000000..e73b2700bd9
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.cc
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "grpc++/alarm.h"
+#include "grpc++/grpc++.h"
+#include "grpc++/server_builder.h"
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
+
+namespace tensorflow {
+
+GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env,
+                                   ::grpc::ServerBuilder* builder)
+    : is_shutdown_(false), worker_env_(worker_env) {
+  builder->RegisterService(&verbs_service_);
+  cq_ = builder->AddCompletionQueue().release();
+}
+
+GrpcVerbsService::~GrpcVerbsService() {
+  delete shutdown_alarm_;
+  delete cq_;
+}
+
+void GrpcVerbsService::Shutdown() {
+  bool did_shutdown = false;
+  {
+    mutex_lock l(shutdown_mu_);
+    if (!is_shutdown_) {
+      LOG(INFO) << "Shutting down GrpcWorkerService.";
+      is_shutdown_ = true;
+      did_shutdown = true;
+    }
+  }
+  if (did_shutdown) {
+    shutdown_alarm_ =
+        new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
+  }
+}
+
+// This macro creates a new request for the given RPC method name
+// (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on
+// `this->cq_`.
+//
+// This macro is invoked one or more times for each RPC method to
+// ensure that there are sufficient completion queue entries to
+// handle incoming requests without blocking.
+//
+// The implementation of the request handler for each RPC method
+// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
+// to keep accepting new requests.
+#define ENQUEUE_REQUEST(method, supports_cancel)                             \
+  do {                                                                       \
+    mutex_lock l(shutdown_mu_);                                              \
+    if (!is_shutdown_) {                                                     \
+      Call<GrpcVerbsService, grpc::VerbsService::AsyncService,               \
+           method##Request, method##Response>::                              \
+          EnqueueRequest(&verbs_service_, cq_,                               \
+                         &grpc::VerbsService::AsyncService::Request##method, \
+                         &GrpcVerbsService::method##Handler,                 \
+                         (supports_cancel));                                 \
+    }                                                                        \
+  } while (0)
+
+// This method blocks forever handling requests from the completion queue.
+void GrpcVerbsService::HandleRPCsLoop() {
+  for (int i = 0; i < 10; ++i) {
+    ENQUEUE_REQUEST(GetRemoteAddress, false);
+  }
+
+  void* tag;
+  bool ok;
+
+  while (cq_->Next(&tag, &ok)) {
+    UntypedCall<GrpcVerbsService>::Tag* callback_tag =
+        static_cast<UntypedCall<GrpcVerbsService>::Tag*>(tag);
+    if (callback_tag) {
+      callback_tag->OnCompleted(this, ok);
+    } else {
+      cq_->Shutdown();
+    }
+  }
+}
+
+void GrpcVerbsService::GetRemoteAddressHandler(
+    WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
+  Status s = GetRemoteAddressSync(&call->request, &call->response);
+  call->SendResponse(ToGrpcStatus(s));
+  ENQUEUE_REQUEST(GetRemoteAddress, false);
+}
+
+// synchronous method
+Status GrpcVerbsService::GetRemoteAddressSync(
+    const GetRemoteAddressRequest* request,
+    GetRemoteAddressResponse* response) {
+  // analyzing request
+  // the channel setting part is redundant.
+  const string remote_host_name = request->host_name();
+  RdmaChannel* rc = rdma_mgr_->FindChannel(remote_host_name);
+  CHECK(rc);
+  RdmaAddress ra;
+  ra.lid = request->channel().lid();
+  ra.qpn = request->channel().qpn();
+  ra.psn = request->channel().psn();
+  rc->SetRemoteAddress(ra, false);
+  rc->Connect();
+  int i = 0;
+  int idx[] = {1, 0, 3, 2};
+  std::vector<RdmaBuffer*> mb(rc->message_buffers());
+  CHECK_EQ(request->mr_size(), 4);
+  for (const auto& mr : request->mr()) {
+    // the connections are crossed, i.e.
+    // local tx_message_buffer <---> remote rx_message_buffer_
+    // local rx_message_buffer <---> remote tx_message_buffer_
+    // local tx_ack_buffer <---> remote rx_ack_buffer_
+    // local rx_ack_buffer <---> remote tx_ack_buffer_
+    // hence idx[] = {1, 0, 3, 2}.
+    RdmaBuffer* rb = mb[idx[i]];
+    RemoteMR rmr;
+    rmr.remote_addr = mr.remote_addr();
+    rmr.rkey = mr.rkey();
+    rb->SetRemoteMR(rmr, false);
+    i++;
+  }
+  CHECK(i == RdmaChannel::kNumMessageBuffers);
+
+  // setting up response
+  response->set_host_name(
+      worker_env_->session_mgr->LegacySession()->worker_name);
+  Channel* channel_info = response->mutable_channel();
+  channel_info->set_lid(rc->self().lid);
+  channel_info->set_qpn(rc->self().qpn);
+  channel_info->set_psn(rc->self().psn);
+  for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
+    MemoryRegion* mr = response->add_mr();
+    mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer()));
+    mr->set_rkey(mb[i]->self()->rkey);
+  }
+  return Status::OK();
+}
+
+// Create a GrpcVerbsService, then assign it to a given handle.
+void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
+                        ::grpc::ServerBuilder* builder) {
+  *handle = new GrpcVerbsService(worker_env, builder);
+}
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_USE_VERBS
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service.h b/tensorflow/contrib/verbs/grpc_verbs_service.h
new file mode 100644
index 00000000000..aa509602b51
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service.h
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
+#include "tensorflow/core/lib/core/refcount.h"
+
+namespace grpc {
+class ServerBuilder;
+class ServerCompletionQueue;
+class Alarm;
+}  // namespace grpc
+
+namespace tensorflow {
+
+class GrpcVerbsService : public AsyncServiceInterface {
+ public:
+  GrpcVerbsService(const WorkerEnv* worker_env, ::grpc::ServerBuilder* builder);
+  ~GrpcVerbsService();
+  void HandleRPCsLoop() override;
+  void Shutdown() override;
+  void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
+
+ private:
+  template <class RequestMessage, class ResponseMessage>
+  using WorkerCall = Call<GrpcVerbsService, grpc::VerbsService::AsyncService,
+                          RequestMessage, ResponseMessage>;
+  void GetRemoteAddressHandler(
+      WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call);
+  Status GetRemoteAddressSync(const GetRemoteAddressRequest* request,
+                              GetRemoteAddressResponse* response);
+
+  ::grpc::ServerCompletionQueue* cq_;
+  grpc::VerbsService::AsyncService verbs_service_;
+  mutex shutdown_mu_;
+  bool is_shutdown_ GUARDED_BY(shutdown_mu_);
+  ::grpc::Alarm* shutdown_alarm_;
+  // not owned
+  RdmaMgr* rdma_mgr_;
+  const WorkerEnv* const worker_env_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsService);
+};
+
+// Create a GrpcVerbsService, then assign it to a given handle.
+void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
+                        ::grpc::ServerBuilder* builder);
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_USE_VERBS
+#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
new file mode 100644
index 00000000000..e0ba78dbfd5
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
@@ -0,0 +1,68 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
+
+#include "grpc++/impl/codegen/async_stream.h"
+#include "grpc++/impl/codegen/async_unary_call.h"
+#include "grpc++/impl/codegen/channel_interface.h"
+#include "grpc++/impl/codegen/client_unary_call.h"
+#include "grpc++/impl/codegen/method_handler_impl.h"
+#include "grpc++/impl/codegen/rpc_service_method.h"
+#include "grpc++/impl/codegen/service_type.h"
+#include "grpc++/impl/codegen/sync_stream.h"
+
+namespace tensorflow {
+
+namespace grpc {
+
+static const char* grpcVerbsService_method_names[] = {
+    "/tensorflow.VerbsService/GetRemoteAddress",
+};
+
+std::unique_ptr<VerbsService::Stub> VerbsService::NewStub(
+    const std::shared_ptr< ::grpc::ChannelInterface>& channel,
+    const ::grpc::StubOptions& options) {
+  std::unique_ptr<VerbsService::Stub> stub(new VerbsService::Stub(channel));
+  return stub;
+}
+
+VerbsService::Stub::Stub(
+    const std::shared_ptr< ::grpc::ChannelInterface>& channel)
+    : channel_(channel),
+      rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0],
+                                  ::grpc::RpcMethod::NORMAL_RPC, channel) {}
+
+::grpc::Status VerbsService::Stub::GetRemoteAddress(
+    ::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
+    GetRemoteAddressResponse* response) {
+  return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_GetRemoteAddress_,
+                                   context, request, response);
+}
+
+VerbsService::AsyncService::AsyncService() {
+  for (int i = 0; i < 1; ++i) {
+    AddMethod(new ::grpc::RpcServiceMethod(grpcVerbsService_method_names[i],
+                                           ::grpc::RpcMethod::NORMAL_RPC,
+                                           nullptr));
+    ::grpc::Service::MarkMethodAsync(i);
+  }
+}
+
+VerbsService::AsyncService::~AsyncService() {}
+
+}  // namespace grpc
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
new file mode 100644
index 00000000000..f7ea774b661
--- /dev/null
+++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h
@@ -0,0 +1,89 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
+
+#include "grpc++/impl/codegen/async_stream.h"
+#include "grpc++/impl/codegen/async_unary_call.h"
+#include "grpc++/impl/codegen/proto_utils.h"
+#include "grpc++/impl/codegen/rpc_method.h"
+#include "grpc++/impl/codegen/service_type.h"
+#include "grpc++/impl/codegen/status.h"
+#include "grpc++/impl/codegen/stub_options.h"
+#include "grpc++/impl/codegen/sync_stream.h"
+
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+
+namespace grpc {
+class CompletionQueue;
+class Channel;
+class RpcService;
+class ServerCompletionQueue;
+class ServerContext;
+}  // namespace grpc
+
+namespace tensorflow {
+
+namespace grpc {
+
+// Implementation of `tensorflow.VerbsService`, based on the
+// definition in "//tensorflow/contrib/verbs/verbs_service.proto",
+// and the gRPC generated stub and service classes.
+// See the proto file for the definition of methods and messages.
+class VerbsService GRPC_FINAL {
+ public:
+  class StubInterface {
+   public:
+    virtual ~StubInterface() {}
+    virtual ::grpc::Status GetRemoteAddress(
+        ::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
+        GetRemoteAddressResponse* response) = 0;
+  };
+  class Stub GRPC_FINAL : public StubInterface {
+   public:
+    Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel);
+    ::grpc::Status GetRemoteAddress(
+        ::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
+        GetRemoteAddressResponse* response) GRPC_OVERRIDE;
+
+   private:
+    std::shared_ptr< ::grpc::ChannelInterface> channel_;
+    const ::grpc::RpcMethod rpcmethod_GetRemoteAddress_;
+  };
+  static std::unique_ptr<Stub> NewStub(
+      const std::shared_ptr< ::grpc::ChannelInterface>& channel,
+      const ::grpc::StubOptions& options = ::grpc::StubOptions());
+
+  class AsyncService : public ::grpc::Service {
+   public:
+    AsyncService();
+    virtual ~AsyncService();
+    void RequestGetRemoteAddress(
+        ::grpc::ServerContext* context, GetRemoteAddressRequest* request,
+        ::grpc::ServerAsyncResponseWriter<GetRemoteAddressResponse>* response,
+        ::grpc::CompletionQueue* new_call_cq,
+        ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+      ::grpc::Service::RequestAsyncUnary(0, context, request, response,
+                                         new_call_cq, notification_cq, tag);
+    }
+  };
+};
+
+}  // namespace grpc
+
+}  // namespace tensorflow
+
+#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc
new file mode 100644
index 00000000000..53d840f5d1c
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma.cc
@@ -0,0 +1,874 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma.h"
+#include <cstdlib>
+#include "tensorflow/contrib/verbs/verbs_util.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+namespace {
+// hash name to 32-bit integer
+uint32_t NameHash(const string& name) {
+  return Hash32(name.data(), name.size(), 0x1234ABCD);
+}
+
+// convenience function for printing message
+string MessageTypeToString(RdmaMessageType rmt) {
+  switch (rmt) {
+    case RDMA_MESSAGE_ACK:
+      return "RDMA_MESSAGE_ACK";
+      break;
+    case RDMA_MESSAGE_BUFFER_IDLE:
+      return "RDMA_MESSAGE_BUFFER_IDLE";
+      break;
+    case RDMA_MESSAGE_BUFFER_REQUEST:
+      return "RDMA_MESSAGE_BUFFER_REQUEST";
+      break;
+    case RDMA_MESSAGE_BUFFER_RESPONSE:
+      return "RDMA_MESSAGE_BUFFER_RESPONSE";
+      break;
+    case RDMA_MESSAGE_TENSOR_REQUEST:
+      return "RDMA_MESSAGE_TENSOR_REQUEST";
+      break;
+    case RDMA_MESSAGE_TENSOR_WRITE:
+      return "RDMA_MESSAGE_TENSOR_WRITE";
+      break;
+    default:
+      return "UNKNOWN MESSAGE";
+  }
+}
+}  // namespace
+
+ibv_context* open_default_device() {
+  ibv_device** dev_list;
+  ibv_device* ib_dev;
+  dev_list = ibv_get_device_list(NULL);
+  CHECK(dev_list) << "No InfiniBand device found";
+  ib_dev = dev_list[0];
+  CHECK(ib_dev) << "No InfiniBand device found";
+  ibv_context* context = ibv_open_device(ib_dev);
+  CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev);
+  return context;
+}
+
+ibv_pd* alloc_protection_domain(ibv_context* context) {
+  ibv_pd* pd = ibv_alloc_pd(context);
+  CHECK(pd) << "Failed to allocate protection domain";
+  return pd;
+}
+
+RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
+    : context_(open_default_device()),
+      pd_(alloc_protection_domain(context_)),
+      worker_env_(worker_env) {
+  event_channel_ = ibv_create_comp_channel(context_);
+  CHECK(event_channel_) << "Failed to create completion channel";
+  cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_,
+                      0);
+  CHECK(cq_) << "Failed to create completion queue";
+  CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
+  polling_thread_.reset(Env::Default()->StartThread(
+      ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
+  VLOG(2) << "Start RdmaAdapter: " << name();
+}
+
+RdmaAdapter::~RdmaAdapter() {
+  polling_thread_.reset();
+  CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
+  CHECK(!ibv_destroy_comp_channel(event_channel_))
+      << "Failed to destroy channel";
+  CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD";
+  CHECK(!ibv_close_device(context_)) << "Failed to release context";
+}
+
+string RdmaAdapter::name() const { return string(context_->device->name); }
+
+// Function to process incoming messages
+// There are two types of messages:
+// 1. IBV_WC_RECV_RDMA_WITH_IMM (receive)
+// 2. IBV_WC_RDMA_WRITE (send))
+void RdmaAdapter::Process_CQ() {
+  while (true) {
+    ibv_cq* cq;
+    void* cq_context;
+    CHECK(!ibv_get_cq_event(event_channel_, &cq, &cq_context));
+    CHECK(cq == cq_);
+    ibv_ack_cq_events(cq, 1);
+    CHECK(!ibv_req_notify_cq(cq_, 0));
+
+    int ne =
+        ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
+    CHECK_GE(ne, 0);
+    for (int i = 0; i < ne; ++i) {
+      CHECK(wc_[i].status == IBV_WC_SUCCESS)
+          << "Failed status \n"
+          << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
+          << static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
+      if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
+        RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
+        // put back a recv wr.
+        rc->Recv();
+        // imm_data is the index of RX buffer in the buffer table.
+        uint32_t imm_data = wc_[i].imm_data;
+        RdmaBuffer* rb = rc->FindBuffer(imm_data);
+        RdmaMessage rm;
+        RdmaMessage::ParseMessage(rm, rb->buffer_);
+        VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_);
+
+        if (rm.type_ == RDMA_MESSAGE_ACK) {
+          // receive an ack to a message
+          rb = rc->tx_message_buffer_;
+          rb->SetBufferStatus(remote, idle);
+          rb->SendNextItem();
+        } else if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
+          // received a request-for-tensor message
+          // send ack to release remote tx message buffer
+          RdmaBuffer* ab = rc->tx_ack_buffer_;
+          ab->SendNextItem();
+          // find or create buffer
+          RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_);
+          string key_with_step_id =
+              VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
+          tb->EnqueueItem(key_with_step_id);
+          // send the next tensor
+          worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
+        } else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
+          // receive tensor-buffer-ready message
+          // send ack to release remote tx message buffer
+          RdmaBuffer* ab = rc->tx_ack_buffer_;
+          ab->SendNextItem();
+          // find buffer
+          RdmaBuffer* tb = rc->FindBuffer(rm.name_);
+          tb->SetBufferStatus(remote, idle);
+          worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
+        } else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) {
+          // remote host requests to create a tensor buffer;
+          // send ack to release remote tx message buffer
+          RdmaBuffer* ab = rc->tx_ack_buffer_;
+          ab->SendNextItem();
+          // find or create the buffer
+          RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_, TENSOR);
+          RemoteMR rmr;
+          rmr.remote_addr = rm.remote_addr_;
+          rmr.rkey = rm.rkey_;
+          tb->SetRemoteMR(rmr, true);
+          tb->CreateCPUBuffer(rm.buffer_size_);
+          // create RDMA_MESSAGE_BUFFER_RESPONSE message
+          RdmaMessage br;
+          br.type_ = RDMA_MESSAGE_BUFFER_RESPONSE;
+          br.name_size_ = rm.name_.size();
+          br.name_ = rm.name_;
+          br.buffer_size_ = rm.buffer_size_;
+          br.remote_addr_ = reinterpret_cast<uint64_t>(tb->buffer_);
+          br.rkey_ = tb->self_->rkey;
+          string message = RdmaMessage::CreateMessage(br);
+          RdmaBuffer* mb = rc->tx_message_buffer_;
+          mb->EnqueueItem(message);
+          mb->SendNextItem();
+        } else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) {
+          // remote creates a buffer and responds
+          // send ack to release remote tx message buffer
+          RdmaBuffer* ab = rc->tx_ack_buffer_;
+          ab->SendNextItem();
+          // find buffer
+          RdmaBuffer* tb = rc->FindBuffer(rm.name_);
+          CHECK(rm.buffer_size_ == tb->size_)
+              << "rm.buffer_size = " << rm.buffer_size_
+              << "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_;
+          RemoteMR rmr;
+          rmr.remote_addr = rm.remote_addr_;
+          rmr.rkey = rm.rkey_;
+          tb->SetRemoteMR(rmr, true);
+          tb->SetBufferStatus(local, idle);
+          tb->SetBufferStatus(remote, idle);
+          worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
+        } else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+          // tensor RDMA write completed
+          worker_env_->compute_pool->Schedule([rm, rc]() {
+            string key_with_step_id =
+                VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
+            rc->RunRecvCallback(key_with_step_id);
+          });
+        }
+      } else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
+        RdmaBuffer* rb = reinterpret_cast<RdmaBuffer*>(wc_[i].wr_id);
+        rb->SetBufferStatus(local, idle);
+        RdmaMessage rm;
+        RdmaMessage::ParseMessage(rm, rb->buffer_);
+        VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_);
+        if (rm.type_ != RDMA_MESSAGE_ACK) {
+          worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); });
+        }
+      }
+    }
+  }
+}
+
+RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
+                         const string remote_name)
+    : adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {
+  // Create queue pair
+  {
+    struct ibv_qp_init_attr attr;
+    memset(&attr, 0, sizeof(ibv_qp_init_attr));
+    attr.send_cq = adapter_->cq_;
+    attr.recv_cq = adapter_->cq_;
+    attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+    attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
+    attr.cap.max_send_sge = 1;
+    attr.cap.max_recv_sge = 1;
+    attr.qp_type = IBV_QPT_RC;
+
+    qp_ = ibv_create_qp(adapter_->pd_, &attr);
+    CHECK(qp_) << "Failed to create queue pair";
+  }
+
+  // Init queue pair
+  {
+    struct ibv_qp_attr attr;
+    memset(&attr, 0, sizeof(ibv_qp_attr));
+    attr.qp_state = IBV_QPS_INIT;
+    attr.pkey_index = 0;
+    attr.port_num = 1;
+    attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
+
+    int mask =
+        IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
+    CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT";
+  }
+
+  // Local address
+  {
+    struct ibv_port_attr attr;
+    CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
+        << "Query port";
+    self_.lid = attr.lid;
+    self_.qpn = qp_->qp_num;
+    self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
+  }
+
+  // create message and ack buffers, then initialize the tables.
+  {
+    const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
+                                   "tx_ack_buffer", "rx_ack_buffer"};
+    tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
+    rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
+    tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
+    rx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[3]);
+    message_buffers_.reserve(kNumMessageBuffers);
+    message_buffers_.push_back(tx_message_buffer_);
+    message_buffers_.push_back(rx_message_buffer_);
+    message_buffers_.push_back(tx_ack_buffer_);
+    message_buffers_.push_back(rx_ack_buffer_);
+    // create buffer on host
+    tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
+    rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
+    tx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
+    rx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
+    // bt_mu_.lock() is not used in constructor.
+    for (int i = 0; i < kNumMessageBuffers; i++) {
+      uint32_t index = NameHash(buffer_names[i]);
+      buffer_table_.insert({index, message_buffers_[i]});
+      buffer_index_name_table_.insert({index, buffer_names[i]});
+      buffer_name_index_table_.insert({buffer_names[i], index});
+    }
+
+    // Initiate recv
+    for (int i = 0; i < 100; i++) {
+      Recv();
+    }
+  }
+}
+
+RdmaChannel::~RdmaChannel() {
+  CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
+  delete tx_message_buffer_;
+  delete rx_message_buffer_;
+  delete tx_ack_buffer_;
+  delete rx_ack_buffer_;
+}
+
+void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
+  mutex_lock lock{mu_};
+  if ((override) || (!remote_set_)) {
+    remote_.lid = ra.lid;
+    remote_.qpn = ra.qpn;
+    remote_.psn = ra.psn;
+    remote_set_ = true;
+  } else {
+    CHECK(remote_.lid == ra.lid);
+    CHECK(remote_.qpn == ra.qpn);
+    CHECK(remote_.psn == ra.psn);
+  }
+}
+
+// Adding tokens to the completion queue
+// Tokens are needed to process future messages.
+void RdmaChannel::Recv() {
+  struct ibv_recv_wr wr;
+  memset(&wr, 0, sizeof(wr));
+  wr.wr_id = (uint64_t)this;
+  struct ibv_recv_wr* bad_wr;
+  CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
+}
+
+// Lookup 32-bit buffer index from buffer name
+// Args:
+//   buffer_name: name of the buffer
+// Returns:
+//   32-bit index
+uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) {
+  mutex_lock lock{bt_mu_};
+  BufferNameIndexTable::iterator iter =
+      buffer_name_index_table_.find(buffer_name);
+  CHECK(iter != buffer_name_index_table_.end());
+  return iter->second;
+}
+
+// Find a buffer by its 32-bit index
+// Args:
+//   index: 32-bit hash code of the tensor buffer name
+// Returns:
+//   name of the tensor buffer
+RdmaBuffer* RdmaChannel::FindBuffer(const uint32_t index) {
+  mutex_lock lock{bt_mu_};
+  BufferTable::iterator iter = buffer_table_.find(index);
+  CHECK(iter != buffer_table_.end());
+  return iter->second;
+}
+
+// Find a buffer by its name
+// Args:
+//   name: name of the buffer
+// Returns:
+//   the named rdma buffer
+RdmaBuffer* RdmaChannel::FindBuffer(const string& name) {
+  uint32_t index = LookupBufferIndex(name);
+  return FindBuffer(index);
+}
+
+// Find a buffer if it exists, otherwise create one.
+// The memory inside the created buffer is not allocated.
+// Args:
+//   name: the name of the buffer
+//   buffer_type: TENSOR, MESSAGE or ACK.
+// Returns:
+//   the named buffer
+RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
+                                            BufferType buffer_type) {
+  mutex_lock lock{bt_mu_};
+  RdmaBuffer* rb;
+  // find index
+  BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(name);
+  if (iter != buffer_name_index_table_.end()) {
+    uint32_t index = iter->second;
+    // find buffer
+    BufferTable::iterator iter = buffer_table_.find(index);
+    CHECK(iter != buffer_table_.end());
+    rb = iter->second;
+  } else {
+    uint32_t index = NameHash(name);
+    if (buffer_type == TENSOR) {
+      rb = new RdmaTensorBuffer(this, name);
+    } else if (buffer_type == MESSAGE) {
+      rb = new RdmaMessageBuffer(this, name);
+    } else if (buffer_type == ACK) {
+      rb = new RdmaAckBuffer(this, name);
+    }
+    buffer_name_index_table_.insert({name, index});
+    buffer_index_name_table_.insert({index, name});
+    buffer_table_.insert({index, rb});
+  }
+  CHECK(rb);
+  return rb;
+}
+
+// Insert callback to the callback_table.
+// The callback is activated when the corresponding tensor is received.
+// Arg:
+//   key: the name of the tensor
+//   recv_done: the callback associated with the tensor.
+// Returns:
+//   None
+void RdmaChannel::InsertRecvCallback(const string& key,
+                                     std::function<void()> recv_done) {
+  mutex_lock lock{ct_mu_};
+  callback_table_.insert({key, recv_done});
+}
+
+// Remove callback from the callback_table.
+// Arg:
+//   key: the name of the tensor
+// Returns:
+//   None
+void RdmaChannel::RemoveRecvCallback(const string& key) {
+  mutex_lock lock{ct_mu_};
+  callback_table_.erase(key);
+}
+
+// Run named callback in the callback_table.
+// Arg:
+//   key: the name of the tensor
+// Returns:
+//   None
+void RdmaChannel::RunRecvCallback(const string& key) {
+  std::function<void()> recv_done;
+  {
+    mutex_lock lock{ct_mu_};
+    CallbackTable::iterator iter = callback_table_.find(key);
+    CHECK(iter != callback_table_.end());
+    recv_done = iter->second;
+  }
+  recv_done();
+}
+
+void RdmaChannel::Connect() {
+  {
+    mutex_lock lock{mu_};
+    CHECK(remote_set_) << "remote channel is not set";
+  }
+  Connect(remote_);
+}
+
+// Setup channel to a remote node
+// Args:
+//   remoteAddr: the rdma address of a remote channel.
+// Returns:
+//   None
+void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
+  mutex_lock lock{mu_};
+  if (!connected_) {
+    struct ibv_qp_attr attr;
+    memset(&attr, 0, sizeof(ibv_qp_attr));
+    attr.qp_state = IBV_QPS_RTR;
+    attr.path_mtu = IBV_MTU_4096;
+    attr.dest_qp_num = remoteAddr.qpn;
+    attr.rq_psn = remoteAddr.psn;
+    attr.max_dest_rd_atomic = 1;
+    attr.min_rnr_timer = 12;
+    attr.ah_attr.is_global = 0;
+    attr.ah_attr.dlid = remoteAddr.lid;
+    attr.ah_attr.sl = 0;
+    attr.ah_attr.src_path_bits = 0;
+    attr.ah_attr.port_num = 1;
+
+    int r;
+    CHECK(!(r = ibv_modify_qp(qp_, &attr,
+                              IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
+                                  IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
+                                  IBV_QP_MAX_DEST_RD_ATOMIC |
+                                  IBV_QP_MIN_RNR_TIMER)))
+        << "QP to Ready to Receive " << r;
+
+    memset(&attr, 0, sizeof(ibv_qp_attr));
+    attr.qp_state = IBV_QPS_RTS;
+    attr.sq_psn = self_.psn;
+    attr.timeout = 14;
+    attr.retry_cnt = 7;
+    attr.rnr_retry = 7; /* infinite */
+    attr.max_rd_atomic = 1;
+
+    CHECK(!(r = ibv_modify_qp(qp_, &attr,
+                              IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
+                                  IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
+                                  IBV_QP_MAX_QP_RD_ATOMIC)))
+        << "QP to Ready to Send " << r;
+
+    connected_ = true;
+  } else {
+    LOG(INFO) << "channel already connected";
+  }
+}
+
+RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name)
+    : channel_(channel), name_(name) {}
+
+RdmaBuffer::~RdmaBuffer() {
+  CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
+  FreeBuffer();
+}
+
+void RdmaBuffer::FreeBuffer() {
+  if ((buffer_ != nullptr) && buffer_on_host_) {
+    free(buffer_);
+  }
+  // TODO
+  // release buffer if it is on device.
+  // We don't support RDMABuffer on device at this moment.
+}
+
+// Allocate CPU memory for the Rdma buffer
+// Args:
+//   size: to-be-allocated memory size
+//   lock: whether or not mutex_lock the process to protect concurrency.
+// Returns:
+//   None
+void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
+  CHECK(size > 0);
+  if (lock) {
+    mu_.lock();
+  }
+  if (local_status_ != none) {
+    // delete existing buffer
+    CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
+    FreeBuffer();
+  }
+  size_ = size;
+  buffer_ = malloc(size_);
+  self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_,
+                     IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
+  CHECK(self_) << "Failed to register memory region";
+  buffer_on_host_ = true;
+  local_status_ = idle;
+  if (lock) {
+    mu_.unlock();
+  }
+}
+
+// Set address of remote memory region
+// Args:
+//   rmr: address of remote memory region
+//   override: whether override existing information
+// Returns:
+//   None
+void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
+  mutex_lock lock{mu_};
+  if ((override) || (remote_status_ == none)) {
+    remote_.remote_addr = rmr.remote_addr;
+    remote_.rkey = rmr.rkey;
+    remote_status_ = idle;
+  } else {
+    CHECK(remote_.remote_addr == rmr.remote_addr);
+    CHECK(remote_.rkey == rmr.rkey);
+  }
+}
+
+// Put a task in the buffer's job queue
+void RdmaBuffer::EnqueueItem(string item) {
+  mutex_lock lock{mu_};
+  queue_.push(item);
+}
+
+// Rdma-Write the content of the buffer
+void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
+  struct ibv_sge list;
+  list.addr = (uint64_t)buffer_;
+  list.length = buffer_size;
+  list.lkey = self_->lkey;
+
+  struct ibv_send_wr wr;
+  memset(&wr, 0, sizeof(wr));
+  wr.wr_id = (uint64_t)this;
+  wr.sg_list = &list;
+  wr.num_sge = 1;
+  wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
+  wr.send_flags = IBV_SEND_SIGNALED;
+  wr.imm_data = imm_data;
+  wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr;
+  wr.wr.rdma.rkey = remote_.rkey;
+
+  struct ibv_send_wr* bad_wr;
+  CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send";
+}
+
+RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name)
+    : RdmaBuffer(channel, name) {}
+
+RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
+    : RdmaBuffer(channel, name) {}
+
+RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name)
+    : RdmaBuffer(channel, name) {}
+
+// Send the next ack from the buffer's job queue.
+void RdmaAckBuffer::SendNextItem() {
+  uint32_t imm_data = LookupBufferIndex("rx_ack_buffer");
+  RdmaMessage rm;
+  rm.name_ = "rx_ack_buffer";
+  rm.type_ = RDMA_MESSAGE_ACK;
+  rm.name_size_ = rm.name_.size();
+  string message = RdmaMessage::CreateMessage(rm);
+  memcpy(buffer_, message.data(), message.size());
+  Write(imm_data, message.size());
+}
+
+// Send the next message from the buffer's job queue.
+void RdmaMessageBuffer::SendNextItem() {
+  uint32_t imm_data = LookupBufferIndex("rx_message_buffer");
+  mu_.lock();
+  if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
+    local_status_ = busy;
+    remote_status_ = busy;
+    string message = queue_.front();
+    queue_.pop();
+    // local/remote_status_ won't be set back to idle
+    // unitl Write() is successful
+    mu_.unlock();
+    memcpy(buffer_, message.data(), message.size());
+    Write(imm_data, message.size());
+  } else {
+    mu_.unlock();
+  }
+}
+
+// Send the next tensor from the buffer's job queue.
+void RdmaTensorBuffer::SendNextItem() {
+  // get the key
+  string key_with_step_id = "";
+  {
+    mutex_lock lock{mu_};
+    if (!queue_.empty()) {
+      key_with_step_id = queue_.front();
+      queue_.pop();
+    }
+  }
+  // send the tensor if a key is acquired.
+  if (key_with_step_id != "") {
+    VLOG(2) << "try to send tensor: " << key_with_step_id;
+    string key;
+    int64 step_id;
+    VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
+    CHECK(key.compare(name_) == 0);
+    Rendezvous::ParsedKey parsed;
+    Rendezvous::ParseKey(key, &parsed);
+    Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id,
+                                   parsed](const Status& status,
+                                           const Rendezvous::Args& send_args,
+                                           const Rendezvous::Args& recv_args,
+                                           const Tensor& in, bool is_dead) {
+      CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
+                         << " error message: " << status.error_message();
+      size_t buffer_size = RdmaMessage::kMessageTotalBytes;
+      size_t tensor_bytes = 0;
+      TensorProto proto;
+      // Figures out which device the tensor is hosted on.
+      Device* src_dev = nullptr;
+      Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
+          parsed.src_device, &src_dev);
+      CHECK(s.ok()) << "src device not found";
+      // Does the device have the right incarnation number we expect?
+      CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
+          << "RecvTensor expects a different device incarnation: "
+          << parsed.src_incarnation << " vs. "
+          << src_dev->attributes().incarnation()
+          << ". Your worker job was probably restarted. Check your "
+          << "worker job for the reason why it was restarted.";
+      Device* dst_dev = nullptr;
+      // destination is on CPU.
+      s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
+                                                                    &dst_dev);
+      CHECK(s.ok()) << "dst device not found";
+      AllocatorAttributes dst_alloc_attr;
+      dst_alloc_attr.set_on_host(true);
+      // string tensor needs to be serialized
+      if (src_dev->tensorflow_gpu_device_info() &&
+          (!send_args.alloc_attrs.on_host())) {
+        CHECK(send_args.device_context)
+            << "send dev name: " << src_dev->name()
+            << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+        // "val" is on a GPU. Uses GPUUtil to fill the proto.
+        s = VerbsUtil::SetProtoFromGPUSync(
+            in, src_dev, send_args.device_context, &proto, is_dead);
+        CHECK(s.ok()) << "set proto from gpu sync";
+      } else {
+        // tensor is in CPU memory.
+        in.AsProtoTensorContent(&proto);
+      }
+      tensor_bytes = proto.ByteSize();
+      // maybe some margin for string tensor?
+      buffer_size += tensor_bytes;
+      // prepare message
+      RdmaMessage rm;
+      rm.name_size_ = key.size();
+      rm.name_ = key;
+      rm.tensor_shape_ = in.shape();
+      rm.data_type_ = in.dtype();
+      rm.step_id_ = step_id;
+      rm.is_dead_ = is_dead;
+      rm.tensor_bytes_ = tensor_bytes;
+      rm.buffer_size_ = buffer_size;
+      mu_.lock();
+      if (local_status_ == none ||
+          (buffer_size > size_ && local_status_ == idle &&
+           remote_status_ == idle)) {
+        if ((local_status_ != none) && (buffer_size > size_)) {
+          CHECK(rm.data_type_ == DT_STRING)
+              << "Only string tensor allows to change size";
+        }
+        CreateCPUBuffer(buffer_size, false);
+        mu_.unlock();
+        // put back the key since it is not sent;
+        EnqueueItem(key_with_step_id);
+        // ask the remote to create the same buffer
+        rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
+        rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
+        rm.rkey_ = self_->rkey;
+        string message = RdmaMessage::CreateMessage(rm);
+        channel_->tx_message_buffer_->EnqueueItem(message);
+        channel_->tx_message_buffer_->SendNextItem();
+      } else if ((local_status_ == idle) && (remote_status_ == idle)) {
+        // both buffers are ready, send the tensor
+        local_status_ = busy;
+        remote_status_ = busy;
+        // local/remote_status_ won't be set back to idle
+        // unitl Write() is successful
+        mu_.unlock();
+        CHECK((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
+              (buffer_size <= size_ && rm.data_type_ == DT_STRING))
+            << "tensor and buffer size do not agree!"
+            << " buffer_size = " << size_
+            << " requested tensor size = " << buffer_size << in.DebugString();
+        uint32_t imm_data = LookupBufferIndex(key);
+        rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
+        string message = RdmaMessage::CreateMessage(rm);
+        memcpy(buffer_, message.data(), message.size());
+        if (!is_dead) {
+          // copy the tensor buffer content
+          void* output =
+              static_cast<void*>(static_cast<char*>(buffer_) +
+                                 RdmaMessage::kTensorBufferStartIndex);
+          CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
+          proto.SerializeToArray(output, tensor_bytes);
+        } else {
+          buffer_size = RdmaMessage::kMessageTotalBytes;
+        }
+        Write(imm_data, buffer_size);
+      } else {
+        mu_.unlock();
+        // put back the key since it is not sent;
+        EnqueueItem(key_with_step_id);
+      }
+    };
+    // Use default session (legacy_session_)
+    // TODO use WorkerSessionForSession
+    // need to pass in session handle
+    channel_->adapter_->worker_env_->session_mgr->LegacySession()
+        ->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb);
+  }
+}
+
+// Create a RdmaMessage according to the pre-defined format
+// Args:
+//   rm: the message structure
+// Returns:
+//   message in string format
+string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
+  // Rdma Message format
+  // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
+  //   1B|    2B   | 512|  8B   |    8B     |       8B  | 4B |    1B |...
+  // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
+  // ...|   XB    |    XB      |    8B      |...
+  //
+  // ACK:             type|13|"rx_ack_buffer"
+  // TENSOR_REQUEST:  type|name_size|tensor_name|step_id
+  // TENSOR_WRITE:    type|name_size|tensor_name|step_id|...|is_dead
+  //                 |data_type|tensor_shape|tensor_bytes
+  // BUFFER_IDLE:     type|name_size|buffer_name
+  // BUFFER_REQUEST:
+  // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
+  // BUFFER_RESPONSE:
+  // type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
+  char message[kMessageTotalBytes];
+  // type
+  message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
+  // size of name
+  memcpy(&message[kNameSizeStartIndex], &rm.name_size_, sizeof(rm.name_size_));
+  // name
+  memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
+  // buffer_size, remote_addr, rkey
+  if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
+      (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
+    memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
+           sizeof(rm.buffer_size_));
+    memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
+           sizeof(rm.remote_addr_));
+    memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
+  }
+  // step_id
+  if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
+      (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
+    memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
+  }
+  // is_dead, data_type, tensor_shape, tensor_bytes
+  if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+    memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
+
+    memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
+           sizeof(rm.data_type_));
+    memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
+           sizeof(rm.tensor_shape_));
+    memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
+           sizeof(rm.tensor_bytes_));
+  }
+  return string(message, kMessageTotalBytes);
+}
+
+// Parse a RdmaMessage according to the pre-defined format
+// Args:
+//   rm: the message structure where the parsed message will be saved
+//   buffer: the place where the raw message is stored
+// Returns:
+//   None
+void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
+  char* message = static_cast<char*>(buffer);
+  // type
+  rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]);
+  // name_size_
+  memcpy(&rm.name_size_, &message[kNameSizeStartIndex], sizeof(rm.name_size_));
+  // name
+  rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
+  // buffer_size, remote_addr, rkey
+  if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
+      (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
+    memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
+           sizeof(rm.buffer_size_));
+    memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
+           sizeof(rm.remote_addr_));
+    memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
+  }
+  // step_id
+  if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
+      (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
+    memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
+  }
+  // data_type, tensor_bytes, tensor_shape, is_dead
+  if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
+    memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
+    memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
+           sizeof(rm.data_type_));
+    memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex],
+           sizeof(rm.tensor_shape_));
+    memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
+           sizeof(rm.tensor_bytes_));
+  }
+}
+
+}  // end namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h
new file mode 100644
index 00000000000..ae2aa63e3f6
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma.h
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include <infiniband/verbs.h>
+#include <cstring>  // for memset
+#include <functional>
+#include <memory>  // for shared_ptr
+#include <queue>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// structure to save the address of remote channels.
+struct RdmaAddress {
+  uint32_t lid;
+  uint32_t qpn;
+  uint32_t psn;
+};
+// structure to save information for remote memory regions.
+struct RemoteMR {
+  uint64_t remote_addr;
+  uint32_t rkey;
+};
+enum BufferStatus { none, idle, busy };
+enum Location { local, remote };
+enum BufferType { ACK, MESSAGE, TENSOR };
+enum RdmaMessageType {
+  RDMA_MESSAGE_ACK,
+  RDMA_MESSAGE_BUFFER_IDLE,
+  RDMA_MESSAGE_BUFFER_REQUEST,
+  RDMA_MESSAGE_BUFFER_RESPONSE,
+  RDMA_MESSAGE_TENSOR_REQUEST,
+  RDMA_MESSAGE_TENSOR_WRITE
+};
+class RdmaBuffer;
+// Class that represents the Rdma Adapter.
+// Responsible for creation of the completion queue, and handling
+// of work completions.
+class RdmaAdapter {
+  friend class RdmaChannel;
+  friend class RdmaBuffer;
+  friend class RdmaAckBuffer;
+  friend class RdmaMessageBuffer;
+  friend class RdmaTensorBuffer;
+  friend class RdmaMgr;
+  friend class RdmaRemoteRendezvous;
+
+ public:
+  RdmaAdapter(const WorkerEnv* worker_env);
+  ~RdmaAdapter();
+  // Adapter name, e.g. mlx5_0.
+  string name() const;
+  void Process_CQ();
+
+ protected:
+  static const int MAX_CONCURRENT_WRITES = 1000;
+  ibv_context* context_;
+  // ibverbs protection domain
+  ibv_pd* pd_;
+  // Completion event channel, to wait for work completions
+  ibv_comp_channel* event_channel_;
+  // Completion queue, to poll on work completions
+  ibv_cq* cq_;
+  // Pre-allocated work completions array used for polling
+  ibv_wc wc_[MAX_CONCURRENT_WRITES * 2];
+  // worker env for thread
+  const WorkerEnv* worker_env_;
+  // thread for cq.
+  std::unique_ptr<Thread> polling_thread_;
+};
+
+// Class that represents a connection to a remote Rdma peer.
+// Responsible for connecting queue pairs.
+class RdmaChannel {
+  friend class RdmaAdapter;
+  friend class RdmaBuffer;
+  friend class RdmaAckBuffer;
+  friend class RdmaMessageBuffer;
+  friend class RdmaTensorBuffer;
+  friend class RdmaMgr;
+  friend class RdmaRemoteRendezvous;
+
+ public:
+  explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
+                       const string remote_name_);
+  ~RdmaChannel();
+  inline const RdmaAddress& self() { return self_; }
+  RdmaAddress address() const;
+  inline const std::vector<RdmaBuffer*>& message_buffers() const {
+    return message_buffers_;
+  }
+  void Connect(const RdmaAddress& remoteAddr);
+  void Connect();
+  void Recv();
+  RdmaBuffer* FindBuffer(const uint32_t index);
+  RdmaBuffer* FindBuffer(const string& name);
+  RdmaBuffer* FindOrCreateBuffer(const string& name,
+                                 BufferType buffer_type = TENSOR);
+  uint32_t LookupBufferIndex(const string& buffer_name);
+  void SetRemoteAddress(const RdmaAddress& ra, bool override);
+  void InsertRecvCallback(const string& key, std::function<void()> recv_done);
+  void RemoveRecvCallback(const string& key);
+  void RunRecvCallback(const string& key);
+  static const int kNumMessageBuffers = 4;
+
+ protected:
+  const RdmaAdapter* adapter_;
+  RdmaAddress self_;
+  string local_name_;
+  string remote_name_;
+  ibv_qp* qp_;
+  mutex mu_;
+  bool connected_ GUARDED_BY(bt_mu_) = false;
+  RdmaAddress remote_ GUARDED_BY(bt_mu_);
+  bool remote_set_ GUARDED_BY(bt_mu_) = false;
+  mutex ct_mu_;
+  typedef std::unordered_map<string, std::function<void()> > CallbackTable;
+  CallbackTable callback_table_ GUARDED_BY(ct_mu_);
+  mutex bt_mu_;
+  typedef std::unordered_map<unsigned int, RdmaBuffer*> BufferTable;
+  BufferTable buffer_table_ GUARDED_BY(bt_mu_);
+  typedef std::unordered_map<uint32_t, string> BufferIndexNameTable;
+  BufferIndexNameTable buffer_index_name_table_ GUARDED_BY(bt_mu_);
+  typedef std::unordered_map<string, uint32_t> BufferNameIndexTable;
+  BufferNameIndexTable buffer_name_index_table_ GUARDED_BY(bt_mu_);
+  RdmaBuffer* tx_message_buffer_;
+  RdmaBuffer* rx_message_buffer_;
+  RdmaBuffer* tx_ack_buffer_;
+  RdmaBuffer* rx_ack_buffer_;
+  std::vector<RdmaBuffer*> message_buffers_;
+};
+
+// Class that represents a buffer for Rdma writes and reads.
+class RdmaBuffer {
+  friend class RdmaChannel;
+  friend class RdmaAdapter;
+  friend class RdmaMgr;
+  friend class RdmaRemoteRendezvous;
+
+ public:
+  explicit RdmaBuffer(RdmaChannel* channel, string name);
+  virtual ~RdmaBuffer();
+
+  inline void* buffer() const { return buffer_; }
+  inline ibv_mr* self() const { return self_; }
+  inline void SetBufferStatus(Location loc, BufferStatus status) {
+    mu_.lock();
+    if (loc == local) {
+      local_status_ = status;
+    } else {
+      remote_status_ = status;
+    }
+    mu_.unlock();
+  }
+  void FreeBuffer();
+  void EnqueueItem(string Item);
+  virtual void SendNextItem(){};
+  void CreateCPUBuffer(size_t size, bool lock = true);
+  void SetRemoteMR(RemoteMR rmi, bool override);
+  uint32_t LookupBufferIndex(const string& buffer_name) {
+    return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name);
+  }
+  void Write(uint32_t imm_data, size_t buffer_size);
+
+ protected:
+  const RdmaChannel* channel_;
+  void* buffer_ = nullptr;
+  bool buffer_on_host_ = true;
+  size_t size_ = 0;
+  const string name_;
+  ibv_mr* self_ = nullptr;
+  mutex mu_;
+  RemoteMR remote_;
+  std::queue<string> queue_ GUARDED_BY(mu_);
+  BufferStatus local_status_ GUARDED_BY(mu_) = none;
+  BufferStatus remote_status_ GUARDED_BY(mu_) = none;
+};
+
+class RdmaAckBuffer : public RdmaBuffer {
+ public:
+  explicit RdmaAckBuffer(RdmaChannel* channel, string name);
+  virtual ~RdmaAckBuffer() override {}
+  void SendNextItem() override;
+};
+
+class RdmaMessageBuffer : public RdmaBuffer {
+  friend class RdmaChannel;
+  friend class RdmaAapater;
+
+ public:
+  explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
+  virtual ~RdmaMessageBuffer() override {}
+  void SendNextItem() override;
+};
+
+class RdmaTensorBuffer : public RdmaBuffer {
+ public:
+  explicit RdmaTensorBuffer(RdmaChannel* channel, string name);
+  virtual ~RdmaTensorBuffer() override {}
+  void SendNextItem() override;
+};
+
+struct RdmaMessage {
+  RdmaMessageType type_;
+  uint16_t name_size_;
+  string name_;
+  int64 step_id_;
+  uint64_t buffer_size_;
+  uint64_t remote_addr_;
+  uint32_t rkey_;
+  bool is_dead_;
+  DataType data_type_;
+  TensorShape tensor_shape_;
+  size_t tensor_bytes_;
+
+  // type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
+  //   1B|    2B   | 512|  8B   |    8B     |       8B  | 4B |    1B |...
+  // ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
+  // ...|   XB    |    XB      |    8B      |...
+  //
+  static const size_t kNameCapacity = 512;
+  static const size_t kTypeStartIndex = 0;
+  static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
+  static const size_t kNameStartIndex =
+      kNameSizeStartIndex + sizeof(name_size_);
+  static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
+  static const size_t kBufferSizeStartIndex =
+      kStepIdStartIndex + sizeof(step_id_);
+  static const size_t kRemoteAddrStartIndex =
+      kBufferSizeStartIndex + sizeof(buffer_size_);
+  static const size_t kRkeyStartIndex =
+      kRemoteAddrStartIndex + sizeof(remote_addr_);
+  static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
+  static const size_t kDataTypeStartIndex =
+      kIsDeadStartIndex + sizeof(is_dead_);
+  static const size_t kTensorShapeStartIndex =
+      kDataTypeStartIndex + sizeof(data_type_);
+  static const size_t kTensorBytesStartIndex =
+      kTensorShapeStartIndex + sizeof(TensorShape);
+  static const size_t kTensorBufferStartIndex =
+      kTensorBytesStartIndex + sizeof(tensor_bytes_);
+  static const size_t kMessageTotalBytes = kTensorBufferStartIndex;
+  static const size_t kRdmaMessageBufferSize = kMessageTotalBytes;
+  static const size_t kRdmaAckBufferSize = kMessageTotalBytes;
+  static string CreateMessage(const RdmaMessage& rm);
+  static void ParseMessage(RdmaMessage& rm, void* buffer);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_USE_VERBS
+#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
diff --git a/tensorflow/contrib/verbs/rdma_mgr.cc b/tensorflow/contrib/verbs/rdma_mgr.cc
new file mode 100644
index 00000000000..e28b80c6f6b
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_mgr.cc
@@ -0,0 +1,133 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include <vector>
+#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
+#include "tensorflow/contrib/verbs/verbs_service.pb.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
+#include "tensorflow/core/distributed_runtime/session_mgr.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
+                 GrpcChannelCache* const channel_cache)
+    : worker_env_(worker_env), channel_cache_(channel_cache) {
+  rdma_adapter_ = new RdmaAdapter(worker_env_);
+  // hardcoded to default session (legacy_session_)
+  // TODO: use WorkerSessionForSession
+  // need to pass in session handle
+  local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name;
+  std::vector<string> workers;
+  worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers(
+      &workers);
+  num_remote_workers_ = workers.size() - 1;
+  VLOG(2) << "rmda_mgr on local worker: " << local_worker_;
+  for (size_t i = 0; i < workers.size(); i++) {
+    if (local_worker_.compare(workers[i]) != 0) {
+      channel_table_.insert(
+          {workers[i],
+           new RdmaChannel(rdma_adapter_, local_worker_, workers[i])});
+    }
+  }
+}
+
+// Setup Rdma channels between peers.
+// This is done at the beginning of the server setup.
+
+void RdmaMgr::SetupChannels() {
+  for (const auto& p : channel_table_) {
+    string worker_name = p.first;
+    LOG(INFO) << "connecting to remote node " << worker_name;
+    RdmaChannel* rc = p.second;
+    GetRemoteAddressRequest req;
+    GetRemoteAddressResponse resp;
+    // get the channel cache
+    SharedGrpcChannelPtr client_channel =
+        channel_cache_->FindWorkerChannel(worker_name);
+    GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
+    CHECK(client != nullptr) << "No worker known as " << worker_name;
+
+    // setting up request
+    req.set_host_name(local_worker_);
+    Channel* channel_info = req.mutable_channel();
+    channel_info->set_lid(rc->self_.lid);
+    channel_info->set_qpn(rc->self_.qpn);
+    channel_info->set_psn(rc->self_.psn);
+    for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
+      MemoryRegion* mr = req.add_mr();
+      mr->set_remote_addr(
+          reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
+      mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
+    }
+    // synchronous call
+    Status s = client->GetRemoteAddress(&req, &resp);
+    // save obtained remote addresses
+    // connect to the remote channel
+    if (s.ok()) {
+      CHECK(worker_name.compare(resp.host_name()) == 0);
+      RdmaAddress ra;
+      ra.lid = resp.channel().lid();
+      ra.qpn = resp.channel().qpn();
+      ra.psn = resp.channel().psn();
+      rc->SetRemoteAddress(ra, false);
+      rc->Connect();
+      int i = 0;
+      int idx[] = {1, 0, 3, 2};
+      for (const auto& mr : resp.mr()) {
+        // the connections are crossed, i.e.
+        // local tx_message_buffer <---> remote rx_message_buffer_
+        // local rx_message_buffer <---> remote tx_message_buffer_
+        // local tx_ack_buffer <---> remote rx_ack_buffer_
+        // local rx_ack_buffer <---> remote tx_ack_buffer_
+        // hence idx[] = {1, 0, 3, 2}.
+        RdmaBuffer* rb = rc->message_buffers_[idx[i]];
+        RemoteMR rmr;
+        rmr.remote_addr = mr.remote_addr();
+        rmr.rkey = mr.rkey();
+        rb->SetRemoteMR(rmr, false);
+        i++;
+      }
+      CHECK(i == RdmaChannel::kNumMessageBuffers);
+    } else {
+      LOG(ERROR) << s.error_message();
+    }
+    delete client;
+  }
+}
+
+RdmaMgr::~RdmaMgr() {
+  for (const auto& p : channel_table_) delete p.second;
+  channel_table_.clear();
+  delete rdma_adapter_;
+}
+
+// Find a channel via the given name.
+// Args:
+//   name: peer name, e.g. worker1
+// Returns
+//   channel object that is connected to the named peer.
+RdmaChannel* RdmaMgr::FindChannel(const string& name) {
+  ChannelTable::iterator iter = channel_table_.find(name);
+  CHECK(iter != channel_table_.end());
+  return iter->second;
+}
+
+}  // end namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/rdma_mgr.h b/tensorflow/contrib/verbs/rdma_mgr.h
new file mode 100644
index 00000000000..b156f64096c
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_mgr.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/contrib/verbs/rdma.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+
+namespace tensorflow {
+
+class RdmaMgr {
+ public:
+  explicit RdmaMgr(const WorkerEnv* const worker_env,
+                   GrpcChannelCache* const channel_cache);
+  ~RdmaMgr();
+  RdmaChannel* FindChannel(const string& key);
+  void SetupChannels();
+  const string& local_worker() { return local_worker_; }
+
+ private:
+  string local_worker_;
+  size_t num_remote_workers_;
+  const WorkerEnv* const worker_env_;
+  GrpcChannelCache* const channel_cache_;
+  RdmaAdapter* rdma_adapter_;
+  typedef std::unordered_map<string, RdmaChannel*> ChannelTable;
+  ChannelTable channel_table_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr);
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_USE_VERBS
+#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
new file mode 100644
index 00000000000..8cbdfaa9439
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -0,0 +1,149 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
+#include <unordered_set>
+#include "tensorflow/contrib/verbs/verbs_util.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
+ public:
+  RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
+                       int64 step_id, RdmaMgr* rdma_mgr)
+      : BaseRemoteRendezvous(env, worker_name, step_id, true),
+        rdma_mgr_(rdma_mgr) {}
+
+ protected:
+  void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
+                           const Rendezvous::Args& args,
+                           DoneCallback done) override;
+
+ private:
+  ~RdmaRemoteRendezvous() override {}
+  RdmaMgr* rdma_mgr_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(RdmaRemoteRendezvous);
+};
+
+void RdmaRemoteRendezvous::RecvFromRemoteAsync(
+    const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
+    DoneCallback done) {
+  Status s;
+  // parse src_name and dst_name
+  string src_name, dst_name, unused;
+  if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name,
+                                        &unused)) {
+    s = errors::Internal("Could not parse src name.");
+  }
+  CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+  if (!s.ok()) {
+    done(s, Args(), recv_args, Tensor{}, false);
+    return;
+  }
+  if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
+                                        &unused)) {
+    s = errors::Internal("Could not parse dst name.");
+  }
+  CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+  if (!s.ok()) {
+    done(s, Args(), recv_args, Tensor{}, false);
+    return;
+  }
+  CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
+  RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
+  string key(std::move(parsed.FullKey().ToString()));
+  string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
+  // insert callback
+  rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
+                                            recv_args, parsed, done]() {
+    Status s;
+    Device* src_dev;
+    s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
+    CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+    if (!s.ok()) {
+      done(s, Args(), recv_args, Tensor(), true);
+      return;
+    }
+    Device* dst_dev;
+    s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
+    CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
+    if (!s.ok()) {
+      done(s, Args(), recv_args, Tensor(), true);
+      return;
+    }
+    RdmaBuffer* rb = rc->FindBuffer(key);
+    RdmaMessage rm;
+    CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
+    RdmaMessage::ParseMessage(rm, rb->buffer_);
+    CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
+    Tensor val;
+    if (!rm.is_dead_) {
+      void* input = static_cast<char*>(rb->buffer_) +
+                    RdmaMessage::kTensorBufferStartIndex;
+      TensorProto proto;
+      CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
+            rb->size_);
+      CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
+          << "fail to parse proto from array";
+      s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
+    }
+
+    rc->RemoveRecvCallback(key_with_step_id);
+    // create message
+    RdmaMessage br;
+    br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
+    br.name_size_ = key.size();
+    br.name_ = key;
+    string message = RdmaMessage::CreateMessage(br);
+    RdmaBuffer* tb = rc->tx_message_buffer_;
+    tb->EnqueueItem(message);
+    tb->SendNextItem();
+    done(s, Args(), recv_args, val, rm.is_dead_);
+  });
+  // append key to message queue
+  RdmaBuffer* rb = rc->tx_message_buffer_;
+  RdmaMessage rm;
+  rm.type_ = RDMA_MESSAGE_TENSOR_REQUEST;
+  rm.name_size_ = key.size();
+  rm.name_ = key;
+  rm.step_id_ = step_id_;
+  string message = RdmaMessage::CreateMessage(rm);
+  rb->EnqueueItem(message);
+  rb->SendNextItem();
+}
+
+RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env,
+                                     const string& worker_name,
+                                     WorkerCacheInterface* worker_cache)
+    : BaseRendezvousMgr(env, worker_name) {}
+
+BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id,
+                                                const WorkerEnv* worker_env,
+                                                const string& worker_name) {
+  return new RdmaRemoteRendezvous(worker_env, worker_name, step_id, rdma_mgr_);
+}
+
+}  // end namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
new file mode 100644
index 00000000000..57cd4bf5e4e
--- /dev/null
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
@@ -0,0 +1,64 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+// RendezvousMgr keeps track of a set of local rendezvous instances.
+// All tensors sent by this worker are buffered in a RendezvousMgr
+// until the tensor is received.  Each global unique "step_id"
+// corresponds to one local rendezvous instance managed by a
+// RendezvousMgr.
+//
+// E.g.,
+//   Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
+//   fork execution of an graph executor using "rendez"  on thread 1;
+//   fork execution of another graph executor using "rendez" on thread 2;
+//   ...
+//   join threads 1 and 2;
+//
+// In the example above, execution in thread 1 and 2 communicates with
+// each other by send/recv operations through the "rend".
+//
+// Tensors sent and recved through rendezvous managed by this
+// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
+class RdmaRendezvousMgr : public BaseRendezvousMgr {
+ public:
+  explicit RdmaRendezvousMgr(const WorkerEnv* env, const string& worker_name,
+                             WorkerCacheInterface* worker_cache);
+  void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
+
+ protected:
+  BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
+                               const string& worker_name) override;
+
+ private:
+  RdmaMgr* rdma_mgr_;
+  TF_DISALLOW_COPY_AND_ASSIGN(RdmaRendezvousMgr);
+};
+
+}  // end namespace tensorflow
+
+#endif  // TENSORFLOW_USE_VERBS
+#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
new file mode 100644
index 00000000000..b061c81d2d8
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -0,0 +1,172 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/verbs_server_lib.h"
+
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/server_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+namespace {
+// static utility function
+RendezvousMgrInterface* NewRdmaRendezvousMgr(
+    const WorkerEnv* env, const string& worker_name,
+    WorkerCacheInterface* worker_cache) {
+  return new RdmaRendezvousMgr(env, worker_name, worker_cache);
+}
+
+}  // namespace
+
+VerbsServer::VerbsServer(const ServerDef& server_def, Env* env)
+    : GrpcServer(server_def, env), verbs_state_(DISCONNECTED) {}
+
+VerbsServer::~VerbsServer() {
+  TF_CHECK_OK(Stop());
+  TF_CHECK_OK(Join());
+  delete rdma_mgr_;
+  delete verbs_service_;
+  delete channel_cache_;
+}
+
+Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
+                                        GrpcChannelCache** channel_cache) {
+  string name_prefix =
+      strings::StrCat("/job:", server_def.job_name(), "/replica:0",
+                      "/task:", server_def.task_index());
+
+  GrpcChannelSpec channel_spec;
+  TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
+
+  *channel_cache =
+      NewGrpcChannelCache(channel_spec, GetChannelCreationFunction(server_def));
+
+  const string host_port = (*channel_cache)->TranslateTask(name_prefix);
+  int requested_port;
+
+  if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
+                             &requested_port)) {
+    return errors::Internal("Could not parse port for local server from \"",
+                            (*channel_cache)->TranslateTask(name_prefix),
+                            "\".");
+  }
+  if (requested_port != bound_port()) {
+    return errors::InvalidArgument("Requested port ", requested_port,
+                                   " differs from expected port ",
+                                   bound_port());
+  }
+
+  return Status::OK();
+}
+
+Status VerbsServer::Init(ServiceInitFunction service_func,
+                         RendezvousMgrCreationFunction rendezvous_mgr_func) {
+  Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
+  {
+    mutex_lock l(mu_);
+    CHECK_EQ(verbs_state_, DISCONNECTED);
+    CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
+    rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
+    // set rdma_mgr for verbs_service and rdma_rendezvous_mgr
+    verbs_service_->SetRdmaMgr(rdma_mgr_);
+    // hardcoded to default session (legacy_session_)
+    // TODO: use WorkerSessionForSession
+    // need to pass in session handle
+    dynamic_cast<RdmaRendezvousMgr*>(
+        worker_env()->session_mgr->LegacySession()->rendezvous_mgr.get())
+        ->SetRdmaMgr(rdma_mgr_);
+  }
+  return s;
+}
+
+Status VerbsServer::Start() {
+  Status s = GrpcServer::Start();
+  {
+    mutex_lock l(mu_);
+    if (verbs_state_ == DISCONNECTED) {
+      // verbs_thread needs to be initiated
+      // before rdma_mgr sets up the rdma channels.
+      verbs_thread_.reset(worker_env()->env->StartThread(
+          ThreadOptions(), "TF_verbs_service",
+          [this] { verbs_service_->HandleRPCsLoop(); }));
+      rdma_mgr_->SetupChannels();
+      verbs_state_ = CONNECTED;
+    }
+  }
+  return s;
+}
+
+Status VerbsServer::Join() {
+  Status s = GrpcServer::Join();
+  {
+    mutex_lock l(mu_);
+    if (verbs_state_ == CONNECTED) {
+      verbs_state_ = DISCONNECTED;
+      verbs_thread_.reset();
+    }
+  }
+  return s;
+}
+
+/* static */
+Status VerbsServer::Create(const ServerDef& server_def, Env* env,
+                           std::unique_ptr<ServerInterface>* out_server) {
+  std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
+  ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
+                                            ::grpc::ServerBuilder* builder) {
+    return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
+  };
+  TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));
+  *out_server = std::move(ret);
+  return Status::OK();
+}
+
+namespace {
+
+class VerbsServerFactory : public ServerFactory {
+ public:
+  bool AcceptsOptions(const ServerDef& server_def) override {
+    return server_def.protocol() == "grpc+verbs";
+  }
+
+  Status NewServer(const ServerDef& server_def,
+                   std::unique_ptr<ServerInterface>* out_server) override {
+    return VerbsServer::Create(server_def, Env::Default(), out_server);
+  }
+};
+
+// Registers a `ServerFactory` for `VerbsServer` instances.
+class VerbsServerRegistrar {
+ public:
+  VerbsServerRegistrar() {
+    gpr_allocation_functions alloc_fns;
+    alloc_fns.malloc_fn = port::Malloc;
+    alloc_fns.realloc_fn = port::Realloc;
+    alloc_fns.free_fn = port::Free;
+    gpr_set_allocation_functions(alloc_fns);
+    ServerFactory::Register("VERBS_SERVER", new VerbsServerFactory());
+  }
+};
+static VerbsServerRegistrar registrar;
+
+}  // namespace
+}  // namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.h b/tensorflow/contrib/verbs/verbs_server_lib.h
new file mode 100644
index 00000000000..855380129f2
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_server_lib.h
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
+
+#ifdef TENSORFLOW_USE_VERBS
+
+#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
+#include "tensorflow/contrib/verbs/rdma_mgr.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+
+namespace tensorflow {
+
+class VerbsServer : public GrpcServer {
+ protected:
+  VerbsServer(const ServerDef& server_def, Env* env);
+
+ public:
+  static Status Create(const ServerDef& server_def, Env* env,
+                       std::unique_ptr<ServerInterface>* out_server);
+
+  // Destruction is only supported in the factory method. Clean
+  // shutdown is not currently implemented for this server type.
+  virtual ~VerbsServer() override;
+
+  // Implementations of ServerInterface methods.
+  Status Start() override;
+  Status Join() override;
+
+ protected:
+  Status Init(ServiceInitFunction service_func,
+              RendezvousMgrCreationFunction rendezvous_mgr_func);
+  Status ChannelCacheFactory(const ServerDef& server_def,
+                             GrpcChannelCache** channel_cache);
+
+ private:
+  RdmaMgr* rdma_mgr_;
+
+  // Guards state transitions.
+  mutex mu_;
+
+  enum State { DISCONNECTED, CONNECTED };
+  State verbs_state_ GUARDED_BY(mu_);
+
+  GrpcVerbsService* verbs_service_ = nullptr;
+  std::unique_ptr<Thread> verbs_thread_ GUARDED_BY(mu_);
+  GrpcChannelCache* channel_cache_ = nullptr;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_USE_VERBS
+#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
diff --git a/tensorflow/contrib/verbs/verbs_service.proto b/tensorflow/contrib/verbs/verbs_service.proto
new file mode 100644
index 00000000000..b985febfb8c
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_service.proto
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+option java_outer_classname = "VerbsServiceProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.contrib.verbs";
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// GRPC Helper messages used to exchange RDMA information.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message Channel {
+  int32 lid = 1;
+  int32 qpn = 2;
+  int32 psn = 3;
+}
+
+message MemoryRegion {
+  uint64 remote_addr = 1;
+  uint32 rkey = 2;
+}
+message GetRemoteAddressRequest {
+  string host_name = 1;
+  Channel channel = 2;
+  repeated MemoryRegion mr = 3;
+}
+
+message GetRemoteAddressResponse {
+  string host_name = 1;
+  Channel channel = 2;
+  repeated MemoryRegion mr = 3;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// VerbsService
+//
+////////////////////////////////////////////////////////////////////////////////
+
+service VerbsService {
+  rpc GetRemoteAddress(GetRemoteAddressRequest)
+      returns (GetRemoteAddressResponse);
+}
diff --git a/tensorflow/contrib/verbs/verbs_util.cc b/tensorflow/contrib/verbs/verbs_util.cc
new file mode 100644
index 00000000000..c3350f7958c
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_util.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/verbs/verbs_util.h"
+
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+namespace tensorflow {
+
+// static sync wrapper:
+Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
+                                      const DeviceContext* device_context,
+                                      TensorProto* proto, bool is_dead) {
+  Notification n;
+  Status status;
+  GPUUtil::SetProtoFromGPU(tensor, dev, device_context, proto, is_dead,
+                           [&n, &status](const Status& s) {
+                             status = s;
+                             n.Notify();
+                           });
+  n.WaitForNotification();
+  return status;
+}
+
+// static
+string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) {
+  return strings::StrCat(key, ";", step_id);
+}
+
+// static
+void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key,
+                                int64& step_id) {
+  StringPiece s(key_with_step_id);
+  // a key (with step_id) has exact 6 parts if split by ";"
+  // part 1: src_device;
+  // part 2: src_incarnation;
+  // part 3: dst_device;
+  // part 4: name;
+  // part 5: frame_iter.frame_id:frame_iter.iter_id
+  // part 6: step_id
+  std::vector<string> parts = str_util::Split(s, ';');
+  CHECK(parts.size() == 6) << "Key with step_id must have 6 parts";
+  strings::safe_strto64(parts[5], &step_id);
+  parts.pop_back();                        // remove step_id
+  key.assign(str_util::Join(parts, ";"));  // stitch them together
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/contrib/verbs/verbs_util.h b/tensorflow/contrib/verbs/verbs_util.h
new file mode 100644
index 00000000000..cbc01adae49
--- /dev/null
+++ b/tensorflow/contrib/verbs/verbs_util.h
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_RDMA_UTIL_H_
+#define TENSORFLOW_CONTRIB_RDMA_UTIL_H_
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class TensorProto;
+
+class VerbsUtil {
+ public:
+  // synchronous wrapper of SetProtoFromGPU
+  static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
+                                    const DeviceContext* device_context,
+                                    TensorProto* proto, bool is_dead);
+  static string AppendStepidToKey(const string& key, int64 step_id);
+  static void GetKeyAndStepId(const string& key_with_step_id, string& key,
+                              int64& step_id);
+};
+
+}  // namespace tensorflow
+#endif  // TENSORFLOW_CONTRIB_RDMA_UTIL_H_
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index d6143493877..71fba99aad1 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -108,6 +108,7 @@ load(
     "tf_additional_cloud_op_deps",
     "tf_additional_cloud_kernel_deps",
     "tf_lib_proto_parsing_deps",
+    "tf_additional_verbs_lib_defines",
 )
 load(
     "//tensorflow/core:platform/default/build_config_root.bzl",
@@ -732,9 +733,13 @@ cc_library(
         "//tensorflow/core/kernels:math_not_windows",
         "//tensorflow/core/kernels:quantized_ops",
     ]) + if_mkl([
+        "//tensorflow/core/kernels:mkl_concat_op",
         "//tensorflow/core/kernels:mkl_conv_op",
+        "//tensorflow/core/kernels:mkl_fused_batch_norm_op",
+        "//tensorflow/core/kernels:mkl_lrn_op",
         "//tensorflow/core/kernels:mkl_pooling_ops",
         "//tensorflow/core/kernels:mkl_relu_op",
+        "//tensorflow/core/kernels:mkl_reshape_op",
         "//tensorflow/core/kernels:mkl_tfconv_op",
     ]),
 )
@@ -1272,7 +1277,9 @@ cc_library(
         "platform/tracing.h",
     ],
     copts = tf_copts(),
-    defines = tf_additional_lib_defines() + ["SNAPPY"],
+    defines = tf_additional_lib_defines() + [
+        "SNAPPY",
+    ] + tf_additional_verbs_lib_defines(),
     linkopts = select({
         "//tensorflow:freebsd": [],
         "//conditions:default": ["-ldl"],
@@ -2089,7 +2096,6 @@ tf_cc_test_mkl(
     size = "small",
     srcs = [
         "graph/mkl_layout_pass_test.cc",
-        "graph/mkl_optimizer_merge_test.cc",
         "graph/mkl_tfconversion_pass_test.cc",
     ],
     linkstatic = tf_kernel_tests_linkstatic(),
@@ -2110,9 +2116,13 @@ tf_cc_test_mkl(
         "//tensorflow/cc:cc_ops",
         "//tensorflow/cc:scope",
         "//tensorflow/cc:sendrecv_ops",
+        "//tensorflow/core/kernels:mkl_concat_op",
         "//tensorflow/core/kernels:mkl_conv_op",
+        "//tensorflow/core/kernels:mkl_fused_batch_norm_op",
+        "//tensorflow/core/kernels:mkl_lrn_op",
         "//tensorflow/core/kernels:mkl_pooling_ops",
         "//tensorflow/core/kernels:mkl_relu_op",
+        "//tensorflow/core/kernels:mkl_reshape_op",
         "//tensorflow/core/kernels:mkl_tfconv_op",
         "//tensorflow/core/kernels:ops_util",
         "//third_party/eigen3",
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 2be510ee9b8..a79ea1b45d7 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -17,7 +17,9 @@ limitations under the License.
 
 #include <vector>
 
+#if defined(PLATFORM_GOOGLE)
 #include "grpc++/create_channel.h"
+#endif
 
 #if defined(PLATFORM_WINDOWS)
 // winsock2.h is used in grpc, so Ws2_32.lib is needed
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 1aafa862cb5..7160962b168 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -62,6 +62,13 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
                          plugins) override {}
 };
 
+// static utility function
+RendezvousMgrInterface* NewRpcRendezvousMgr(
+    const WorkerEnv* env, const string& worker_name,
+    WorkerCacheInterface* worker_cache) {
+  return new RpcRendezvousMgr(env, worker_name, worker_cache);
+}
+
 }  // namespace
 
 GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
@@ -93,7 +100,8 @@ GrpcServer::~GrpcServer() {
   // - worker_env_.compute_pool
 }
 
-Status GrpcServer::Init() {
+Status GrpcServer::Init(ServiceInitFunction service_func,
+                        RendezvousMgrCreationFunction rendevous_mgr_func) {
   mutex_lock l(mu_);
   CHECK_EQ(state_, NEW);
   master_env_.env = env_;
@@ -170,6 +178,10 @@ Status GrpcServer::Init() {
   worker_impl_ = NewGrpcWorker(&worker_env_);
   worker_service_ =
       NewGrpcWorkerService(worker_impl_.get(), &builder).release();
+  // extra service:
+  if (service_func != nullptr) {
+    service_func(&worker_env_, &builder);
+  }
   server_ = builder.BuildAndStart();
 
   if (!server_) {
@@ -182,7 +194,9 @@ Status GrpcServer::Init() {
 
   // Set up worker environment.
   std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
-      new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache));
+      rendevous_mgr_func == nullptr ?
+      new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
+      rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
   worker_env_.session_mgr = new SessionMgr(
       &worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
       std::unique_ptr<WorkerCacheInterface>(worker_cache),
@@ -211,6 +225,10 @@ Status GrpcServer::Init() {
   return Status::OK();
 }
 
+Status GrpcServer::Init() {
+  return Init(nullptr, nullptr);
+}
+
 Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
                                     GrpcChannelSpec* channel_spec) {
   for (const auto& job : server_def.cluster().job()) {
@@ -248,6 +266,7 @@ Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def,
       channel_spec, GetChannelCreationFunction(server_def)));
   const string host_port = channel_cache->TranslateTask(name_prefix);
   int requested_port;
+
   if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
                              &requested_port)) {
     return errors::Internal("Could not parse port for local server from \"",
@@ -346,7 +365,8 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
                           std::unique_ptr<ServerInterface>* out_server) {
   std::unique_ptr<GrpcServer> ret(
       new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
-  TF_RETURN_IF_ERROR(ret->Init());
+  ServiceInitFunction service_func = nullptr;
+  TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
   *out_server = std::move(ret);
   return Status::OK();
 }
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index c6ba2601041..3b66291a9ab 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -36,6 +36,17 @@ namespace tensorflow {
 class GrpcWorker;
 class Master;
 
+// function that creates a RendezvousMgr.
+typedef std::function<RendezvousMgrInterface*(
+    const WorkerEnv*, const std::string& worker_name,
+    WorkerCacheInterface* worker_cache)>
+    RendezvousMgrCreationFunction;
+
+// function that registers a service to the server. The service needs to
+// be registered before builder.BuildAndStart().
+typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
+    ServiceInitFunction;
+
 class GrpcServer : public ServerInterface {
  protected:
   GrpcServer(const ServerDef& server_def, Env* env);
@@ -55,6 +66,9 @@ class GrpcServer : public ServerInterface {
   const string target() const override;
 
  protected:
+  Status Init(ServiceInitFunction service_func,
+              RendezvousMgrCreationFunction rendezvous_mgr_func);
+
   Status Init();
 
   // A subclass can override this method to support secure credentials.
@@ -78,6 +92,10 @@ class GrpcServer : public ServerInterface {
   // This method may only be called after `this->Init()` returns successfully.
   int bound_port() const { return bound_port_; }
 
+  WorkerEnv* worker_env() { return &worker_env_; }
+
+  const ServerDef& server_def() const { return server_def_; }
+
  private:
   // The overall server configuration.
   const ServerDef server_def_;
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index fb1ad0102f6..e45f156e1e5 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -126,25 +126,33 @@ FunctionDef XTimes16() {
       {{"y", "y:y:0"}});
 }
 
-FunctionDef WXPlusB() {
-  return FDH::Define(
-      // Name
-      "WXPlusB",
-      // Args
-      {"w: T", "x: T", "b: T"},
-      // Return values
-      {"y: T"},
-      // Attr def
-      {"T: {float, double}"},
-      // Nodes
-      {{{"mm"},
-        "MatMul",
-        {"w", "x"},
-        {{"T", "$T"},
-         {"transpose_a", false},
-         {"transpose_b", false},
+FunctionDef WXPlusB(){return FDH::Define(
+    // Name
+    "WXPlusB",
+    // Args
+    {"w: T", "x: T", "b: T"},
+    // Return values
+    {"y: T"},
+    // Attr def
+    {"T: {float, double}"},
+    // Nodes
+    {
+      {{"mm"},
+       "MatMul",
+       {"w", "x"},
+       {
+           {"T", "$T"}, {"transpose_a", false}, {"transpose_b", false},
+#ifdef INTEL_MKL
+       }},
+#else
          {"_kernel", "eigen"}}},
-       {{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
+#endif
+      {
+        {"y"}, "Add", {"mm", "b"}, {
+          { "T", "$T" }
+        }
+      }
+    });
 }
 
 FunctionDef Swap() {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 309c4cd774c..09b632a1650 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -48,7 +48,7 @@ namespace tensorflow {
 //     1) Propagating Mkl layout as an additional output tensor
 //        (we will loosely call a tensor that carries Mkl layout as Mkl tensor
 //         henceforth.) from every Mkl supported NN layer.
-//     2) Context-based rewrite: This is neded in order to optimize
+//     2) Context-based rewrite: This is needed in order to optimize
 //        gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and
 //        MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into
 //        Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad.
@@ -63,12 +63,12 @@ namespace tensorflow {
 //           P = BiasAdd(O, C)
 //
 // We merge them into Conv2DWithBias as:
-//           P = MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
+//           P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
 //
-// Meaning of A_m, B_m and C_m is explained in B.1.
+// The meaning of A_m, B_m and C_m is explained in B.1.
 //
 // Merge rules:
-//  - Merge for Conv2D and BiasAdd happens only when output of Conv2D _only_
+//  - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
 //    goes to BiasAdd.
 //  - Also, the intersection of attributes of both the nodes must have same
 //    values.
@@ -76,7 +76,7 @@ namespace tensorflow {
 //
 // Example of B.1 : Rewriting nodes to Mkl nodes
 // ---------------------------------------------
-// Consider Relu layer. Current definition of Relu layer looks like:
+// Consider a Relu node. Current definition of Relu node looks like:
 //
 //           O = Relu(A)
 //
@@ -87,58 +87,59 @@ namespace tensorflow {
 //
 //          O, O_m = MklRelu(A, A_m)
 //
-// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here A input is
-// same as A input of Relu; O output is same as O output of Relu. O_m is the
+// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
+// same as input A of Relu; output O is same as output O of Relu. O_m is the
 // additional output tensor that will be set by MklRelu, and it represents
 // Mkl tensor corresponding to O -- in other words, O_m is some kind of
 // metadata for O. A_m is additional input of Relu, and it represents metadata
 // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
-// this metadata from previous layer (in the graph).
+// this metadata from previous node in the graph.
 //
-// When previous layer in the graph is Mkl layer, A_m will represent a valid
-// Mkl tensor. But when previous Mkl layer is not an Mkl layer, then A_m
-// represents a dummy Mkl tensor.
+// When a previous node in the graph is an Mkl node, A_m will represent a valid
+// Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
+// a dummy Mkl tensor.
 //
 // Rewriting rules:
-//  - Selection of an op for rewriting happens by registering an op with this
-//     pass. If an op is not registered, then it is not rewritten.
+//  - Selection of a node for rewriting happens by registering the op type of
+//    the node with the rewriting pass. If the op type is not registered, then
+//    all nodes of this op type will not be rewritten.
 //  - Number of inputs after rewriting:
-//      Since for every input Tensorflow tensor, the rewritten layer gets Mkl
-//      tensor, rewritten op gets 2*N inputs, where N is the number of inputs
-//      for original op.
+//      Since for every input Tensorflow tensor, the rewritten node gets Mkl
+//      tensor(s), rewritten node gets 2*N inputs, where N is the number of
+//      inputs for the original node.
 //  - Number of outputs after rewriting:
-//      Since for every output Tensorflow tensor, the rewritten layer generates
-//      Mkl tensor, rewritten op generates 2*N outputs, where N is the number
-//      of outputs of original op.
+//      Since for every output Tensorflow tensor, the rewritten node generates
+//      Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
+//      number of outputs of the original node.
 //  - Ordering of Tensorflow tensors and Mkl tensors:
-//      Since every op generates twice the number of inputs and outputs, one
-//      could imagine different ordering among Tensorflow tensors and Mkl
-//      tensors. E.g., let's assume an op 'Conv2D' takes (A, B) as input, then
-//      new op 'MklConv2D' can take (A, A_m, B, B_m) as input or it can also
-//      take (A, B, A_m, B_m) as input. Among N inputs one can get N!
-//      permutations.
+//      Since every rewritten node generates twice the number of inputs and
+//      outputs, one could imagine various orderings among Tensorflow tensors
+//      and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
+//      inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
+//      in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
+//      order. Among N inputs one can get N! permutations.
 //
-//      So the question is: which one do we follow? Currently, we follow an
-//      intuitive order where Mkl tensor follows a corresponding Tensorflow
-//      tensor immediately. In the context of above example, it will be: (A,
-//      A_m, B, B_m). We follow same ordering rule for output tensors.
-//
-// NOTE: Current rewriting approach rewrites an op to Mkl op without any
-//      conditions. But in the future, it may be possible to consider
-//      conditions such as input shapes and sizes to rewrite an op.
+//      So the question is: which order do we follow? We support 2 types of
+//      orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
+//      follows an intuitive order where an Mkl tensor follows the
+//      corresponding Tensorflow tensor immediately. In the context of the
+//      above example, it will be: A, A_m, B, B_m. Note that the ordering rule
+//      applies to both the inputs and outputs. Contiguous ordering means
+//      all the Tensorflow tensors are contiguous followed by all the Mkl
+//      tensors. We use contiguous ordering as default.
 //
 // Graph rewrite algorithm:
 //      Algorithm: Graph Rewrite
-//      Input: Graph G, Names of nodes to rewrite and their new nodes
-//      Output: Modified Graph G' if nodes are modified, G otherwise.
+//      Input: Graph G, Names of the nodes to rewrite and their new names
+//      Output: Modified Graph G' if the nodes are modified, G otherwise.
 //      Start:
-//        N = Topological_Sort(G) // N is set of nodes in toposort order.
+//        N = Topological_Sort(G) // N is a set of nodes in toposort order.
 //        foreach node n in N
 //        do
-//          if (Is_MKL_Layer(n))  // Can this layer accept Mkl layout as input.
+//          if (Is_MKL_Op(n))  // Can this node accept an Mkl layout as input.
 //          then
 //            E = set of <incoming edge and its src_output slot> of n
-//            E' = {}   // new set of edges for rewritten node
+//            E' = {}   // a new set of edges for rewritten node
 //            foreach <e,s> in E
 //            do
 //              E' U {<e,s>}  // First copy edge which generates Tensorflow
@@ -146,42 +147,44 @@ namespace tensorflow {
 //              m = Source node of edge e
 //              if Is_Rewritten(m)  // Did we rewrite this node in this pass?
 //              then
-//                E' U {<m,s+1>}    // If yes, then m will generate Mkl tensor
-//                                  // as output.
+//                E' U {<m,s+1>}    // If yes, then m will generate an Mkl
+//                                  // tensor as an additional output.
 //              else
-//                d = Generate_Dummy_Mkl_Tensor()  // If not, generate dummy
+//                d = Generate_Dummy_Mkl_Tensor()  // If not, generate a dummy
 //                                                 // Mkl tensor.
-//                E' U {<d,0>}   // Dummy Mkl tensor has only 1 output slot.
+//                E' U {<d,0>}  // The dummy Mkl tensor has only 1 output slot.
 //              fi
 //            done
 //            n' = Build_New_Node(G,new_name,E')
-//            Mark_Rewritten(n')  // Mark new node as being rewritten.
+//            Mark_Rewritten(n')  // Mark the new node as being rewritten.
 //          fi
 //        done
 //
 //      Explanation:
-//        For graph rewrite, we visit nodes of the graph in the topological
-//        sort order. With this ordering, we visit nodes in top-to-bottom
-//        fashion. We need this order because while visiting a node we want
-//        all of its input nodes (parents) visited (and rewritten if
-//        applicable). This is because if we need to rewrite a current node
+//        For graph rewrite, we visit nodes of the input graph in the
+//        topological sort order. With this ordering, we visit nodes in the
+//        top-to-bottom fashion. We need this order because while visiting a
+//        node we want that all of its input nodes are visited and rewritten if
+//        applicable. This is because if we need to rewrite a given node
 //        then all of its input nodes need to be fixed (in other words they
-//        cannot be removed later.)
+//        cannot be deleted later.)
 //
-//        While visiting each node, we first check if it is Mkl layer. If
-//        it is, then we rewrite that node after constructing new inputs to
-//        the node. If it is not Mkl layer, then we do not rewrite the node.
+//        While visiting a node, we first check if the op type of the node is
+//        an Mkl op. If it is, then we rewrite that node after constructing
+//        new inputs to the node. If the op type of the node is not Mkl op,
+//        then we do not rewrite that node.
 //
 // Handling workspace propagation for certain ops:
 //
 //        Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
-//        passing of workspace from their corresponding forward ops. But
-//        TensorFlow does not have a notion of workspace and as a result
-//        does not allow producing additional outputs from these forward ops.
-//        For these ops, we need to add an additional edge between forward
-//        ops and their corresponding backward ops, and this edge carries
-//        workspace tensor value and another edge carries Mkl tensor for
-//        workspace tensor.
+//        passing of a workspace from their respective forward ops. Workspace
+//        tensors provide memory for storing results of intermediate operations
+//        which are helpful in backward propagation. TensorFlow does not have
+//        a notion of a workspace and as a result does not allow producing
+//        additional outputs from these forward ops. For these ops, we need
+//        to add 2 extra edges between forward ops and their corresponding
+//        backward ops - the first extra edge carries a workspace tensor and
+//        the second one carries an Mkl tensor for the workspace tensor.
 //
 //        Example:
 //
@@ -190,59 +193,61 @@ namespace tensorflow {
 //        A = MaxPool(T)
 //        B = MaxPoolGrad(X, A, Y)
 //
-//        We will transform this graph to propagate workspace as:
+//        We will transform this graph to propagate the workspace as:
+//        (with the contiguous ordering)
 //
-//        A, A_m, W, W_m = MklMaxPool(T, T_m)
-//        B, B_m = MklMaxPoolGrad(X, X_m, A, A_m, Y, Y_m, W, W_m)
+//        A, W, A_m, W_m = MklMaxPool(T, T_m)
+//        B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
 //
-//        Here W is the workspace tensor. Transformed tensors with name
-//        suffix _m are Mkl tensors and this transformation has been done
+//        Here W is the workspace tensor. Transformed tensor names with the
+//        suffix _m are Mkl tensors, and this transformation has been done
 //        using the algorithm discussed earlier. The transformation for
-//        workspace only adds extra outputs (W, W_m) for forward op and
-//        connects them to corresponding backward ops.
+//        workspace propagation only adds extra outputs (W, W_m) for a forward
+//        op and connects them to the corresponding backward ops.
 //
 //        Terms:
 //
 //        Forward op name = name of the op in the forward pass
-//          where workspace originates (MaxPool in this example)
+//          where a workspace tensor originates (MaxPool in this example)
 //        Backward op name = name of the op in the backward pass that receives
-//          workspace from forward op (MaxPoolGrad in the example)
-//        Slot = Number of the output or input slot that will be
-//               used by the workspace (2 for MklMaxPool as W is 3rd
-//               output of MaxPool (0 is 1st); 6 for MklMaxPoolGrad)
+//          a workspace tensor from the forward op (MaxPoolGrad in the example)
+//        Slot = Position of the output or input slot that will be
+//               used by the workspace tensor (1 for MklMaxPool as W is the 2nd
+//               output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
 //
 //        Question:
 //
-//        How do we associate backward op to forward op? There can be more
-//        than one op with exact same name.
+//        How do we associate a backward op to a forward op? There can be more
+//        than one op with the exact same name.
 //
-//        In this example we associate MaxPoolGrad with MaxPool. But there
+//        In this example, we associate MaxPoolGrad with MaxPool. But there
 //        could be more than one MaxPool ops. To solve this problem, we look
-//        for _direct_ edge between forward op and backward op (tensor A is
-//        flowing along this edge in the example.)
+//        for _direct_ edge between a forward op and a backward op (tensor A is
+//        flowing along this edge in the example).
 //
-//        How do we transform forward and backward op when there is no direct
-//        edge between them? In such case, we generate dummy tensors as
+//        How do we transform forward and backward ops when there is no direct
+//        edge between them? In such a case, we generate dummy tensors for
 //        workspace tensors. For the example, transformation of MaxPool will
-//        be exactly same --- it is just that MaxPool won't generate any
-//        workspace tensor. For MaxPoolGrad, transformation will also be same,
-//        but instead of connecting W and W_m with outputs of MaxPool, we will
-//        produce dummy tensors for them, and we will set workspace_enabled
-//        attribute to false.
+//        be exactly same as it would be when there is a direct edge between
+//        the forward and the backward op --- it is just that MaxPool won't
+//        generate any workspace tensor. For MaxPoolGrad, the transformation
+//        will also be same, but instead of connecting W and W_m with the
+//        outputs of MaxPool, we will produce dummy tensors for them, and we
+//        will set workspace_enabled attribute to false.
 //
 // Example of B.2 : Context-based node rewrite
 // -------------------------------------------
 // Consider BiasAddGrad op as:
 //
-//           O = MklConv2D(A, A_m, B, B_m, C, C_m)
+//           O = _MklConv2D(A, B, C, A_m, B_m, C_m)
 //           P = BiasAddGrad(O)
 //
-// Then we rewrite is as:
+// Then we rewrite it as:
 //
 //           P = Conv2DWithBiasBackpropBias(O, O_m)
 //
-// 'Distance' between input of BiasAddGrad and MklConv2D in terms of hops is
-// the context matching depth. If MklConv2DWithBias is not within the context
+// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is
+// the context matching depth. If _MklConv2DWithBias is not within the context
 // matching depth, then we do not rewrite BiasAddGrad.
 
 // How many hops do we search for matching node in the backward dataflow graph?
@@ -255,53 +260,85 @@ static size_t kNodeMergeContextMaxDepth = 10;
 class MklLayoutRewritePass : public GraphOptimizationPass {
  public:
   MklLayoutRewritePass() {
+    // NOTE: names are alphabetically sorted.
+    csinfo_.avg_pool = "AvgPool";
+    csinfo_.avg_pool_grad = "AvgPoolGrad";
+    csinfo_.bias_add = "BiasAdd";
+    csinfo_.bias_add_grad = "BiasAddGrad";
+    csinfo_.concat = "Concat";
+    csinfo_.concatv2 = "ConcatV2";
     csinfo_.conv2d = "Conv2D";
-    csinfo_.mklconv2d = "MklConv2D";
-    csinfo_.mklconv2dwithbias = "MklConv2DWithBias";
-    csinfo_.mklconv2dwithbiasbackpropbias = "MklConv2DWithBiasBackpropBias";
-    csinfo_.biasadd = "BiasAdd";
+    csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
+    csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
+    csinfo_.fused_batch_norm = "FusedBatchNorm";
+    csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
+    csinfo_.lrn = "LRN";
+    csinfo_.lrn_grad = "LRNGrad";
     csinfo_.matmul = "MatMul";
-    csinfo_.biasaddgrad = "BiasAddGrad";
+    csinfo_.max_pool = "MaxPool";
+    csinfo_.max_pool_grad = "MaxPoolGrad";
+    csinfo_.mkl_conv2d = "_MklConv2D";
+    csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
+    csinfo_.mkl_conv2d_with_bias_backprop_bias =
+        "_MklConv2DWithBiasBackpropBias";
     csinfo_.relu = "Relu";
-    csinfo_.relugrad = "ReluGrad";
-    csinfo_.maxpool = "MaxPool";
-    csinfo_.maxpoolgrad = "MaxPoolGrad";
-    csinfo_.avgpool = "AvgPool";
-    csinfo_.avgpoolgrad = "AvgPoolGrad";
-    csinfo_.conv2dgradinput = "Conv2DBackpropInput";
-    csinfo_.conv2dgradfilter = "Conv2DBackpropFilter";
+    csinfo_.reshape = "Reshape";
+    csinfo_.relu_grad = "ReluGrad";
+    csinfo_.split = "Split";
 
-    rinfo_.push_back(
-        {csinfo_.conv2d, csinfo_.mklconv2d, 2, CopyAttrsConv2D, AlwaysRewrite});
-    rinfo_.push_back({csinfo_.conv2dgradfilter,
-                      GetMklOpName(csinfo_.conv2dgradfilter), 3,
+    // NOTE: names are alphabetically sorted.
+    rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1,
+                      CopyAttrsPooling, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.avg_pool_grad,
+                      GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling,
+                      AlwaysRewrite});
+    rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0,
+                      CopyAttrsConcat, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0,
+                      CopyAttrsConcatV2, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2,
                       CopyAttrsConv2D, AlwaysRewrite});
-    rinfo_.push_back({csinfo_.conv2dgradinput,
-                      GetMklOpName(csinfo_.conv2dgradinput), 3, CopyAttrsConv2D,
+    rinfo_.push_back({csinfo_.conv2d_grad_filter,
+                      GetMklOpName(csinfo_.conv2d_grad_filter), 3,
+                      CopyAttrsConv2D, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.conv2d_grad_input,
+                      GetMklOpName(csinfo_.conv2d_grad_input), 3,
+                      CopyAttrsConv2D, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.fused_batch_norm,
+                      GetMklOpName(csinfo_.fused_batch_norm), 5,
+                      CopyAttrsFusedBatchNorm, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.fused_batch_norm_grad,
+                      GetMklOpName(csinfo_.fused_batch_norm_grad), 5,
+                      CopyAttrsFusedBatchNorm, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN,
+                      AlwaysRewrite});
+    rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3,
+                      CopyAttrsLRN, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1,
+                      CopyAttrsPooling, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.max_pool_grad,
+                      GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling,
                       AlwaysRewrite});
     rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1,
                       CopyAttrsRelu, AlwaysRewrite});
-    rinfo_.push_back({csinfo_.maxpool, GetMklOpName(csinfo_.maxpool), 1,
-                      CopyAttrsPooling, AlwaysRewrite});
-    rinfo_.push_back({csinfo_.maxpoolgrad, GetMklOpName(csinfo_.maxpoolgrad), 3,
-                      CopyAttrsPooling, AlwaysRewrite});
-    rinfo_.push_back({csinfo_.avgpool, GetMklOpName(csinfo_.avgpool), 1,
-                      CopyAttrsPooling, AlwaysRewrite});
-    rinfo_.push_back({csinfo_.avgpoolgrad, GetMklOpName(csinfo_.avgpoolgrad), 2,
-                      CopyAttrsPooling, AlwaysRewrite});
+    rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2,
+                      CopyAttrsReshape, AlwaysRewrite});
+
+    // TODO(inteltf): we do not support ReluGrad and BiasAddGrad yet.
 
     // Add info about which ops to add workspace edge to and the slots.
-    wsinfo_.push_back({csinfo_.maxpool, csinfo_.maxpoolgrad, 0, 1, 2, 6});
+    wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
+    wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
 
     // Add a rule for merging nodes
-    minfo_.push_back(
-        {csinfo_.mklconv2d, csinfo_.biasadd, 0, csinfo_.mklconv2dwithbias});
+    minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
+                      csinfo_.mkl_conv2d_with_bias});
 
     // We use maxhop of 10 based on empirical observations. Also, these are
     // maxhops in backward data-flow graph. Since input of forward nodes
     // (Conv2D) directly goes to backward nodes, we do not expect the
     // hop-distance would be more than few nodes.
-    cinfo_.push_back({csinfo_.biasaddgrad, csinfo_.mklconv2dwithbias,
+    cinfo_.push_back({csinfo_.bias_add_grad, csinfo_.mkl_conv2d_with_bias,
                       kNodeMergeContextMaxDepth});
   }
 
@@ -318,73 +355,80 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
   bool RunPass(std::unique_ptr<Graph>* g);
 
  private:
-  /// Structure to specify name of original op, its new name after rewrite,
-  /// the number of inputs to the original op, and the function to be used
-  /// to copy attributes for the op
+  /// Structure to specify the name of an original node, its new name after
+  /// rewrite, the number of inputs to the original node, the function to
+  /// be used to copy attributes for the op, and the rule (if any) which
+  /// must hold for rewriting the node
   typedef struct {
-    string name;     // Original name of the op in the graph
-    string newname;  // New name of op in the graph
-    int numins;      // Number of inputs to the original op
-    // Function handler to copy attributes from old node to new node.
-    std::function<void(const Node*, NodeBuilder*)> copyattrs;
-    std::function<bool(const Node*)> rewriterule;  // Rule under which to
-                                                   // rewrite this node.
+    string name;      // Original name of op of the node in the graph
+    string new_name;  // New name of the op of the node in the graph
+    int num_ins;      // The number of inputs to the original op type
+    // A function handler to copy attributes from an old node to a new node.
+    std::function<void(const Node*, NodeBuilder*)> copy_attrs;
+    std::function<bool(const Node*)> rewrite_rule;  // A rule under which to
+                                                    // rewrite this node.
   } RewriteInfo;
 
-  /// Structure to specify forward op, backward op, and the slot numbers
-  /// in forward and backward op where we will add workspace edge.
+  /// Structure to specify a forward op, a backward op, and the slot numbers
+  /// in the forward and backward ops where we will add a workspace edge.
   typedef struct {
-    string fwdop;   // Name of the forward op in the graph
-    string bwdop;   // Name of the backward op in the graph
-    int fwdslot;    // Output slot in the forward op node where actual
-                    // output tensor resides
-    int bwdslot;    // Input slot in the backward op node where actual
-                    // input tensor resides
-    int wsfwdslot;  // Output slot in the forward op node where workspace
-                    // edge is added
-    int wsbwdslot;  // Input slot in the backward op node where workspace
-                    // edge is added
+    string fwd_op;    // Name of a forward op in the graph
+    string bwd_op;    // Name of a backward op in the graph
+    int fwd_slot;     // Output slot in the forward op node where actual
+                      // output tensor resides
+    int bwd_slot;     // Input slot in the backward op node where actual
+                      // input tensor resides
+    int ws_fwd_slot;  // Output slot in the forward op node where workspace
+                      // edge is added
+    int ws_bwd_slot;  // Input slot in the backward op node where workspace
+                      // edge is added
   } WorkSpaceInfo;
 
   /// Structure to specify information used in node merge
   typedef struct {
-    string pred;     // Predecessor node string
-    string succ;     // Successor node string
-    int op;          // What operand no the predecessor node corresponds
-                     // to successor node?
-    string newnode;  // Name of the node after merge
+    string pred;      // Predecessor node string
+    string succ;      // Successor node string
+    int op;           // The operand no the predecessor node corresponds
+                      // to the successor node
+    string new_node;  // Name of the node after merge
   } MergeInfo;
 
-  /// Structure to specify the context information used in node rewrite rule
+  /// Structure to specify the context information used in a node rewrite rule
   typedef struct {
-    string node;    // Name of the node to be rewritten
-    string fwd;     // Node name in forward pass that this node
-                    // corresponds to
-    size_t maxhop;  // Maximum number of hops the fwd is located
-                    // from this node. If fwd is farther than maxhop
-                    // then we do not rewrite the node.
+    string node;     // Name of the node to be rewritten
+    string fwd;      // Name of the node in the forward pass that this node
+                     // corresponds to
+    size_t max_hop;  // Maximum number of hops the fwd is located
+                     // from this node. If the fwd is farther than max_hop
+                     // then we do not rewrite the node.
   } ContextInfo;
 
   /// Structure to store all constant strings
+  /// NOTE: names are alphabetically sorted.
   struct {
-    string relu;
-    string relugrad;
-    // Conv ops
+    string avg_pool;
+    string avg_pool_grad;
+    string bias_add;
+    string bias_add_grad;
+    string concat;
+    string concatv2;
     string conv2d;
-    string mklconv2d;
-    string conv2dgradinput;
-    string conv2dgradfilter;
-    string mklconv2dwithbias;
-    string mklconv2dwithbiasbackpropbias;
-    // Pooling ops
-    string maxpool;
-    string maxpoolgrad;
-    string avgpool;
-    string avgpoolgrad;
-    // Others
-    string biasadd;
+    string conv2d_grad_input;
+    string conv2d_grad_filter;
+    string fused_batch_norm;
+    string fused_batch_norm_grad;
+    string lrn;
+    string lrn_grad;
     string matmul;
-    string biasaddgrad;
+    string max_pool;
+    string max_pool_grad;
+    string mkl_conv2d;
+    string mkl_conv2d_with_bias;
+    string mkl_conv2d_with_bias_backprop_bias;
+    string relu;
+    string relu_grad;
+    string split;
+    string reshape;
   } csinfo_;
 
   /// Maintain info about nodes to rewrite
@@ -393,7 +437,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
   /// Maintain info about nodes to add workspace edge
   std::vector<WorkSpaceInfo> wsinfo_;
 
-  /// Maintain info  to be merged
+  /// Maintain info about nodes to be merged
   std::vector<MergeInfo> minfo_;
 
   /// Maintain info about nodes to rewrite
@@ -403,7 +447,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
   std::unordered_set<const Node*> visited_nodes_;
 
  private:
-  // Predicate to check if we rewrote node 'n'
+  // Check if we rewrote node 'n'
   //
   // If we rewrote the node, then the rewritten node will produce
   // Mkl tensor as output. If we did not rewrite the node, then
@@ -420,12 +464,49 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
   // Clear all visited nodes
   inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); }
 
+  // Is this a graph node that can accept variable number of inputs?
+  // Return true if yes, false otherwise.
+  //
+  // Concat, Split are vararg nodes.
+  inline bool IsVarArgNode(Node* n) {
+    if (n->type_string() == csinfo_.concat ||
+        n->type_string() == csinfo_.concatv2 ||
+        n->type_string() == csinfo_.split) {
+      return true;
+    }
+    return false;
+  }
+
+  // Is OpDef::ArgDef a list type? It could be N * T or list(type).
+  // Refer to opdef.proto for details of list type.
+  inline bool ArgIsList(const OpDef::ArgDef& arg) const {
+    return !arg.type_list_attr().empty() || !arg.number_attr().empty();
+  }
+
+  // Get length of a list in 'n' if 'arg' is of list type. Refer to
+  // description of ArgIsList for definition of list type.
+  inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
+    CHECK_EQ(ArgIsList(arg), true);
+    int N = 0;
+    const string attr_name = !arg.type_list_attr().empty()
+                                 ? arg.type_list_attr()
+                                 : arg.number_attr();
+    if (!arg.type_list_attr().empty()) {
+      std::vector<DataType> value;
+      TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
+      N = value.size();
+    } else {
+      TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
+    }
+    return N;
+  }
+
   // Get the name of Mkl op from original TensorFlow op
   // We prefix 'Mkl' to the original op to get Mkl op.
   // TODO(nhasabni) We should move this to mkl_util.h.
   inline string GetMklOpName(const string& name) const {
     // Prefix that we add to Tensorflow op name to construct Mkl op name.
-    const char* const kMklOpPrefix = "Mkl";
+    const char* const kMklOpPrefix = "_Mkl";
     return string(kMklOpPrefix) + name;
   }
 
@@ -440,7 +521,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
   //
   // Input nodes succ and pred may be deleted if the call to
   // this function is successful. Attempt to use the pointers
-  // after the call to function may result is undefined behaviors.
+  // after the call to function may result in undefined behaviors.
   //
   // @input g - input graph, succ - successor node, pred - predecessor node
   // @return Status::OK(), if merging is successful and supported.
@@ -470,13 +551,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
   // gradient op in the backward direction.
   //
   // @input n - Node (gradient op) whose contextinfo is to be searched,
-  //        fwdn - pointer to node from the forward pass that this node
-  //        belongs to. fwdn cannot be NULL.
+  //        fwd_node - pointer to node from the forward pass that this node
+  //        belongs to. fwd_node cannot be NULL.
   // @return Matching contextinfo in case a match is found; null otherwise.
-  //         Also updates *fwdn with pointer to forward node that this context
-  //         matches.
+  //         Also updates *fwd_node with pointer to forward node that this
+  //         context matches.
   static const ContextInfo* SearchMatchingContext(const Node* n,
-                                                  const Node** fwdn);
+                                                  const Node** fwd_node);
 
   // Rewrites input node to a new node specified by its matching rewrite info.
   //
@@ -494,46 +575,132 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
   //         Otherwise, it is not updated.
   Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
 
-  // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
-  // in graph 'g'. Original node is input in 'orign'.
+  // Get nodes that will feed a list of TF tensors to the new
+  // node that we are constructing.
   //
-  // For details, refer to 'Number of inputs after rewriting' section in the
+  // @input g - input graph,
+  // @input inputs - inputs to old node that we are using for constructing
+  //                 new inputs,
+  // @input input_idx - the index in the 'inputs' vector pointing to the
+  //                    current input that we have processed so far
+  // @output input_idx - index will be incremented by the number of nodes
+  //                     from 'inputs' that are processed
+  // @input list_length - The expected length of list of TF tensors
+  // @output output_nodes - the list of new nodes creating TF tensors
+  //
+  // @return None
+  void GetNodesProducingTFTensorList(
+      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
+      int* input_idx, int list_length,
+      std::vector<NodeBuilder::NodeOut>* output_nodes);
+
+  // Get nodes that will feed a list of Mkl tensors to the new
+  // node that we are constructing.
+  //
+  // @input g - input graph,
+  // @input inputs - inputs to old node that we are using for constructing
+  //                 new inputs,
+  // @input input_idx - the index in the 'inputs' vector pointing to the
+  //                    current input that we have processed so far
+  // @output input_idx - index will be incremented by the number of nodes
+  //                     from 'inputs' that are processed
+  // @input list_length - The expected length of list of Mkl tensors
+  // @output output_nodes - the list of new nodes creating Mkl tensors
+  //
+  // @return None
+  void GetNodesProducingMklTensorList(
+      std::unique_ptr<Graph>* g,
+      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
+      int* input_idx, int list_length,
+      std::vector<NodeBuilder::NodeOut>* output_nodes);
+
+  // Get a node that will feed an Mkl tensor to the new
+  // node that we are constructing. The output node could be (1) 'n'
+  // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
+  // if 'n' is not an Mkl layer.
+  //
+  // @input g - input graph,
+  // @input n - Node based on which we are creating Mkl node,
+  // @input n_output_slot - the output slot of node 'n'
+  //            which is feeding to the node that we are constructing
+  // @output mkl_node - the new node that will feed Mkl tensor
+  // @output mkl_node_output_slot - the slot number of mkl_node that
+  //                                will feed the tensor
+  // @return None
+  void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n,
+                                 int n_output_slot, Node** mkl_node,
+                                 int* mkl_node_output_slot);
+
+  // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
+  // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
+  // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
+  // producing workspace edges if 'are_workspace_tensors_available' is true.
+  // Otherwise, 'workspace_tensors' is empty vector.
+  //
+  // For details, refer to 'Ordering of inputs after rewriting' section in the
   // documentation above.
   //
   // Returns Status::OK() if setting up inputs is successful, otherwise
   // returns appropriate status code.
+  int SetUpContiguousInputs(
+      std::unique_ptr<Graph>* g,
+      const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
+      NodeBuilder* nb, Node* old_node,
+      std::vector<NodeBuilder::NodeOut>* workspace_tensors,
+      bool are_workspace_tensors_available);
+
+  // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
+  // in graph 'g'. Original node is input in 'orig_node'.
+  //
+  // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
+  // section in the documentation above.
+  //
+  // Returns Status::OK() if setting up inputs is successful, otherwise
+  // returns appropriate status code.
   Status SetUpInputs(std::unique_ptr<Graph>* g,
                      const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
-                     NodeBuilder* nb, Node* orign);
+                     NodeBuilder* nb, Node* orig_node);
 
-  // Add workspace edge on the input or output side of Node 'orign' by using
-  // NodeBuilder 'nb' for the new node provided. If 'orign' does not dictate
-  // adding workspace edge then do not add it.
-  void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orign,
-                                NodeBuilder* nb);
+  // Add workspace edge on the input or output side of Node 'orig_node' by using
+  // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
+  // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
+  // tensors, if they need to be added, will be set into these tensors.
+  // If we set workspace tensors, then are_ws_tensors_added should be true.
+  void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
+                                NodeBuilder* nb,
+                                std::vector<NodeBuilder::NodeOut>* ws_tensors,
+                                bool* are_ws_tensors_added);
 
   // Functions specific to operators to copy attributes
   // We need operator-specific function to copy attributes because the framework
   // does not provide any generic function for it.
-  static void CopyAttrsConv2D(const Node* orign, NodeBuilder* nb);
-  static void CopyAttrsBiasAddGrad(const Node* orign, NodeBuilder* nb);
-  static void CopyAttrsPooling(const Node* orign, NodeBuilder* nb);
-  static void CopyAttrsRelu(const Node* orign, NodeBuilder* nb);
+  // NOTE: names are alphabetically sorted.
+  static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
+  static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
+  static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
+  static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
+  static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
+  static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
+  static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
+  static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb);
+  static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
+  static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
 
   // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
-  // using node for original node 'orign' and return it in '*out'.
+  // using node for original node 'orig_node' and return it in '*out'.
   // TODO(nhasabni) We should move this to mkl_util.h
   void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
-                             Node* orign);
+                             Node* orig_node);
   void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
-                                   Node* orign);
+                                   Node* orig_node);
 };
 
 std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_;
 
-// We register Mkl rewrite pass for phase 1 in pre-placement group.
-// Do not change the ordering of the Mkl passes.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1,
+// We register Mkl rewrite pass for phase 1 in post rewrite group.
+// We register it here so that we get a complete picture of all users of Mkl
+// nodes. Do not change the ordering of the Mkl passes.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 1,
                       MklLayoutRewritePass);
 
 //////////////////////////////////////////////////////////////////////////
@@ -543,7 +710,6 @@ REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1,
 static void FillInputs(const Node* n,
                        gtl::InlinedVector<Node*, 4>* control_edges,
                        gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
-  DCHECK_EQ(in->size(), n->num_inputs());
   control_edges->clear();
   for (const Edge* e : n->in_edges()) {
     if (e->IsControlEdge()) {
@@ -561,9 +727,43 @@ static void FillInputs(const Node* n,
   }
 }
 
+void MklLayoutRewritePass::GetNodesProducingTFTensorList(
+    const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
+    int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
+  CHECK_LT(*input_idx, inputs.size());
+  CHECK_GT(list_length, 0);
+  CHECK_NOTNULL(output_nodes);
+  output_nodes->reserve(list_length);
+
+  while (list_length != 0) {
+    CHECK_GT(list_length, 0);
+    CHECK_LE(*input_idx, inputs.size());
+    Node* n = inputs[*input_idx].first;
+    int slot = inputs[*input_idx].second;
+    const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
+    // If input node 'n' is producing a list/array output at output
+    // slot 'slot' then we need to find out the length of that list/array.
+    if (ArgIsList(arg)) {
+      int N = GetTensorListLength(arg, n);
+      CHECK_LE(N, list_length);
+      for (int j = 0; j < N; j++) {
+        output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
+      }
+      (*input_idx)++;
+      list_length -= N;
+    } else {
+      // But if input node 'n' is just producing a single tensor at
+      // output slot 'slot' then we just add that single node.
+      output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
+      (*input_idx)++;
+      list_length--;
+    }
+  }
+}
+
 // TODO(nhasabni) We should move this to mkl_util.h.
 void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
-                                                 Node** out, Node* orign) {
+                                                 Node** out, Node* orig_node) {
   // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
   // dummy Mkl tensor. 8 = 2*size_t.
   const DataType dt = DataTypeToEnum<uint8>::v();
@@ -574,63 +774,228 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
                            8);
   TensorShape dummy_shape({8});
   dummy_shape.AsProto(proto.mutable_tensor_shape());
-  TF_CHECK_OK(
-      NodeBuilder((*g)->NewName("DMT"), "Const")
-          .Attr("value", proto)
-          .Attr("dtype", dt)
-          .Device(orign->def().device())  // We place this node on same
-                                          // device as device of original
-                                          // node.
-          .Finalize(&**g, out));
-  (*out)->set_assigned_device_name(orign->assigned_device_name());
+  TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
+                  .Attr("value", proto)
+                  .Attr("dtype", dt)
+                  .Device(orig_node->def().device())  // We place this node on
+                                                      // the same device as the
+                                                      // device of the original
+                                                      // node.
+                  .Finalize(&**g, out));
+  (*out)->set_assigned_device_name(orig_node->assigned_device_name());
+}
+
+void MklLayoutRewritePass::GetNodesProducingMklTensorList(
+    std::unique_ptr<Graph>* g,
+    const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
+    int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
+  CHECK_LT(*input_idx, inputs.size());
+  CHECK_GT(list_length, 0);
+  CHECK_NOTNULL(output_nodes);
+  output_nodes->reserve(list_length);
+
+  while (list_length != 0) {
+    CHECK_GT(list_length, 0);
+    CHECK_LE(*input_idx, inputs.size());
+    Node* n = inputs[*input_idx].first;
+    int slot = inputs[*input_idx].second;
+    const OpDef::ArgDef& arg = n->op_def().output_arg(slot);
+    // We need to check first if the input edge is going to carry a
+    // single tensor or a list of tensors. If it is a list of tensors,
+    // then we need to create list of Mkl dummy nodes.
+    if (ArgIsList(arg)) {
+      // If input node 'n' is producing a list/array output at output
+      // slot 'slot' then we need to find out the length of that list/array.
+      int N = GetTensorListLength(arg, n);
+      CHECK_LE(N, list_length);
+      Node* mkl_node = nullptr;
+      int mkl_node_output_slot = 0;
+      // If it is a list, then create a list of Mkl dummy nodes.
+      for (int j = 0; j < N; j++) {
+        GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
+        output_nodes->push_back(
+            NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
+      }
+      (*input_idx)++;
+      list_length -= N;
+    } else {
+      // If it is not a list, then create a single Mkl tensor node.
+      Node* mkl_node = nullptr;
+      int mkl_node_output_slot = 0;
+      GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
+      output_nodes->push_back(
+          NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
+      (*input_idx)++;
+      list_length--;
+    }
+  }
+}
+
+// Get an input node that will feed Mkl tensor to the new
+// node that we are constructing. An input node could be (1) 'n'
+// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
+// if 'n' is not an Mkl layer.
+void MklLayoutRewritePass::GetNodeProducingMklTensor(
+    std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node,
+    int* mkl_node_output_slot) {
+  CHECK_NOTNULL(n);
+  CHECK_NOTNULL(mkl_node);
+  CHECK_NOTNULL(mkl_node_output_slot);
+  if (IsRewrittenNode(n)) {
+    // If we have visited this node and rewritten it, then it will generate
+    // an edge that will receive Mkl tensor from a node.
+    // First, let's assert that this op is Mkl layer.
+    DataType T;
+    TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
+    // If this op has been rewritten, then its name must have been same as
+    // Mkl op.
+    CHECK_EQ(mkl_op_registry::IsMklOp(n->type_string(), T), true);
+    // output slot number for Mkl tensor would be N+slot number of TensorFlow
+    // tensor, where N is total number of TensorFlow tensors.
+    *mkl_node = n;
+    *mkl_node_output_slot =
+        GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
+  } else {
+    // If we have not visited the node and rewritten it, then we need
+    // to create a dummy node that will feed a dummy Mkl tensor to this node.
+    // DummyMklTensor node has no input and generates only 1 output
+    // (dummy Mkl tensor) as output slot number 0.
+    GetDummyMklTensorNode(g, mkl_node, n);
+    CHECK_NOTNULL(*mkl_node);
+    *mkl_node_output_slot = 0;
+  }
+}
+
+int MklLayoutRewritePass::SetUpContiguousInputs(
+    std::unique_ptr<Graph>* g,
+    const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
+    NodeBuilder* nb, Node* old_node,
+    std::vector<NodeBuilder::NodeOut>* workspace_tensors,
+    bool are_workspace_tensors_available) {
+  CHECK_NOTNULL(workspace_tensors);
+  CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+
+  // Number of input slots to original op
+  // Input slots are represented by .Input() calls in REGISTER_OP.
+  int old_node_input_slots = old_node->op_def().input_arg_size();
+  // Actual number of inputs can be greater than or equal to number
+  // of Input slots because inputs of type list could be unfolded.
+  CHECK_GE(old_node_inputs.size(), old_node_input_slots);
+  int nn_slot_idx = 0;  // slot index for inputs of new node
+
+  // Let's copy all inputs (TF tensors) of original node to new node.
+  int iidx = 0;
+  for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
+    // An input slot could be a single tensor or a list. We need
+    // to handle this case accordingly.
+    CHECK_LT(iidx, old_node_inputs.size());
+    const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
+    if (ArgIsList(arg)) {
+      std::vector<NodeBuilder::NodeOut> new_node_inputs;
+      int N = GetTensorListLength(arg, old_node);
+      GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
+                                    &new_node_inputs);
+      nb->Input(new_node_inputs);
+      nn_slot_idx++;
+    } else {
+      nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
+      iidx++;
+      nn_slot_idx++;
+    }
+  }
+
+  // If workspace tensors are available for this op and we are using
+  // contiguous ordering then we need to add Tensorflow tensor for
+  // workspace here because Tensorflow tensor for workspace is the
+  // last tensor in the list of Tensorflow tensors.
+  if (are_workspace_tensors_available) {
+    CHECK_EQ(workspace_tensors->size(), 2);
+    // Tensorflow tensor
+    nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
+    nn_slot_idx++;
+  }
+
+  // Let's now setup all Mkl inputs to new node.
+  // Number of Mkl inputs must be same as number of TF inputs.
+  iidx = 0;
+  for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
+    // An input slot could be a single tensor or a list. We need
+    // to handle this case accordingly.
+    CHECK_LT(iidx, old_node_inputs.size());
+    const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
+    if (ArgIsList(arg)) {
+      std::vector<NodeBuilder::NodeOut> new_node_inputs;
+      int N = GetTensorListLength(arg, old_node);
+      GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N,
+                                     &new_node_inputs);
+      nb->Input(new_node_inputs);
+      nn_slot_idx++;
+    } else {
+      Node* mkl_node = nullptr;
+      int mkl_node_output_slot = 0;
+      GetNodeProducingMklTensor(g, old_node_inputs[iidx].first,
+                                old_node_inputs[iidx].second, &mkl_node,
+                                &mkl_node_output_slot);
+      nb->Input(mkl_node, mkl_node_output_slot);
+      iidx++;
+      nn_slot_idx++;
+    }
+  }
+
+  // If workspace tensors are available for this op and we are using
+  // contiguous ordering then we need to add Mkl tensor for
+  // workspace here because Mkl tensor for workspace is the
+  // last tensor in the list of Mkl tensors.
+  if (are_workspace_tensors_available) {
+    CHECK_EQ(workspace_tensors->size(), 2);
+    // Mkl tensor
+    nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
+    nn_slot_idx++;
+  }
+
+  return nn_slot_idx;
 }
 
 Status MklLayoutRewritePass::SetUpInputs(
     std::unique_ptr<Graph>* g,
-    const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, NodeBuilder* nb,
-    Node* orign) {
-  std::vector<NodeBuilder::NodeOut> new_inputs;
+    const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
+    NodeBuilder* nb, Node* old_node) {
+  // Let's check if we need to add workspace tensors for this node.
+  // We add workspace edge only for MaxPool, LRN and BatchNorm.
+  std::vector<NodeBuilder::NodeOut> workspace_tensors;
+  bool are_workspace_tensors_available = false;
+  AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
+                           &are_workspace_tensors_available);
 
-  // 1. Let's setup inputs for the new node.
-  for (int i = 0; i < inputs.size(); i++) {
-    Node* n = inputs[i].first;
-    // First let's copy original TF tensor input as it is.
-    new_inputs.push_back(NodeBuilder::NodeOut(n, inputs[i].second));
-
-    // Second, let's add edge to propagate Mkl tensors from input Mkl layers,
-    // or generate a dummy Mkl tensor representing not-mkl-tensor case.
-    if (IsRewrittenNode(n)) {
-      // If we have visited this node and rewritten it, then it will generate
-      // an edge that will receive Mkl tensor from a node.
-      // First, let's assert that this op is Mkl layer.
-      DataType T;
-      TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
-      // If this op has been rewritten, then its name must have been same as
-      // Mkl op.
-      CHECK_EQ(mkl_layer_registry::IsMklLayer(n->type_string(), T), true);
-      // src slot number for Mkl tensor would be the one next to TF tensor
-      // slot number.
-      new_inputs.push_back(NodeBuilder::NodeOut(n, inputs[i].second + 1));
-    } else {
-      // If we have not visited the node and rewritten it, then we need
-      // to create a dummy node that will feed a non-Mkl tensor to this node.
-      // DummyMklTensor node has no input and generates only 1 output
-      // (dummy Mkl tensor) as output slot number 0.
-      Node* dmt = nullptr;
-      GetDummyMklTensorNode(g, &dmt, orign);
-      CHECK_NOTNULL(dmt);
-      new_inputs.push_back(NodeBuilder::NodeOut(dmt, 0));
-    }
+  int new_node_input_slots = 0;
+  if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+    // TODO(nhasabni): implement this function just for same of completion.
+    // We do not use interleaved ordering right now.
+    return Status(
+        error::Code::UNIMPLEMENTED,
+        "Interleaved ordering of tensors is currently not supported.");
+  } else {
+    CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+    new_node_input_slots = SetUpContiguousInputs(
+        g, old_node_inputs, nb, old_node, &workspace_tensors,
+        are_workspace_tensors_available);
   }
 
-  // The total number of inputs to new node _must_ be 2 times the number
-  // of inputs to the original node: N original Tensorflow tensors and
-  // N for Mkl tensors corresponding to each Tensorflow tensors.
-  CHECK_EQ(new_inputs.size(), inputs.size() * 2);
-
-  // 2. Let's add the new inputs.
-  for (auto ni : new_inputs) {
-    nb->Input(ni.node, ni.index);
+  // Sanity check
+  int old_node_input_slots = old_node->op_def().input_arg_size();
+  if (!are_workspace_tensors_available) {
+    // If we are not adding workspace tensors for this op, then the total
+    // number of input slots to the new node _must_ be 2 times the number
+    // of input slots to the original node: N original Tensorflow tensors and
+    // N for Mkl tensors corresponding to each Tensorflow tensors.
+    CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
+  } else {
+    // If we are adding workspace tensors for this op, then the total
+    // The total number of input slots to new node _must_ be 2 times the number
+    // of input slots to the original node: N original Tensorflow tensors and
+    // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
+    // (for workspace Tensorflow tensor and workspace Mkl tensor).
+    CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
   }
 
   return Status::OK();
@@ -642,7 +1007,7 @@ Status MklLayoutRewritePass::SetUpInputs(
 
 // TODO(nhasabni) We should move this to mkl_util.h.
 void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
-    std::unique_ptr<Graph>* g, Node** out, Node* orign) {
+    std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
   // We use a tensor of shape {1} and value 0 to represent
   // dummy float tensor. We need this as a dummy workspace tensor.
   // Workspace tensor has type float.
@@ -654,39 +1019,42 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
                            4);
   TensorShape dummy_shape({1});
   dummy_shape.AsProto(proto.mutable_tensor_shape());
-  TF_CHECK_OK(
-      NodeBuilder((*g)->NewName("DMT"), "Const")
-          .Attr("value", proto)
-          .Attr("dtype", dt)
-          .Device(orign->def().device())  // We place this node on same
-                                          // device as device of original
-                                          // node.
-          .Finalize(&**g, out));
-  (*out)->set_assigned_device_name(orign->assigned_device_name());
+  TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
+                  .Attr("value", proto)
+                  .Attr("dtype", dt)
+                  .Device(orig_node->def().device())  // We place this node on
+                                                      // same the device as the
+                                                      // device of the original
+                                                      // node.
+                  .Finalize(&**g, out));
+  (*out)->set_assigned_device_name(orig_node->assigned_device_name());
 }
 
-void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
-                                                    Node* orign,
-                                                    NodeBuilder* nb) {
-  bool workspace_edge_added = false;
+void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
+    std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
+    std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
+  bool workspace_edge_added = false;  // Default initializer
+  CHECK_NOTNULL(are_ws_tensors_added);
+  *are_ws_tensors_added = false;  // Default initializer
+
   DataType T;
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
   for (auto ws : wsinfo_) {
-    if (orign->type_string() == ws.fwdop &&
-        mkl_layer_registry::IsMklLayer(GetMklOpName(orign->type_string()), T)) {
+    if (orig_node->type_string() == ws.fwd_op &&
+        mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()), T)) {
       // If this op is a fwd op, then we need to check if there is an
-      // edge from this node's fwdslot to bwdop's bwdslot. If there is
+      // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
       // an edge, then we just add an attribute on this node for setting
       // workspace_passed to true. We don't add actual workspace edge
       // in this node. Actual workspace edge gets added in the backward
       // op for this node.
-      for (const Edge* e : orign->out_edges()) {
-        if (e->src_output() == ws.fwdslot &&
-            e->dst()->type_string() == ws.bwdop &&
-            e->dst_input() == ws.bwdslot) {
+      for (const Edge* e : orig_node->out_edges()) {
+        if (e->src_output() == ws.fwd_slot &&
+            e->dst()->type_string() == ws.bwd_op &&
+            e->dst_input() == ws.bwd_slot) {
           nb->Attr("workspace_enabled", true);
           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
-                  << orign->type_string();
+                  << orig_node->type_string();
           workspace_edge_added = true;
           // We found the edge that we were looking for, so break.
           break;
@@ -698,34 +1066,40 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
         // node.
         nb->Attr("workspace_enabled", false);
       }
-    } else if (orign->type_string() == ws.bwdop &&
-               mkl_layer_registry::IsMklLayer(
-                   GetMklOpName(orign->type_string()), T)) {
+    } else if (orig_node->type_string() == ws.bwd_op &&
+               mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()),
+                                        T)) {
       // If this op is a bwd op, then we need to add workspace edge and
       // it's Mkl tensor edge between its corresponding fwd op and this
-      // op. Corresponding fwd op is specified in 'fwdop' field of
-      // workspace info. fwdslot and bwdslot in workspace info specify
+      // op. Corresponding fwd op is specified in 'fwd_op' field of
+      // workspace info. fwd_slot and bwd_slot in workspace info specify
       // an edge between which slots connect forward and backward op.
       // Once all these criteria match, we add a workspace edge between
-      // wsfwdslot and wsbwdslot. It's corresponding Mkl tensor is added
-      // in wsfwdslot+1 and wsbwdslot+1.
-      for (const Edge* e : orign->in_edges()) {
-        if (e->src_output() == ws.fwdslot &&
+      // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
+      // determined by interleaved/contiguous ordering. Function
+      // DataIndexToMetaDataIndex tells us the location of Mkl tensor
+      // from the location of the Tensorflow tensor.
+      for (const Edge* e : orig_node->in_edges()) {
+        if (e->src_output() == ws.fwd_slot &&
             // We would have rewritten the forward op, so we need to use
             // GetMklOpName call to get its Mkl name.
-            e->src()->type_string() == GetMklOpName(ws.fwdop) &&
-            e->dst_input() == ws.bwdslot) {
+            e->src()->type_string() == GetMklOpName(ws.fwd_op) &&
+            e->dst_input() == ws.bwd_slot) {
           nb->Attr("workspace_enabled", true);
+          CHECK_NOTNULL(ws_tensors);
           // Add workspace edge between fwd op and bwd op.
-          nb->Input(e->src(), ws.wsfwdslot);
+          ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
           // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
-          nb->Input(e->src(), ws.wsfwdslot + 1);
+          ws_tensors->push_back(NodeBuilder::NodeOut(
+              e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
+                                                 e->src()->num_outputs())));
+          *are_ws_tensors_added = true;
           // In terms of input ordering, we add these calls to add Input
           // here because workspace edge (and its Mkl tensor) is the last
           // edge in the fwdop and bwdop. So all inputs before workspace
           // tensor have been added by SetUpInputs function.
           VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
-                  << orign->type_string();
+                  << orig_node->type_string();
           workspace_edge_added = true;
           // We found the edge that we were looking for, so break.
           break;
@@ -740,15 +1114,18 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
         nb->Attr("workspace_enabled", false);
         Node* dmt_ws = nullptr;      // Dummy tensor for workspace
         Node* dmt_mkl_ws = nullptr;  // Dummy Mkl tensor for workspace
-        GetDummyWorkspaceTensorNode(g, &dmt_ws, orign);
-        GetDummyMklTensorNode(g, &dmt_mkl_ws, orign);
+        GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
+        GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
         CHECK_NOTNULL(dmt_ws);
         CHECK_NOTNULL(dmt_mkl_ws);
-        nb->Input(dmt_ws, 0);      // We add dummy tensor as workspace tensor.
-        nb->Input(dmt_mkl_ws, 0);  // We add dummy tensor as Mkl
-                                   // tensor for workspace tensor.
+        CHECK_NOTNULL(ws_tensors);
+        // We add dummy tensor as workspace tensor.
+        ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
+        // We add dummy tensor as Mkl tensor for workspace tensor.
+        ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
+        *are_ws_tensors_added = true;
         VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
-                << orign->type_string();
+                << orig_node->type_string();
       }
     } else {
       // If this node does not match any workspace info, then we do not
@@ -761,7 +1138,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
 // Op-specific functions to copy attributes from old node to new node
 //////////////////////////////////////////////////////////////////////////
 
-void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
+                                           NodeBuilder* nb) {
   DataType T;
   string data_format;
   string padding;
@@ -769,11 +1147,12 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) {
   bool use_cudnn_on_gpu;
 
   // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "padding", &padding));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
+  TF_CHECK_OK(
+      GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
 
   // Add attributes to new node.
   nb->Attr("T", T);
@@ -783,16 +1162,16 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) {
   nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
 }
 
-void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orign,
+void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
                                                 NodeBuilder* nb) {
   DataType T;
   string data_format;
   std::vector<int32> strides;
 
   // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
 
   // Add attributes to new node.
   nb->Attr("T", T);
@@ -800,7 +1179,30 @@ void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orign,
   nb->Attr("data_format", data_format);
 }
 
-void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign,
+void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node,
+                                        NodeBuilder* nb) {
+  DataType T;
+  int depth_radius;
+  float bias;
+  float alpha;
+  float beta;
+
+  // Get all attributes from old node.
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "depth_radius", &depth_radius));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "bias", &bias));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "alpha", &alpha));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "beta", &beta));
+
+  // Add attributes to new node.
+  nb->Attr("T", T);
+  nb->Attr("depth_radius", depth_radius);
+  nb->Attr("bias", bias);
+  nb->Attr("alpha", alpha);
+  nb->Attr("beta", beta);
+}
+
+void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
                                             NodeBuilder* nb) {
   DataType T;
   string data_format;
@@ -808,11 +1210,11 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign,
   std::vector<int32> ksize, strides;
 
   // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "ksize", &ksize));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "strides", &strides));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "padding", &padding));
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &data_format));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
 
   // Add attributes to new node.
   nb->Attr("T", T);
@@ -822,16 +1224,99 @@ void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign,
   nb->Attr("data_format", data_format);
 }
 
-void MklLayoutRewritePass::CopyAttrsRelu(const Node* orign, NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsRelu(const Node* orig_node,
+                                         NodeBuilder* nb) {
   DataType T;
 
   // Get all attributes from old node.
-  TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
 
   // Add attributes to new node.
   nb->Attr("T", T);
 }
 
+void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
+                                          NodeBuilder* nb) {
+  DataType T;
+  string data_format;
+  int num_split;
+
+  // Get all attributes from old node.
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_split", &num_split));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
+
+  // Add attributes to new node.
+  nb->Attr("T", T);
+  nb->Attr("num_split", num_split);
+  nb->Attr("data_format", data_format);
+}
+
+void MklLayoutRewritePass::CopyAttrsConcat(const Node* orig_node,
+                                           NodeBuilder* nb) {
+  DataType T;
+  int N;
+
+  // Get all attributes from old node.
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
+
+  // Add attributes to new node.
+  nb->Attr("T", T);
+  nb->Attr("N", N);
+}
+
+void MklLayoutRewritePass::CopyAttrsConcatV2(const Node* orig_node,
+                                             NodeBuilder* nb) {
+  DataType T;
+  int N;
+  DataType tidx;
+
+  // Get all attributes from old node.
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "N", &N));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tidx", &tidx));
+
+  // Add attributes to new node.
+  nb->Attr("T", T);
+  nb->Attr("N", N);
+  nb->Attr("Tidx", tidx);
+}
+
+void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
+                                                   NodeBuilder* nb) {
+  DataType T;
+  float epsilon;
+  string data_format;
+  bool is_training;
+
+  // Get all attributes from old node.
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "is_training", &is_training));
+
+  // Add attributes to new node.
+  nb->Attr("T", T);
+  nb->Attr("epsilon", epsilon);
+  nb->Attr("data_format", data_format);
+  nb->Attr("is_training", is_training);
+}
+
+void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
+                                            NodeBuilder* nb) {
+  DataType T;
+  DataType Tshape;
+
+  // Get all attributes from old node.
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+  TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
+
+  // Add attributes to new node.
+  nb->Attr("T", T);
+  nb->Attr("Tshape", Tshape);
+}
+
 //////////////////////////////////////////////////////////////////////////
 //           Helper functions related to node merge pass
 //////////////////////////////////////////////////////////////////////////
@@ -889,8 +1374,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
   CHECK_NOTNULL(succ);
   CHECK_NOTNULL(pred);
 
-  if (succ->type_string() == csinfo_.biasadd &&
-      pred->type_string() == csinfo_.mklconv2d) {
+  if (succ->type_string() == csinfo_.bias_add &&
+      pred->type_string() == csinfo_.mkl_conv2d) {
     // 1. Get all attributes from input nodes.
     DataType T_pred, T_succ;
     string padding;
@@ -947,7 +1432,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
     // 2. Get inputs from both the nodes.
     // Find the 2 inputs from the conv and the bias from the add Bias.
     // Get operand 0, 1 of conv2D and their Mkl tensors.
-    CHECK_EQ(pred->in_edges().size(), 4);  // MklConv2D must have 4 inputs.
+    CHECK_EQ(pred->in_edges().size(), 4);  // _MklConv2D must have 4 inputs.
     // Get operand 1 of add_bias
     // BiasAdd must have 2 inputs: Conv, bias
     CHECK_EQ(succ->in_edges().size(), 2);
@@ -960,13 +1445,29 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
     // We will use the node name of BiasAdd as the name of new node
     // Build new node. We use same name as original node, but change the op
     // name.
-    NodeBuilder nb(succ->name(), csinfo_.mklconv2dwithbias);
-    nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
-    nb.Input(pred_in[1].first, pred_in[1].second);  // Mkl for In1
-    nb.Input(pred_in[2].first, pred_in[2].second);  // In2 of Conv2D
-    nb.Input(pred_in[3].first, pred_in[3].second);  // Mkl for In2
-    nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
-    nb.Input(oper3_mkl, oper3_mkl_slot);            // Mkl for In2 of BiasAdd
+    NodeBuilder nb(succ->name(), csinfo_.mkl_conv2d_with_bias);
+    if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+      nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
+      // pred_in[1] will be Mkl tensor for In1 if we follow interleaved
+      // ordering, and it will be 2nd Tensorflow tensor for Conv2D if
+      // we follow contiguous ordering.
+      nb.Input(pred_in[1].first, pred_in[1].second);  // Mkl for In1
+      nb.Input(pred_in[2].first, pred_in[2].second);  // In2 of Conv2D
+      nb.Input(pred_in[3].first, pred_in[3].second);  // Mkl for In2
+      nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
+      nb.Input(oper3_mkl, oper3_mkl_slot);            // Mkl for In2 of BiasAdd
+    } else {
+      CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+      nb.Input(pred_in[0].first, pred_in[0].second);  // In1 of Conv2D
+      // pred_in[1] will be Mkl tensor for In1 if we follow interleaved
+      // ordering, and it will be 2nd Tensorflow tensor for Conv2D if
+      // we follow contiguous ordering.
+      nb.Input(pred_in[1].first, pred_in[1].second);  // In2 of Conv2D
+      nb.Input(succ_in[1].first, succ_in[1].second);  // In2 of BiasAdd
+      nb.Input(pred_in[2].first, pred_in[2].second);  // Mkl for In1 of Conv2D
+      nb.Input(pred_in[3].first, pred_in[3].second);  // Mkl for In2 of Conv2D
+      nb.Input(oper3_mkl, oper3_mkl_slot);            // Mkl for In2 of BiasAdd
+    }
 
     // Copy attributes from Conv2D to Conv2DWithBias.
     CopyAttrsConv2D(const_cast<const Node*>(pred), &nb);
@@ -975,30 +1476,30 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
     nb.Device(succ->def().device());
 
     // Create node.
-    Node* newn;
-    nb.Finalize(&**g, &newn);
-    CHECK_NOTNULL(newn);
+    Node* new_node;
+    nb.Finalize(&**g, &new_node);
+    CHECK_NOTNULL(new_node);
 
     // Set the Mkl layer label for this op.
-    newn->AddAttr("_kernel", mkl_layer_registry::kMklLayerLabel);
+    new_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
 
     // Incoming edges are fixed, we will fix the outgoing edges now.
     for (const Edge* e : succ->out_edges()) {
-      (*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input());
+      (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
     }
 
     // Copy device assigned to old node to new node.
     // It's ok to use pred or succ as we have enforced a check that
     // both have same device assigned.
-    newn->set_assigned_device_name(pred->assigned_device_name());
+    new_node->set_assigned_device_name(pred->assigned_device_name());
 
     VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
             << ", and node: " << succ->DebugString()
-            << ", into node:" << newn->DebugString();
+            << ", into node:" << new_node->DebugString();
 
     (*g)->RemoveNode(succ);
     (*g)->RemoveNode(pred);
-    MarkRewrittenNode(newn);
+    MarkRewrittenNode(new_node);
 
     return Status::OK();
   }
@@ -1011,35 +1512,39 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
 //           Helper functions for node rewrite
 //////////////////////////////////////////////////////////////////////////
 
-Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign,
+Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
+                                         Node* orig_node,
                                          const RewriteInfo* ri) {
   CHECK_NOTNULL(ri);
-  CHECK_NOTNULL(orign);
+  CHECK_NOTNULL(orig_node);
 
-  VLOG(1) << "MklLayoutRewritePass: Original node:" << orign->DebugString();
+  VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
 
   // Check if this is scenario 2 (context-based rewrite).
   // Get the matching ContextInfo if it is.
-  const Node* fwdn = nullptr;
+  const Node* fwd_node = nullptr;
   const ContextInfo* ci = nullptr;
   bool is_context_based_rewrite = false;
-  if ((ci = SearchMatchingContext(orign, &fwdn)) != nullptr) {
-    CHECK_NOTNULL(fwdn);
+  if ((ci = SearchMatchingContext(orig_node, &fwd_node)) != nullptr) {
+    CHECK_NOTNULL(fwd_node);
     is_context_based_rewrite = true;
 
     // Sanity checks for context-based rewrite (if any)
-    if (orign->type_string() == csinfo_.biasaddgrad &&
-        ri->newname == csinfo_.mklconv2dwithbiasbackpropbias) {
+    if (orig_node->type_string() == csinfo_.bias_add_grad &&
+        ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
       DataType orig_T, ctx_T;
       string orig_data_format, ctx_data_format;
-      TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &orig_T));
-      TF_CHECK_OK(GetNodeAttr(orign->def(), "data_format", &orig_data_format));
-      TF_CHECK_OK(GetNodeAttr(fwdn->def(), "T", &ctx_T));
-      TF_CHECK_OK(GetNodeAttr(fwdn->def(), "data_format", &ctx_data_format));
+      TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
+      TF_CHECK_OK(
+          GetNodeAttr(orig_node->def(), "data_format", &orig_data_format));
+      TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T));
+      TF_CHECK_OK(
+          GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format));
 
       if (orig_data_format != ctx_data_format || orig_T != ctx_T ||
-          orign->assigned_device_name() != fwdn->assigned_device_name() ||
-          orign->def().device() != fwdn->def().device()) {
+          orig_node->assigned_device_name() !=
+              fwd_node->assigned_device_name() ||
+          orig_node->def().device() != fwd_node->def().device()) {
         return Status(
             error::Code::INVALID_ARGUMENT,
             "data_format or T attribute or devices of BiasAddGrad and "
@@ -1049,18 +1554,22 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign,
   }
 
   // Get all inputs.
-  const int num = orign->num_inputs();
-  CHECK_EQ(num, ri->numins);
+  const int num = orig_node->in_edges().size();
+  // Check the number of inputs against the user-specified value for non-vararg
+  // nodes.
+  if (!IsVarArgNode(orig_node)) {
+    CHECK_EQ(num, ri->num_ins);
+  }
   gtl::InlinedVector<Node*, 4> control_edges;
   gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num);
-  FillInputs(orign, &control_edges, &inputs);
+  FillInputs(orig_node, &control_edges, &inputs);
 
   // Build new node. We use same name as original node, but change the op name.
-  NodeBuilder nb(orign->name().c_str(), ri->newname.c_str());
+  NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
   // Copy user-specified device assigned to original node to new node.
-  nb.Device(orign->def().device());
+  nb.Device(orig_node->def().device());
   // Set up new inputs to the rewritten node.
-  Status s = SetUpInputs(g, inputs, &nb, orign);
+  Status s = SetUpInputs(g, inputs, &nb, orig_node);
   if (s != Status::OK()) {
     return s;
   }
@@ -1068,62 +1577,63 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign,
   // Copy attributes from original node to new node (for scenario 1).
   // For context-based rewrite, we use context to copy the attributes.
   if (is_context_based_rewrite) {
-    if (orign->type_string() == csinfo_.biasaddgrad &&
-        ri->newname == csinfo_.mklconv2dwithbiasbackpropbias) {
-      CHECK_NOTNULL(fwdn);
-      ri->copyattrs(fwdn, &nb);
+    if (orig_node->type_string() == csinfo_.bias_add_grad &&
+        ri->new_name == csinfo_.mkl_conv2d_with_bias_backprop_bias) {
+      CHECK_NOTNULL(fwd_node);
+      ri->copy_attrs(fwd_node, &nb);
     } else {
       return Status(error::Code::UNIMPLEMENTED,
                     "Unimplemented case for node rewrite optimization.");
     }
   } else {
-    ri->copyattrs(const_cast<const Node*>(orign), &nb);
+    ri->copy_attrs(const_cast<const Node*>(orig_node), &nb);
   }
   // Set the Mkl layer label for this op.
-  nb.Attr("_kernel", mkl_layer_registry::kMklLayerLabel);
-
-  // Add workspace edge to this node if needed.
-  // We add workspace edge only for MaxPool, LRN and BatchNorm.
-  AddWorkSpaceEdgeIfNeeded(g, orign, &nb);
+  nb.Attr("_kernel", mkl_op_registry::kMklOpLabel);
 
   // Finalize graph and get new node.
-  Node* newn = nullptr;
-  TF_CHECK_OK(nb.Finalize(&**g, &newn));
-  CHECK_NOTNULL(newn);
+  Node* new_node = nullptr;
+  TF_CHECK_OK(nb.Finalize(&**g, &new_node));
+  CHECK_NOTNULL(new_node);
 
-  // Incoming edges from 'orign' node to new 'newn' node are already copied
-  // in BuildNode. Copy outgoing edges from 'orign' node to new 'newn' node.
-  // Since the output also follows same ordering among Tensorflow tensors and
-  // Mkl tensors. We need to connect Tensorflow tensors appropriately.
-  // Specifically, nth output of original node will become 2*nth output of
-  // Mkl node. GetTensorDataIndex provides this mapping function.
-  for (const Edge* e : orign->out_edges()) {
+  // Incoming edges from 'orig_node' node to new 'new_node' node are already
+  // copied in BuildNode. Copy outgoing edges from 'orig_node' node to new
+  // 'new_node' node, since the output also follows same ordering among
+  // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
+  // tensors appropriately. Specifically, nth output of the original node
+  // will become 2*nth output of the Mkl node for the interleaved ordering
+  // of the tensors. For the contiguous ordering of the tensors, it will be n.
+  // GetTensorDataIndex provides this mapping function.
+  for (const Edge* e : orig_node->out_edges()) {
     // We need to handle control-edges by using their original slot number.
     // Generally, -1 is reserved for control slot.
     if (e->src_output() < 0) {
-      (*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input());
+      (*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
     } else {
-      (*g)->AddEdge(newn, GetTensorDataIndex(e->src_output()), e->dst(),
-                    e->dst_input());
+      (*g)->AddEdge(
+          new_node,
+          GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
+          e->dst(), e->dst_input());
     }
   }
 
   // Copy the runtime device assigned from original code to new node.
-  newn->set_assigned_device_name(orign->assigned_device_name());
+  new_node->set_assigned_device_name(orig_node->assigned_device_name());
 
   // Delete original node and mark new node as rewritten.
-  (*g)->RemoveNode(orign);
-  MarkRewrittenNode(newn);
+  (*g)->RemoveNode(orig_node);
+  MarkRewrittenNode(new_node);
 
-  VLOG(1) << "MklLayoutRewritePass: New node:" << newn->DebugString();
+  VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
   return Status::OK();
 }
 
 const MklLayoutRewritePass::ContextInfo*
-MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) {
+MklLayoutRewritePass::SearchMatchingContext(const Node* n,
+                                            const Node** fwd_node) {
   CHECK_NOTNULL(n);
-  CHECK_NOTNULL(fwdn);
-  *fwdn = nullptr;
+  CHECK_NOTNULL(fwd_node);
+  *fwd_node = nullptr;
 
   // Search for matching contextinfo based on node name.
   // There could be more than one matching contextinfos.
@@ -1171,7 +1681,7 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) {
     // If we find a match, we return immediately.
     for (const ContextInfo* ci : mci) {
       if (curr_node->type_string() == ci->fwd) {
-        *fwdn = curr_node;
+        *fwd_node = curr_node;
         return ci;
       }
     }
@@ -1192,8 +1702,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) {
 }
 
 bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) {
-  const Node* fwdn = nullptr;
-  return SearchMatchingContext(n, &fwdn) != nullptr;
+  const Node* fwd_node = nullptr;
+  return SearchMatchingContext(n, &fwd_node) != nullptr;
 }
 
 const MklLayoutRewritePass::RewriteInfo*
@@ -1208,7 +1718,8 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
   if (!GetNodeAttr(n->def(), "T", &T).ok()) {
     return nullptr;
   }
-  if (!mkl_layer_registry::IsMklLayer(GetMklOpName(n->type_string()), T)) {
+
+  if (!mkl_op_registry::IsMklOp(GetMklOpName(n->type_string()), T)) {
     return nullptr;
   }
 
@@ -1219,7 +1730,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
 
   // Find matching RewriteInfo and then check that rewrite rule applies.
   for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
-    if (n->type_string().compare(ri->name) == 0 && ri->rewriterule(n)) {
+    if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
       return &*ri;
     }
   }
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index 142d60d6112..6e72baf84e2 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -110,9 +110,11 @@ class MklLayoutPassTest : public ::testing::Test {
 };
 
 REGISTER_OP("Input").Output("o: float").SetIsStateful();
+REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
 REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
-REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
-REGISTER_OP("MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
+REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
+REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
+REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
 
 /////////////////////////////////////////////////////////////////////
 //  Unit tests related to node merge optiimization
@@ -133,20 +135,22 @@ TEST_F(MklLayoutPassTest, Basic) {
 
 // Test set 1: Conv2D + AddBias
 
-// C=MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y)
+// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering)
+// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering)
 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
+  CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   InitGraph(
       "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
       "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
+      "node { name: 'M' op: '_MklInput'}"
+      "node { name: 'N' op: '_MklInput'}"
+      "node { name: 'C' op: '_MklConv2D'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
+      " input: ['A', 'B', 'M', 'N']}"
       "node { name: 'D' op: 'Input'}"
       "node { name: 'E' op: 'BiasAdd'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
@@ -157,26 +161,28 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
       " attr {key: 'T'                 value { type: DT_FLOAT } }"
       " input: ['E', 'Y']}");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
-            "M(MklInput);N(MklInput);Y(Input);Z(Sub)|A->E;B->E:2;D->E:4;"
-            "DMT/_0->E:5;E->Z;M->E:1;N->E:3;Y->Z:1");
+            "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
+            "M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
+            "DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1");
 }
 
-// C=MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y)
+// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
+// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous)
 // Test for correct output slots selected
 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
+  CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   InitGraph(
       "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput2'}"
       "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput2'}"
-      "node { name: 'C' op: 'MklConv2D'"
+      "node { name: 'M' op: '_MklInput2'}"
+      "node { name: 'N' op: '_MklInput2'}"
+      "node { name: 'C' op: '_MklConv2D'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M:1', 'B', 'N:1']}"
+      " input: ['A', 'B', 'M:1', 'N:1']}"
       "node { name: 'D' op: 'Input'}"
       "node { name: 'E' op: 'BiasAdd'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
@@ -187,16 +193,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
       " attr {key: 'T'                 value { type: DT_FLOAT } }"
       " input: ['E', 'Y']}");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
-            "M(MklInput2);N(MklInput2);Y(Input);Z(Sub)|A->E;B->E:2;D->E:4;"
-            "DMT/_0->E:5;E->Z;M:1->E:1;N:1->E:3;Y->Z:1");
+            "A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
+            "M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
+            "DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1");
 }
 
 // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
 // This is a case of node rewrite followed by node merge.
-// We will first rewrite Conv2D to MklConv2D, and then merge MklConv2D
-// with BiasAdd to produce MklConv2DWithBias.
+// We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D
+// with BiasAdd to produce _MklConv2DWithBias.
 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
+  CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
   InitGraph(
       "node { name: 'A' op: 'Input'}"
       "node { name: 'B' op: 'Input'}"
@@ -218,70 +225,70 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
       " input: ['E', 'Y']}");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
             "A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
-            "DMT/_2(Const);E(MklConv2DWithBias);Y(Input);Z(Sub)|"
-            "A->E;B->E:2;D->E:4;DMT/_0->E:1;DMT/_1->E:3;DMT/_2->E:5;"
+            "DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
+            "A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;"
             "E->Z;Y->Z:1");
 }
 
-// Graph contains only MklConv2D, no AddBias.
+// Graph contains only _MklConv2D, no AddBias.
 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
   InitGraph(
       "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
       "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
+      "node { name: 'M' op: '_MklInput'}"
+      "node { name: 'N' op: '_MklInput'}"
+      "node { name: 'C' op: '_MklConv2D'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}");
+      " input: ['A', 'B', 'M', 'N']}");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|"
-            "A->C;B->C:2;M->C:1;N->C:3");
+            "A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|"
+            "A->C;B->C:1;M->C:2;N->C:3");
 }
 
-// MklConv2D output does not go to BiasAdd.
+// _MklConv2D output does not go to BiasAdd.
 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
   InitGraph(
       "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
       "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
+      "node { name: 'M' op: '_MklInput'}"
+      "node { name: 'N' op: '_MklInput'}"
+      "node { name: 'C' op: '_MklConv2D'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
+      " input: ['A', 'B', 'M', 'N']}"
       "node { name: 'D' op: 'Input'}"
       "node { name: 'E' op: 'Input'}"
       "node { name: 'F' op: 'BiasAdd'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['D', 'E'] }");  // Output of MklConv2D does not go to BiasAdd.
+      " input: ['D', 'E'] }");  // Output of _MklConv2D does not go to BiasAdd.
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
-            "M(MklInput);N(MklInput)|A->C;B->C:2;D->F;E->F:1;M->C:1;N->C:3");
+            "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
+            "M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
 }
 
-// MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
+// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
 // Merge should not be done in such case.
 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
   InitGraph(
       "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
       "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
+      "node { name: 'M' op: '_MklInput'}"
+      "node { name: 'N' op: '_MklInput'}"
+      "node { name: 'C' op: '_MklConv2D'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
+      " input: ['A', 'B', 'M', 'N']}"
       "node { name: 'D' op: 'Input'}"
       "node { name: 'E' op: 'Input'}"
       "node { name: 'F' op: 'BiasAdd'"
@@ -293,9 +300,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " input: ['C', 'E'] }");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
-            "G(Add);M(MklInput);N(MklInput)|A->C;B->C:2;C->G;D->F;"
-            "E->F:1;E->G:1;M->C:1;N->C:3");
+            "A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
+            "G(Add);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;"
+            "E->F:1;E->G:1;M->C:2;N->C:3");
 }
 
 // data_format attribute value mismatch. Merge should not be done
@@ -303,43 +310,81 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
 TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
   InitGraph(
       "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
       "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
+      "node { name: 'M' op: '_MklInput'}"
+      "node { name: 'N' op: '_MklInput'}"
+      "node { name: 'C' op: '_MklConv2D'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
+      " input: ['A', 'B', 'M', 'N']}"
       "node { name: 'D' op: 'Input'}"
       "node { name: 'E' op: 'BiasAdd'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NHCW' } }"
       " input: ['C', 'D'] }");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);C(MklConv2D);D(Input);E(BiasAdd);M(MklInput);"
-            "N(MklInput)|A->C;B->C:2;C->E;D->E:1;M->C:1;N->C:3");
+            "A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);"
+            "N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
 }
 
-// No MklConv2D in context, but Conv2D in context.
-// Only Conv2D would be rewritten to MklConv2D, but no rewrite
-// for BiasAddGrad should happen.
-// C=MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
+// Disabling Conv2DBackpropBias test for now as we have disabled rewrite
+// of BiasAddGrad into BackpropBias
+#if 0
+// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
+// rewrite tests
+
+// D=_MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E)
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
   InitGraph(
       "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
       "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'M' op: '_MklInput'}"
+      "node { name: 'N' op: '_MklInput'}"
+      "node { name: 'O' op: '_MklInput'}"
+      "node { name: 'D' op: '_MklConv2DWithBias'"
       " attr { key: 'T'                value { type: DT_FLOAT } }"
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
       " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
       " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
       " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
+      " input: ['A', 'B', 'C', 'M', 'N', 'O']}"
+      "node { name: 'E' op: 'Sub'"
+      " attr {key: 'T'                 value { type: DT_FLOAT } }"
+      " input: ['D', 'A']}"
+      "node { name: 'F' op: 'BiasAddGrad'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " input: ['E'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
+            "E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
+            "O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
+            "M->D:3;N->D:4;O->D:5");
+}
+#endif
+
+// No _MklConv2D in context, but Conv2D in context.
+// Only Conv2D would be rewritten to _MklConv2D, but no rewrite
+// for BiasAddGrad should happen.
+// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
+// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
+TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'M' op: '_MklInput'}"
+      "node { name: 'N' op: '_MklInput'}"
+      "node { name: 'C' op: '_MklConv2D'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
+      " attr { key: 'padding'          value { s: 'SAME' } }"
+      " input: ['A', 'B', 'M', 'N']}"
       "node { name: 'D' op: 'Sub'"
       " attr {key: 'T'                 value { type: DT_FLOAT } }"
       " input: ['C', 'A']}"
@@ -348,9 +393,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
       " attr { key: 'data_format'      value { s: 'NCHW' } }"
       " input: ['D'] }");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);C(MklConv2D);D(Sub);E(BiasAddGrad);"
-            "M(MklInput);N(MklInput)|A->C;A->D:1;B->C:2;C->D;D->E;"
-            "M->C:1;N->C:3");
+            "A(Input);B(Input);C(_MklConv2D);D(Sub);E(BiasAddGrad);"
+            "M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
+            "M->C:2;N->C:3");
 }
 
 // No Conv2D in the context for BiasAddGrad. No rewrite should happen.
@@ -462,8 +507,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
       "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
       " input: ['B', 'C'] }");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);C(MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
-            "A->C;B->C:2;B->D;C->D:1;DMT/_0->C:1;DMT/_1->C:3");
+            "A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
+            "A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
 }
 
 // 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
@@ -489,9 +534,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
       "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
       " input: ['C', 'D'] }");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);C(MklConv2D);D(MklConv2D);DMT/_0(Const);"
-            "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:2;C->D:2;C->E;"
-            "C:1->D:3;D->E:1;DMT/_0->C:1;DMT/_1->C:3;DMT/_2->D:1");
+            "A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
+            "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;"
+            "C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
 }
 
 // Conv2D with INT32 which is not supported by Mkl
@@ -513,10 +558,374 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
             "A->C;B->C:1;B->D;C->D:1");
 }
 
+// Concat Op test: Concat with no Mkl layer feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
+  InitGraph(
+      "node { name: 'A' op: 'Const' "
+      " attr { key: 'dtype' value { type: DT_INT32 } }"
+      " attr { key: 'value' value { "
+      "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+      "    int_val: 0 } } } }"
+      "node { name: 'B' op: 'InputList'"
+      " attr { key: 'N'                value { i: 2 } }}"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'D' op: 'Concat'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'N'                value { i: 2 } }"
+      " input: ['A', 'B']}"
+      "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['C', 'D'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
+            "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;"
+            "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+
+// Concat with 2 Mkl layers feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'D' op: 'Input'}"
+      "node { name: 'E' op: 'Conv2D'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
+      " attr { key: 'padding'          value { s: 'SAME' } }"
+      " input: ['A', 'B']}"
+      "node { name: 'F' op: 'Conv2D'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
+      " attr { key: 'padding'          value { s: 'SAME' } }"
+      " input: ['C', 'D']}"
+      "node { name: 'G' op: 'Const' "
+      " attr { key: 'dtype' value { type: DT_INT32 } }"
+      " attr { key: 'value' value { "
+      "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+      "    int_val: 0 } } } }"
+      "node { name: 'H' op: 'Concat'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'N'                value { i: 2 } }"
+      " input: ['G', 'E', 'F']}"
+      "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A', 'H'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+            "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
+            "F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
+            "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
+            "DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1");
+}
+
+// Concat with 1 Mkl and 1 non-Mkl layer feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'D' op: 'Input'}"
+      "node { name: 'E' op: 'Conv2D'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
+      " attr { key: 'padding'          value { s: 'SAME' } }"
+      " input: ['A', 'B']}"
+      "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['C', 'D']}"
+      "node { name: 'G' op: 'Const' "
+      " attr { key: 'dtype' value { type: DT_INT32 } }"
+      " attr { key: 'value' value { "
+      "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+      "    int_val: 0 } } } }"
+      "node { name: 'H' op: 'Concat'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'N'                value { i: 2 } }"
+      " input: ['G', 'E', 'F']}"
+      "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A', 'H'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+            "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
+            "H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
+            "DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
+            "G->H;H->I:1");
+}
+
+#if 0
+// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
+  InitGraph(
+      "node { name: 'A' op: 'Const' "
+      " attr { key: 'dtype' value { type: DT_INT32 } }"
+      " attr { key: 'value' value { "
+      "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+      "    int_val: 0 } } } }"
+      "node { name: 'B' op: 'InputList'"
+      " attr { key: 'N'                value { i: 2 } }}"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'D' op: 'ConcatV2'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'Tidx'             value { type: DT_INT32 } }"
+      " attr { key: 'N'                value { i: 2 } }"
+      " input: ['B:0', 'B:1', 'A']}"
+      "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['C', 'D'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
+            "DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;"
+            "D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+#endif
+
+// ConcatV2 with 2 Mkl layers feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'D' op: 'Input'}"
+      "node { name: 'E' op: 'Conv2D'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
+      " attr { key: 'padding'          value { s: 'SAME' } }"
+      " input: ['A', 'B']}"
+      "node { name: 'F' op: 'Conv2D'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
+      " attr { key: 'padding'          value { s: 'SAME' } }"
+      " input: ['C', 'D']}"
+      "node { name: 'G' op: 'Const' "
+      " attr { key: 'dtype' value { type: DT_INT32 } }"
+      " attr { key: 'value' value { "
+      "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+      "    int_val: 0 } } } }"
+      "node { name: 'H' op: 'ConcatV2'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'Tidx'             value { type: DT_INT32 } }"
+      " attr { key: 'N'                value { i: 2 } }"
+      " input: ['E', 'F', 'G']}"
+      "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A', 'H'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+            "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
+            "F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
+            "D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
+            "DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1");
+}
+
+// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
+TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'D' op: 'Input'}"
+      "node { name: 'E' op: 'Conv2D'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'      value { s: 'NCHW' } }"
+      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
+      " attr { key: 'padding'          value { s: 'SAME' } }"
+      " input: ['A', 'B']}"
+      "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['C', 'D']}"
+      "node { name: 'G' op: 'Const' "
+      " attr { key: 'dtype' value { type: DT_INT32 } }"
+      " attr { key: 'value' value { "
+      "    tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
+      "    int_val: 0 } } } }"
+      "node { name: 'H' op: 'ConcatV2'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " attr { key: 'Tidx'             value { type: DT_INT32 } }"
+      " attr { key: 'N'                value { i: 2 } }"
+      " input: ['E', 'F', 'G']}"
+      "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A', 'H'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+            "DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
+            "H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
+            "DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;"
+            "G->H:2;H->I:1");
+}
+
 /////////////////////////////////////////////////////////////////////
 //  Unit tests related to rewriting node for workspace edges
 /////////////////////////////////////////////////////////////////////
 
+/* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */
+TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'LRN'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'alpha'        value { f: 0.001 } }"
+      " attr { key: 'beta'         value { f: 0.75 } }"
+      " attr { key: 'bias'         value { f: 1.0 } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'depth_radius' value { i: 2 } }"
+      " input: ['A'] }"
+      "node { name: 'C' op: 'MaxPool'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
+      " attr { key: 'padding'      value { s: 'VALID' } }"
+      " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
+      " input: ['B'] }"
+      "node { name: 'D' op: 'Input'}"
+      "node { name: 'E' op: 'MaxPoolGrad'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'ksize'        value { list: {i: 1, i:1, i:3, i:3} } }"
+      " attr { key: 'padding'      value { s: 'VALID' } }"
+      " attr { key: 'strides'      value { list: {i: 1, i:1, i:2, i:2} } }"
+      " input: ['B', 'C', 'D'] }"
+      "node { name: 'F' op: 'Input'}"
+      "node { name: 'G' op: 'LRNGrad'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'alpha'        value { f: 0.001 } }"
+      " attr { key: 'beta'         value { f: 0.75 } }"
+      " attr { key: 'bias'         value { f: 1.0 } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'depth_radius' value { i: 2 } }"
+      " input: ['E', 'F', 'B'] }"
+      "node { name: 'H' op: 'Input'}"
+      "node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['H', 'G'] }");
+  EXPECT_EQ(
+      DoMklLayoutOptimizationPass(),
+      "A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
+      "DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|"
+      "A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;"
+      "C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;"
+      "DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I");
+}
+
+/* Test LRN->LRNGrad replacement by workspace nodes. */
+TEST_F(MklLayoutPassTest, LRN_Positive) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'LRN'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'alpha'        value { f: 0.001 } }"
+      " attr { key: 'beta'         value { f: 0.75 } }"
+      " attr { key: 'bias'         value { f: 1.0 } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'depth_radius' value { i: 2 } }"
+      " input: ['A'] }"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'D' op: 'Input'}"
+      "node { name: 'E' op: 'LRNGrad'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'alpha'        value { f: 0.001 } }"
+      " attr { key: 'beta'         value { f: 0.75 } }"
+      " attr { key: 'bias'         value { f: 1.0 } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'depth_radius' value { i: 2 } }"
+      " input: ['C', 'D', 'B'] }"
+      "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['C', 'E'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+            "DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
+            "A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;"
+            "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
+}
+
+/* Test LRN->LRNGrad replacement when only one of them is present. */
+TEST_F(MklLayoutPassTest, LRN_Negative1) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'LRN'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'alpha'        value { f: 0.001 } }"
+      " attr { key: 'beta'         value { f: 0.75 } }"
+      " attr { key: 'bias'         value { f: 1.0 } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'depth_radius' value { i: 2 } }"
+      " input: ['A'] }"
+      "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A', 'B'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
+            "A->B;A->C;B->C:1;DMT/_0->B:1");
+}
+
+/* Test LRN->LRNGrad replacement when only one of them is present. */
+TEST_F(MklLayoutPassTest, LRN_Negative2) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'D' op: 'LRNGrad'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'alpha'        value { f: 0.001 } }"
+      " attr { key: 'beta'         value { f: 0.75 } }"
+      " attr { key: 'bias'         value { f: 1.0 } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'depth_radius' value { i: 2 } }"
+      " input: ['A', 'B', 'C'] }"
+      "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['A', 'D'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
+            "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
+            "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
+            "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
+}
+
+/* Test LRN->LRNGrad negative case, where single LRN feeds
+   2 LRNGrad nodes at different slots. */
+TEST_F(MklLayoutPassTest, LRN_Negative3) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'LRN'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'alpha'        value { f: 0.001 } }"
+      " attr { key: 'beta'         value { f: 0.75 } }"
+      " attr { key: 'bias'         value { f: 1.0 } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'depth_radius' value { i: 2 } }"
+      " input: ['A'] }"
+      "node { name: 'C' op: 'Input'}"
+      "node { name: 'D' op: 'Input'}"
+      "node { name: 'E' op: 'LRNGrad'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'alpha'        value { f: 0.001 } }"
+      " attr { key: 'beta'         value { f: 0.75 } }"
+      " attr { key: 'bias'         value { f: 1.0 } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'depth_radius' value { i: 2 } }"
+      " input: ['C', 'D', 'B'] }"
+      "node { name: 'F' op: 'LRNGrad'"
+      " attr { key: 'T'            value { type: DT_FLOAT } }"
+      " attr { key: 'alpha'        value { f: 0.001 } }"
+      " attr { key: 'beta'         value { f: 0.75 } }"
+      " attr { key: 'bias'         value { f: 1.0 } }"
+      " attr { key: 'data_format'  value { s: 'NCHW' } }"
+      " attr { key: 'depth_radius' value { i: 2 } }"
+      " input: ['C', 'B', 'D'] }"
+      "node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
+      " input: ['E', 'F'] }");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
+            "DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
+            "DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;"
+            "B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;"
+            "D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
+            "DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
+}
+
 /* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
   InitGraph(
@@ -540,10 +949,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
       "node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
       " input: ['C', 'E'] }");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
-            "DMT/_1(Const);DMT/_2(Const);E(MklMaxPoolGrad);F(Mul)|"
-            "A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:4;"
-            "DMT/_0->B:1;DMT/_1->E:1;DMT/_2->E:5;E->F:1");
+            "A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
+            "DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
+            "A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;"
+            "DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
 }
 
 // Test MaxPool>MaxPoolGrad replacement when only one of them is present.
@@ -562,11 +971,11 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
       "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
       " input: ['A', 'B'] }");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(MklMaxPool);C(Mul);DMT/_0(Const)|"
+            "A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
             "A->B;A->C;B->C:1;DMT/_0->B:1");
 }
 
-// Test MaxPool->MaxPoolGrad replacement when only one of them is present.
+// Test MaxPoolGrad replacement when only one of them is present.
 // In this case, we will rewrite MaxPoolGrad and for workspace tensor and
 // its Mkl part, we will generate dummy tensor.
 TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
@@ -584,10 +993,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
       "node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
       " input: ['A', 'D'] }");
   EXPECT_EQ(DoMklLayoutOptimizationPass(),
-            "A(Input);B(Input);C(Input);D(MklMaxPoolGrad);DMT/_0(Const);"
+            "A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
             "DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
-            "A->D;A->E;B->D:2;C->D:4;D->E:1;DMT/_0->D:1;DMT/_1->D:3;"
-            "DMT/_2->D:5;DMT/_3->D:6;DMT/_4->D:7");
+            "A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
+            "DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
 }
 
 /////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/core/graph/mkl_optimizer_merge.cc b/tensorflow/core/graph/mkl_optimizer_merge.cc
deleted file mode 100644
index a171a27d8f5..00000000000
--- a/tensorflow/core/graph/mkl_optimizer_merge.cc
+++ /dev/null
@@ -1,651 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifdef INTEL_MKL
-// This module implements node merging optimization on the graph.
-// We process the nodes in the graph in reverse postorder
-// (i.e. inputs before their downstream dependencies).
-//
-#include <memory>
-#include <queue>
-#include <set>
-#include <string>
-#include <utility>
-#include <vector>
-
-#include "tensorflow/core/graph/mkl_optimizer_merge.h"
-
-#include "tensorflow/core/common_runtime/function.h"
-#include "tensorflow/core/common_runtime/optimization_registry.h"
-#include "tensorflow/core/framework/node_def_util.h"
-#include "tensorflow/core/graph/algorithm.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/hash/hash.h"
-#include "tensorflow/core/platform/logging.h"
-
-namespace tensorflow {
-
-// How many hops do we search for matching node in the backward dataflow graph?
-// We use maxhop of 10 based on empirical observations. Also, these are
-// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D)
-// directly goes to backward nodes, we do not expect the hop-distance
-// would be more than few nodes.
-static size_t kNodeMergeContextMaxDepth = 10;
-
-// This optimization pass performs two tasks: merge
-// nodes in the forward pass, and rewrite the gradient ops
-// corresponding to merged forward ops.
-//
-// Merging nodes in the graph: Currently, it merges Conv2D+AddBias together.
-//
-// Rewriting nodes in the graph: This is neded in order to optimize
-// gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and
-// MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into
-// Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad.
-// This is context-specific optimization, where the context is the
-// forward operator that the BiasAddGrad corresponds to.
-class NodeMergeRewritePass : public GraphOptimizationPass {
- public:
-  NodeMergeRewritePass() {
-    csinfo_.conv2d = "MklConv2D";
-    csinfo_.conv2dwithbias = "MklConv2DWithBias";
-    csinfo_.conv2dwithbiasbackpropbias = "Conv2DWithBiasBackpropBias";
-    csinfo_.biasadd = "BiasAdd";
-    csinfo_.matmul = "MatMul";
-    csinfo_.biasaddgrad = "BiasAddGrad";
-
-    minfo_.push_back(
-        {csinfo_.conv2d, csinfo_.biasadd, 0, csinfo_.conv2dwithbias});
-
-// We use maxhop of 10 based on emperical observations. Also, these are
-// maxhops in backward data-flow graph. Since input of forward nodes
-// (Conv2D) directly goes to backward nodes, we do not expect the
-// hop-distance would be more than few nodes.
-// TODO(nhasabni) Temporarily disabling rewrite of BiasAddGrad.
-// Will enable it once we support Conv2DWithBiasBackpropBias op.
-#if 0
-    rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
-                  {csinfo_.conv2dwithbias, kNodeMergeContextMaxDepth}});
-    rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
-                  {csinfo_.conv2d, kNodeMergeContextMaxDepth}});
-    // For now, we are rewriting BiasAddGrad to BiasAddGrad for MatMul. This is
-    // because we do not have a separate Op for MatMulwithBias.
-    rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.biasaddgrad,
-                      {csinfo_.matmul, kNodeMergeContextMaxDepth}});
-#endif
-  }
-
-  // Standard interface to run optimization pass
-  Status Run(const GraphOptimizationPassOptions& options);
-
-  // Helper function which does most of heavy lifting for node merge
-  //
-  // Extracts common functionality between Run public interface and
-  // test interface.
-  //
-  // @return true, if and only if graph is mutated; false otherwise.
-  bool RunPass(std::unique_ptr<Graph>* g);
-
- private:
-  /// Structure to specify information used in node merge
-  typedef struct {
-    string pred;     // Predecessor node string
-    string succ;     // Successor node string
-    int op;          // What operand no the predecessor node corresponds
-                     // to successor node?
-    string newnode;  // Name of the node after merge
-  } MergeInfo;
-
-  /// Structure to specify information used in node rewrite
-  typedef struct {
-    string node;     // Name of the node to be rewritten
-    string rewrite;  // New name of the node after rewrite
-    typedef struct {
-      string fwd;     // Node name in forward pass that this node
-                      // corresponds to
-      size_t maxhop;  // Maximum number of hops the mfwd_ is located
-                      // from this node. If mfwd_ is farther than mmaxhop_
-                      // then we do not rewrite the node.
-    } ContextInfo;
-    ContextInfo cinfo;  // Context for rewrite
-  } RewriteInfo;
-
-  /// Structure to store all constant strings
-  typedef struct {
-    string conv2d;
-    string conv2dwithbias;
-    string conv2dwithbiasbackpropbias;
-    string biasadd;
-    string matmul;
-    string biasaddgrad;
-  } ConstStringInfo;
-
-  ConstStringInfo csinfo_;
-  std::vector<MergeInfo> minfo_;
-  std::vector<RewriteInfo> rinfo_;
-
- private:
-  // Return a node that can be merged with input node
-  //
-  // @return pointer to the node if we can find such a
-  // node. Otherwise, it returns nullptr.
-  Node* FindNodeForMerge(const Node* a) const;
-
-  // Merge predecessor node with its successor.
-  // Currently, we merge Conv2D with AddBias only.
-  //
-  // Input nodes succ and pred may be deleted if the call to
-  // this function is successful. Attempt to use the pointers
-  // after the call to function may result is undefined behaviors.
-  //
-  // @input g - input graph, succ - successor node, pred - predecessor node
-  // @return Status::OK(), if merging is successful and supported.
-  //         Returns appropriate Status error code otherwise.
-  //         Graph is updated in case nodes are merged. Otherwise, it is
-  //         not updated.
-  Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred);
-
-  // Is input node (n) a candidate for rewrite?
-  //
-  // @return true, if it can be rewritten; false, otherwise.
-  bool IsApplicableRewriteNode(const Node* n) const;
-
-  // Rewrites input node to a new node specified by its matching rewrite info.
-  //
-  // Method first searches matching rewrite info for input node and then
-  // uses that info to rewrite.
-  //
-  // Input node may be deleted in case of rewrite. Attempt to use the node
-  // after the call can result in undefined behaviors.
-  //
-  // @input  g - input graph, n - Node to be rewritten
-  // @return Status::OK(), if the input node is rewritten;
-  //         Returns appropriate Status error code otherwise.
-  //         Graph is updated in case the input node is rewritten.
-  //         Otherwise, it is not updated.
-  Status RewriteNode(std::unique_ptr<Graph>* g, Node* n);
-
-  // Helper function that searches the matching rewriteinfo for the node.
-  // Implements depth-first search in the data dependence graph for the
-  // gradient op in backward direction.
-  //
-  // @input n - Node (gradient op) whose rewriteinfo is to be searched,
-  //        fwdn - pointer to node from the forward pass that this node
-  //        belongs to
-  // @return Matching rewriteinfo in case a match is found; null otherwise.
-  const RewriteInfo* FindMatchingRewriteInfo(const Node* n,
-                                             const Node** fwdn) const;
-
-  // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
-  // and return it in '*out'.
-  // TODO(nhasabni) We should move this to mkl_util.h
-  void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out);
-};
-
-// We register merge optimizer for phase 2 in pre-placement group.
-// Do not change the ordering of the Mkl passes.
-REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 2,
-                      NodeMergeRewritePass);
-
-static void FillInputs(const Node* n,
-                       gtl::InlinedVector<Node*, 4>* control_edges,
-                       gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
-  DCHECK_EQ(in->size(), n->num_inputs());
-  control_edges->clear();
-  for (const Edge* e : n->in_edges()) {
-    if (e->IsControlEdge()) {
-      control_edges->push_back(e->src());
-    } else {
-      (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
-    }
-  }
-  std::sort(control_edges->begin(), control_edges->end());
-  if (n->op_def().is_commutative()) {
-    // For commutative inputs, we sort the input by the input Node*
-    // to get a canonical ordering (so that add(a,b) and add(b, a) will
-    // hash to the same value if is_commutative is true for 'add').
-    std::sort(in->begin(), in->end());
-  }
-}
-
-Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const {
-  // Search for all matching mergeinfo.
-  // We allow more than one match for extensibility.
-  std::vector<const MergeInfo*> matching_mi;
-  for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
-    if (a->type_string() == mi->succ) {
-      matching_mi.push_back(&*mi);
-    }
-  }
-
-  for (const MergeInfo* mi : matching_mi) {
-    const int N_in = a->num_inputs();
-    if (mi->op >= N_in) {
-      continue;
-    }
-
-    // Get the control edges and input of node
-    gtl::InlinedVector<Node*, 4> a_control_edges;
-    gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
-    FillInputs(a, &a_control_edges, &a_in);
-
-    // Get operand op of the operator
-    Node* b = nullptr;
-    b = a_in[mi->op].first;
-    if (b == nullptr || (b->type_string() != mi->pred)) {
-      // NOTE: Should the first check be assert?
-      continue;
-    }
-
-    gtl::InlinedVector<Node*, 4> b_control_edges;
-    gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
-    FillInputs(b, &b_control_edges, &b_in);
-
-    // Shouldn't merge if a and b have different control edges.
-    if (a_control_edges != b_control_edges) {
-      continue;
-    } else {
-      // We found a match.
-      return b;
-    }
-  }
-
-  return nullptr;
-}
-
-void NodeMergeRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
-                                                 Node** out) {
-  const DataType dt = DataTypeToEnum<uint8>::v();
-  TensorProto proto;
-  proto.set_dtype(dt);
-  uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
-  proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
-                           8);
-  TensorShape dummy_shape({8});
-  dummy_shape.AsProto(proto.mutable_tensor_shape());
-  TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
-                  .Attr("value", proto)
-                  .Attr("dtype", dt)
-                  .Finalize(&**g, out));
-}
-
-Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
-                                       Node* pred) {
-  CHECK_NOTNULL(succ);
-  CHECK_NOTNULL(pred);
-
-  if (succ->type_string() == csinfo_.biasadd &&
-      pred->type_string() == csinfo_.conv2d) {
-    // 1. Get all attributes from input nodes.
-    DataType T_pred, T_succ;
-    string padding;
-    std::vector<int32> strides;
-    string data_format_pred, data_format_succ;
-    bool use_cudnn_on_gnu;
-    TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
-    TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
-    TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
-    TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
-    TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
-    TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
-    TF_CHECK_OK(
-        GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
-    // We check to ensure that data formats of both succ and pred are same.
-    // We expect them to be same, so we can enforce this as assert.
-    // But assert can be too strict, so we enforce this as a check.
-    // If the check fails, then we do not merge two nodes.
-    // We also do same check for devices.
-    if (data_format_pred != data_format_succ || T_pred != T_succ ||
-        pred->assigned_device_name() != succ->assigned_device_name() ||
-        pred->def().device() != succ->def().device()) {
-      return Status(error::Code::INVALID_ARGUMENT,
-                    "data_format or T attribute or devices of Conv2D and "
-                    "BiasAdd do not match. Will skip node merge optimization");
-    }
-
-    // 2. Get inputs from both the nodes.
-    // Find the 2 inputs from the conv and the bias from the add Bias.
-    Node* oper1 = nullptr;
-    Node* oper1_mkl = nullptr;  // Mkl tensor corresponding to oper1
-    Node* oper2 = nullptr;
-    Node* oper2_mkl = nullptr;  // Mkl tensor corresponding to oper2
-    Node* oper3 = nullptr;
-    Node* oper3_mkl = nullptr;  // Mkl tensor corresponding to oper3
-
-    const int succ_num = succ->num_inputs();
-    gtl::InlinedVector<Node*, 4> succ_control_edges;
-    gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
-    FillInputs(succ, &succ_control_edges, &succ_in);
-
-    const int pred_num = pred->num_inputs();
-    gtl::InlinedVector<Node*, 4> pred_control_edges;
-    gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
-    FillInputs(pred, &pred_control_edges, &pred_in);
-
-    // We need to ensure that there is only 1 edge between Conv2D and AddBias.
-    // Otherwise, merging is semantically incorrect.
-    if (pred->out_edges().size() != 1) {
-      return Status(error::Code::INVALID_ARGUMENT,
-                    "Conv2D has multiple outputs."
-                    "Will skip node merge optimization");
-    }
-
-    for (const Edge* e : pred->out_edges()) {
-      if (e->dst() != succ) {
-        return Status(error::Code::INVALID_ARGUMENT,
-                      "Conv2D does not feed to BiasAdd."
-                      "Will skip node merge optimization");
-      }
-    }
-
-    // Get operand 0, 1 of conv2D and their Mkl tensors.
-    CHECK_EQ(pred->in_edges().size(), 4);  // MklConv2D must have 4 inputs.
-    oper1 = pred_in[0].first;
-    oper1_mkl = pred_in[1].first;
-    oper2 = pred_in[2].first;
-    oper2_mkl = pred_in[3].first;
-    // Get operand 1 of add_bias
-    // BiasAdd must have 2 inputs: Conv, bias
-    CHECK_EQ(succ->in_edges().size(), 2);
-    oper3 = succ_in[1].first;
-    GetDummyMklTensorNode(g, &oper3_mkl);  // Get dummy Mkl tensor node
-    // as BiasAdd does not have Mkl tensor as input.
-    CHECK_NOTNULL(oper3_mkl);
-
-    Node* ret;
-    // We will use the node name of BiasAdd as the name of new node
-    TF_CHECK_OK(NodeBuilder(succ->name(), csinfo_.conv2dwithbias)
-                    .Input(oper1)
-                    .Input(oper1_mkl)
-                    .Input(oper2)
-                    .Input(oper2_mkl)
-                    .Input(oper3)
-                    .Input(oper3_mkl)
-                    .Attr("T", T_pred)
-                    .Attr("strides", strides)
-                    .Attr("padding", padding)
-                    .Attr("data_format", data_format_pred)
-                    .Attr("use_cudnn_on_gpu", use_cudnn_on_gnu)
-                    .Device(succ->def().device())
-                    .Finalize(&**g, &ret));
-    CHECK_NOTNULL(ret);
-
-    // Incoming edges are fixed, we will fix the outgoing edges now.
-    for (const Edge* e : succ->out_edges()) {
-      (*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
-    }
-
-    // Copy device assigned to old node to new node.
-    // It's ok to use pred or succ as we have enforced a check that
-    // both have same device assigned.
-    ret->set_assigned_device_name(pred->assigned_device_name());
-
-    VLOG(1) << "NodeMergeRewritePass: Merged old node:" << pred->DebugString()
-            << ", and node: " << succ->DebugString()
-            << ", into node:" << ret->DebugString();
-
-    (*g)->RemoveNode(succ);
-    (*g)->RemoveNode(pred);
-
-    return Status::OK();
-  }
-
-  return Status(error::Code::UNIMPLEMENTED,
-                "Unimplemented case for node merge optimization.");
-}
-
-Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* n) {
-  CHECK_NOTNULL(n);
-
-  // Get the matching rewriteinfo for the node
-  const Node* fwdn = nullptr;
-  const RewriteInfo* ri = FindMatchingRewriteInfo(n, &fwdn);
-  if (ri == nullptr || fwdn == nullptr) {
-    VLOG(2) << "NodeMergeRewritePass: Rewriteinfo not found for: "
-            << n->type_string();
-    return Status(error::Code::INVALID_ARGUMENT,
-                  "Rewrite info not found for the node."
-                  "Will skip node rewrite optimization");
-  }
-
-  VLOG(1) << "NodeMergeRewritePass: Rewrite called for: " << n->type_string();
-
-  if (n->type_string() == csinfo_.biasaddgrad &&
-      ri->node == csinfo_.biasaddgrad &&
-      (ri->rewrite == csinfo_.conv2dwithbiasbackpropbias ||
-       ri->rewrite == csinfo_.biasaddgrad)) {
-    DataType T;
-    string data_format;
-    TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
-    TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format));
-
-    int n_num = n->num_inputs();  // this must be 1.
-    CHECK_EQ(n_num, 1);
-
-    gtl::InlinedVector<Node*, 4> n_control_edges;
-    gtl::InlinedVector<std::pair<Node*, int>, 4> n_in(n_num);
-    FillInputs(n, &n_control_edges, &n_in);
-
-    Node *ret = nullptr, *op = n_in[0].first;
-
-    if (ri->rewrite == csinfo_.conv2dwithbiasbackpropbias) {
-      // Get strides info from Conv2D (node in the forward pass that this
-      // node corresponds to).
-      std::vector<int32> strides;
-      TF_CHECK_OK(GetNodeAttr(fwdn->def(), "strides", &strides));
-
-      // We use same name as original node name as there may be fetchoutputs
-      // associated with it.
-      TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
-                      .Input(op)
-                      .Attr("T", T)
-                      .Attr("data_format", data_format)
-                      .Attr("strides", strides)
-                      .Device(n->def().device())
-                      .Finalize(&**g, &ret));
-    } else {
-      CHECK_EQ(ri->rewrite, csinfo_.biasaddgrad);
-      TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
-                      .Input(op)
-                      .Attr("T", T)
-                      .Attr("data_format", data_format)
-                      .Device(n->def().device())
-                      .Finalize(&**g, &ret));
-    }
-
-    CHECK_NOTNULL(ret);
-
-    // Incoming edges are fixed, we will fix the outgoing edges now.
-    for (const Edge* e : n->out_edges()) {
-      (*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
-    }
-
-    // Copy device assigned to old node to new node.
-    ret->set_assigned_device_name(n->assigned_device_name());
-
-    VLOG(1) << "MKLOptimizerMergePass: Rewrote old node:" << n->DebugString()
-            << ", into node:" << ret->DebugString();
-    (*g)->RemoveNode(n);
-
-    return Status::OK();
-  }
-
-  return Status(error::Code::UNIMPLEMENTED,
-                "Unimplemented case for node rewrite optimization.");
-}
-
-const NodeMergeRewritePass::RewriteInfo*
-NodeMergeRewritePass::FindMatchingRewriteInfo(const Node* n,
-                                              const Node** fwdn) const {
-  CHECK_NOTNULL(n);
-  CHECK_NOTNULL(fwdn);
-  *fwdn = nullptr;
-
-  // Search for matching rewriteinfo based on node name.
-  // There could be more than one matching rewriteinfos.
-  std::vector<const RewriteInfo*> matching_ri;
-  for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
-    if (n->type_string() == ri->node) {
-      matching_ri.push_back(&*ri);
-    }
-  }
-
-  VLOG(1) << "NodeMergeRewritePass: Searching graph for: " << n->type_string()
-          << " in backwards.";
-
-  // Now we will check for forward op name for rewrite info in data
-  // flow graph. Get the max hops we should search for the fwd node
-  // We are now going to search (breadth-first) backwards in data
-  // dependence graph (for up to max hops) from n for the node
-  // specified in fwd.
-  // queue to maintain nodes to be visited and depth info for
-  // breadth-first search
-  std::queue<std::pair<const Node*, int>> nqueue;
-  const Node* curr_node = n;
-  size_t curr_depth = 0;
-  nqueue.push(std::make_pair(curr_node, curr_depth));
-
-  while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) {
-    std::pair<const Node*, int> curr_pair = nqueue.front();
-    nqueue.pop();
-
-    std::set<const Node*> visited_nodes;
-    curr_node = curr_pair.first;
-    curr_depth = curr_pair.second;
-    CHECK_NOTNULL(curr_node);
-
-    VLOG(1) << "NodeMergeRewritePass: Visiting node: "
-            << curr_node->type_string() << " at depth: " << curr_depth
-            << " for node: " << n->type_string();
-
-    // If we find a match, we return immediately with the matching rewrite
-    // info.
-    for (const RewriteInfo* ri : matching_ri) {
-      if (curr_node->type_string() == ri->cinfo.fwd) {
-        *fwdn = curr_node;
-        return ri;
-      }
-    }
-
-    // Else we explore backward edges from current node.
-    // Add the source nodes of all incoming edges of the node to the queue.
-    for (const Edge* e : curr_node->in_edges()) {
-      // We do not visit already visited node.
-      if (visited_nodes.find(e->src()) == visited_nodes.end()) {
-        // Depth of these nodes is 1 more than the depth of current node.
-        nqueue.push(std::make_pair(e->src(), curr_depth + 1));
-        visited_nodes.insert(e->src());
-      }
-    }
-  } /* while */
-
-  return nullptr;
-}
-
-bool NodeMergeRewritePass::IsApplicableRewriteNode(const Node* n) const {
-  CHECK_NOTNULL(n);
-
-  // Search for matching rewriteinfo
-  // Even if we find one match, we return true.
-  bool match_found = false;
-  for (const RewriteInfo& ri : rinfo_) {
-    if (n->type_string() == ri.node) {
-      match_found = true;
-      break;
-    }
-  }
-
-  return match_found;
-}
-
-bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) {
-  bool result = false;
-  CHECK_NOTNULL(g);
-
-  DumpGraph("Before OptimizeMerge", &**g);
-
-  std::vector<Node*> order;
-  GetReversePostOrder(**g, &order);
-  std::vector<std::pair<Node*, Node*>> nodes_to_be_merged;
-  std::vector<Node*> nodes_to_be_rewritten;
-
-  for (Node* n : order) {
-    if (!n->IsOp()) continue;
-    Node* n1 = nullptr;
-    if ((n1 = FindNodeForMerge(n)) != nullptr) {
-      VLOG(1) << "NodeMergeRewritePass: Scheduled nodes " << n->name()
-              << " and " << n1->name() << " for merging";
-      nodes_to_be_merged.push_back(std::make_pair(n, n1));
-    } else if (IsApplicableRewriteNode(n)) {
-      VLOG(1) << "NodeMergeRewritePass: Scheduled node " << n->name()
-              << " for rewrite";
-      nodes_to_be_rewritten.push_back(n);
-    }
-  }
-
-  for (std::pair<Node*, Node*> i : nodes_to_be_merged) {
-    // Even if MergeNode merges single pair of nodes, we
-    // need to return true.
-    string n1_name = i.first->name();
-    string n2_name = i.second->name();
-    if (MergeNode(g, i.first, i.second) == Status::OK()) {
-      VLOG(1) << "NodeMergeRewritePass: Merged nodes " << n1_name << " and "
-              << n2_name;
-      result = true;
-    }
-  }
-
-  DumpGraph("After OptimizeMerge(nodemerge)", &**g);
-
-  for (Node* i : nodes_to_be_rewritten) {
-    string name = i->name();
-    if (RewriteNode(g, i) == Status::OK()) {
-      VLOG(1) << "NodeMergeRewritePass: Rewrite node: " << name
-              << " successful.";
-      result = true;
-    }
-  }
-
-  DumpGraph("After OptimizeMerge(noderewrite)", &**g);
-
-  return result;
-}
-
-bool OptimizeNodeMerge(std::unique_ptr<Graph>* g) {
-  return NodeMergeRewritePass().RunPass(g);
-}
-
-Status NodeMergeRewritePass::Run(const GraphOptimizationPassOptions& options) {
-  if (options.graph == nullptr) {
-    return Status::OK();
-  }
-
-  // Get the ownership of graph
-  std::unique_ptr<Graph>* g = std::move(options.graph);
-
-  RunPass(g);
-
-  // Return the ownership of graph back
-  options.graph->reset(g->release());
-
-  return Status::OK();
-}
-
-}  // namespace tensorflow
-
-#endif
diff --git a/tensorflow/core/graph/mkl_optimizer_merge.h b/tensorflow/core/graph/mkl_optimizer_merge.h
deleted file mode 100644
index b2caec58aff..00000000000
--- a/tensorflow/core/graph/mkl_optimizer_merge.h
+++ /dev/null
@@ -1,36 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// An optimization pass that performs node merging and rewrite on graph nodes
-
-#ifndef TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
-#define TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
-
-#ifdef INTEL_MKL
-
-#include <sys/types.h>
-#include <memory>
-#include "tensorflow/core/graph/graph.h"
-
-namespace tensorflow {
-// Interface to invoke the pass for unit test
-//
-// Returns true if and only if 'g' is mutated.
-extern bool OptimizeNodeMerge(std::unique_ptr<Graph>* g);
-}  // namespace tensorflow
-
-#endif  // INTEL_MKL
-
-#endif  // TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
diff --git a/tensorflow/core/graph/mkl_optimizer_merge_test.cc b/tensorflow/core/graph/mkl_optimizer_merge_test.cc
deleted file mode 100644
index f752721d6e0..00000000000
--- a/tensorflow/core/graph/mkl_optimizer_merge_test.cc
+++ /dev/null
@@ -1,470 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifdef INTEL_MKL
-
-#include "tensorflow/core/graph/mkl_optimizer_merge.h"
-
-#include <vector>
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/testlib.h"
-#include "tensorflow/core/kernels/ops_util.h"
-#include "tensorflow/core/lib/random/simple_philox.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
-
-namespace tensorflow {
-namespace {
-
-class OptimizerMergeTest : public ::testing::Test {
- public:
-  OptimizerMergeTest() : graph_(OpRegistry::Global()) {}
-
-  static void InitGraph(const string& s, Graph* graph) {
-    GraphDef graph_def;
-
-    auto parser = protobuf::TextFormat::Parser();
-    CHECK(parser.MergeFromString(s, &graph_def)) << s;
-    GraphConstructorOptions opts;
-    TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
-  }
-
-  void InitGraph(const string& s) {
-    InitGraph(s, &graph_);
-    original_ = CanonicalGraphString(&graph_);
-  }
-
-  static bool IncludeNode(const Node* n) { return n->IsOp(); }
-
-  static string EdgeId(const Node* n, int index) {
-    if (index == 0) {
-      return n->name();
-    } else if (index == Graph::kControlSlot) {
-      return strings::StrCat(n->name(), ":control");
-    } else {
-      return strings::StrCat(n->name(), ":", index);
-    }
-  }
-
-  string CanonicalGraphString(Graph* g) {
-    std::vector<string> nodes;
-    std::vector<string> edges;
-    for (const Node* n : g->nodes()) {
-      if (IncludeNode(n)) {
-        nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
-      }
-    }
-    for (const Edge* e : g->edges()) {
-      if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
-        edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
-                                        EdgeId(e->dst(), e->dst_input())));
-      }
-    }
-    // Canonicalize
-    std::sort(nodes.begin(), nodes.end());
-    std::sort(edges.begin(), edges.end());
-    return strings::StrCat(str_util::Join(nodes, ";"), "|",
-                           str_util::Join(edges, ";"));
-  }
-
-  string DoNodeMerge() {
-    string before = CanonicalGraphString(&graph_);
-    LOG(ERROR) << "Before node merge optimize: " << before;
-
-    std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
-    OptimizeNodeMerge(ug);
-
-    string result = CanonicalGraphString(&graph_);
-    LOG(ERROR) << "After node merge optimize:  " << result;
-    return result;
-  }
-
-  const string& OriginalGraph() const { return original_; }
-
-  Graph graph_;
-  string original_;
-};
-
-REGISTER_OP("Input").Output("o: float").SetIsStateful();
-REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
-
-TEST_F(OptimizerMergeTest, Basic) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
-      " input: ['A', 'B'] }"
-      "node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
-      " input: ['A', 'B'] }");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(Mul);D(Mul)|"
-            "A->C;A->D;B->C:1;B->D:1");
-}
-
-// Test set 1: Conv2D + AddBias
-
-// C=MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y)
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Positive) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
-      "node { name: 'D' op: 'Input'}"
-      "node { name: 'E' op: 'BiasAdd'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['C', 'D'] }"
-      "node { name: 'Y' op: 'Input'}"
-      "node { name: 'Z' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['E', 'Y']}");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
-            "M(MklInput);N(MklInput);Y(Input);Z(Sub)|A->E;B->E:2;D->E:4;"
-            "DMT/_0->E:5;E->Z;M->E:1;N->E:3;Y->Z:1");
-}
-
-// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
-// We do not merge in this case as op is Conv2D and not MklConv2D.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_NoMklConv2D) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'C' op: 'Conv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'B']}"
-      "node { name: 'D' op: 'Input'}"
-      "node { name: 'E' op: 'BiasAdd'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['C', 'D'] }"
-      "node { name: 'Y' op: 'Input'}"
-      "node { name: 'Z' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['E', 'Y']}");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(Conv2D);D(Input);E(BiasAdd);Y(Input);Z(Sub)|"
-            "A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
-}
-
-// Graph contains only MklConv2D, no AddBias.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_NoAddBias) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|"
-            "A->C;B->C:2;M->C:1;N->C:3");
-}
-
-// MklConv2D output does not go to BiasAdd.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow1) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
-      "node { name: 'D' op: 'Input'}"
-      "node { name: 'E' op: 'Input'}"
-      "node { name: 'F' op: 'BiasAdd'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['D', 'E'] }");  // Output of MklConv2D does not go to BiasAdd.
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
-            "M(MklInput);N(MklInput)|A->C;B->C:2;D->F;E->F:1;M->C:1;N->C:3");
-}
-
-// MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
-// Merge should not be done in such case.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow2) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
-      "node { name: 'D' op: 'Input'}"
-      "node { name: 'E' op: 'Input'}"
-      "node { name: 'F' op: 'BiasAdd'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['D', 'E'] }"  // Conv2D has two outputs.
-                              // No merge should happen.
-      "node { name: 'G' op: 'Add'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " input: ['C', 'E'] }");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
-            "G(Add);M(MklInput);N(MklInput)|A->C;B->C:2;C->G;D->F;"
-            "E->F:1;E->G:1;M->C:1;N->C:3");
-}
-
-// data_format attribute value mismatch. Merge should not be done
-// in such case.
-TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_AttrMismatch) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
-      "node { name: 'D' op: 'Input'}"
-      "node { name: 'E' op: 'BiasAdd'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NHCW' } }"
-      " input: ['C', 'D'] }");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(MklConv2D);D(Input);E(BiasAdd);M(MklInput);"
-            "N(MklInput)|A->C;B->C:2;C->E;D->E:1;M->C:1;N->C:3");
-}
-
-#if 0
-// This test set is disabled temporarily as we do not enable node rewrite.
-// This test set will be enabled when we support Mkl-specific kernels for
-// backward bias.
-//
-// Test set 2: MklConv2D..BiasAddGrad -> Conv2DWithBiasBackpropBias
-// rewrite tests
-
-// C=MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, Conv2DBackprop_Positive) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
-      "node { name: 'D' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['C', 'A']}"
-      "node { name: 'E' op: 'BiasAddGrad'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['D'] }");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(MklConv2D);D(Sub);E(Conv2DWithBiasBackpropBias);"
-            "M(MklInput);N(MklInput)|A->C;A->D:1;B->C:2;C->D;D->E;M->C:1;N->C:3");
-}
-
-// No MklConv2D in context, but Conv2D in context. No rewrite should happen.
-// C=Conv2D(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoMklConv2D) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'C' op: 'Conv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'B']}"
-      "node { name: 'D' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['C', 'A']}"
-      "node { name: 'E' op: 'BiasAddGrad'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['D'] }");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(Conv2D);D(Sub);E(BiasAddGrad)|"
-             "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// No Conv2D in the context for BiasAddGrad. No rewrite should happen.
-// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'C' op: 'Add'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " input: ['A', 'B']}"
-      "node { name: 'D' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['C', 'A']}"
-      "node { name: 'E' op: 'BiasAddGrad'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['D'] }");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
-             "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// No Conv2D in the context for BiasAddGrad, but MatMul in context.
-// Rewrite should happen, but name of BiasAddGrad does not change.
-// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D_MatMul) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'C' op: 'MatMul'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'transpose_a'      value { b: false } }"
-      " attr { key: 'transpose_b'      value { b: false } }"
-      " input: ['A', 'B']}"
-      "node { name: 'D' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['C', 'A']}"
-      "node { name: 'E' op: 'BiasAddGrad'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['D'] }");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
-             "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests
-// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Positive) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'C' op: 'MatMul'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'transpose_a'      value { b: false } }"
-      " attr { key: 'transpose_b'      value { b: false } }"
-      " input: ['A', 'B']}"
-      "node { name: 'D' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['C', 'A']}"
-      "node { name: 'E' op: 'BiasAddGrad'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['D'] }");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
-             "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-
-// No MatMul in the context for BiasAddGrad. No rewrite should happen.
-// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D)
-TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Negative_NoMatMul) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'C' op: 'Add'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " input: ['A', 'B']}"
-      "node { name: 'D' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['C', 'A']}"
-      "node { name: 'E' op: 'BiasAddGrad'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['D'] }");
-  EXPECT_EQ(DoNodeMerge(),
-            "A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
-             "A->C;A->D:1;B->C:1;C->D;D->E");
-}
-#endif
-
-static void BM_NodeMerge(int iters, int op_nodes) {
-  testing::StopTiming();
-  string s;
-  for (int in = 0; in < 10; in++) {
-    s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
-  }
-  random::PhiloxRandom philox(301, 17);
-  random::SimplePhilox rnd(&philox);
-  for (int op = 0; op < op_nodes; op++) {
-    s += strings::Printf(
-        "node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
-        "type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
-        op, rnd.Uniform(10), rnd.Uniform(10));
-  }
-
-  bool first = true;
-  while (iters > 0) {
-    Graph* graph = new Graph(OpRegistry::Global());
-    OptimizerMergeTest::InitGraph(s, graph);
-    int N = graph->num_node_ids();
-    if (first) {
-      testing::SetLabel(strings::StrCat("Per graph node.  Nodes: ", N));
-      first = false;
-    }
-    {
-      testing::StartTiming();
-      std::unique_ptr<Graph> ug(graph);
-      OptimizeNodeMerge(&ug);
-      testing::StopTiming();
-    }
-    iters -= N;  // Our benchmark units are individual graph nodes,
-                 // not whole graphs
-    // delete graph;
-  }
-}
-BENCHMARK(BM_NodeMerge)->Arg(1000)->Arg(10000);
-
-}  // namespace
-}  // namespace tensorflow
-
-#endif /* INTEL_MKL */
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index 7c3836b3089..55c280719c3 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -40,16 +40,16 @@ namespace tensorflow {
 
 // This pass inserts Mkl to Tf tensor conversion nodes (represented by C)
 // in the graph in between A and B, where A and B match any one
-// of the following
-// cases:
-//  1) A = layer/Op that generates output in Mkl format and,
-//     B = layer/Op that does not accept input in Mkl format and,
+// of the following cases:
+//
+//  1) A = a node that generates output in the Mkl format and,
+//     B = a node that does not accept input in the Mkl format and,
 //     A -> B (there is a direct edge between A and B, then
 //     We will insert C such that A->C->B.
 //
-//  2) A = layer/Op that generates output in Mkl format and,
-//     B = NULL (in other words, A is the last layer in the graph), then
-//     We will insert C such that A->C->B. (C will be the last layer.)
+//  2) A = a node that generates output in the Mkl format and,
+//     B = NULL (in other words, A is the last node in the graph), then
+//     We will insert C such that A->C->B. (C will be the last node.)
 //
 //  Note that case 1 applies to all outputs of A that are input to B.
 //  In other words, the conversions will be required for every output
@@ -59,9 +59,9 @@ namespace tensorflow {
 //  do the conversion for A1 and A2 only. We do not need to do any conversion
 //  for A3.
 //
-// This pass relies on layers registering themselves about their Mkl compliant.
-// Mkl compliant layer can accept inputs in Mkl format, and produce output in
-// Mkl format. Non-compliant layer accepts inputs and outputs in
+// This pass relies on ops registering themselves about their Mkl compliance.
+// An Mkl-compliant op can accept inputs in the Mkl format, and produce outputs
+// in the Mkl format. Non-compliant ops accept inputs and outputs in the
 // TensorFlow format.
 //
 class MklToTfConversionPass : public GraphOptimizationPass {
@@ -84,7 +84,7 @@ class MklToTfConversionPass : public GraphOptimizationPass {
   // @input T Datatype to use for checking input op
   // @return true if op is Mkl supported; false, otherwise.
   inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
-    return mkl_layer_registry::IsMklLayer(op_name, T);
+    return mkl_op_registry::IsMklOp(op_name, T);
   }
 
   // Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
@@ -129,14 +129,16 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
     return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
   }
 
-  // Lets build the conversion node and specify src as input.
+  // Build the conversion node and specify src as input.
   TF_CHECK_OK(
-      NodeBuilder((*g)->NewName("Mkl2Tf"), "MklToTf")
+      NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf")
           .Input(src, e->src_output())
-          .Input(src, e->src_output() + 1)  // Mkl tensor immediately
-                                            // follows Tf tensor.
-          .Device(src->def().device())      // We want to get conversion node
-                                            // on same device as source node.
+          .Input(src, DataIndexToMetaDataIndex(
+                          e->src_output(),
+                          src->num_outputs()))  // Get an Mkl tensor slot
+                                                // from the Tf tensor slot.
+          .Device(src->def().device())  // We want to get conversion node
+                                        // on same device as source node.
           .Attr("T", src_datatype)
           .Finalize(&**g, &conversion_node));
 
@@ -149,8 +151,8 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
   // We want conversion node to be on the same device as the source node.
   conversion_node->set_assigned_device_name(src->assigned_device_name());
 
-  // Set the Mkl layer label for this op.
-  conversion_node->AddAttr("_kernel", mkl_layer_registry::kMklLayerLabel);
+  // Set the Mkl op label for this op.
+  conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
 
   // Now that we have added edge from src->conversion_node, let's add edge from
   // output of conversion_node to the dest node. Since conversion_node
@@ -173,11 +175,11 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
 
   DumpGraph("Before MklToTfConversionPass", &**g);
 
-  // Since we are looking for mkl-supported op node immediately
-  // followed by non-mkl op node, we will just iterate over edge
+  // Since we are looking for an Mkl-supported op node immediately
+  // followed by a non-Mkl op node, we will just iterate over edge
   // set of the graph.
-  // vector to maintain candiadate edges whose source and destination
-  // are candidate for inserting conversion node
+  // edge set whose source and destination are candidates for
+  // inserting conversion node
   std::vector<Edge*> candidate_edges;
 
   for (const Edge* e : (*g)->edges()) {
@@ -190,9 +192,9 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
     }
 
     // We skip adding MklToTf on an edge between X->MklToTf or
-    // MklToTf->X, where X is any layer.
-    if (src->type_string().compare("MklToTf") == 0 ||
-        dst->type_string().compare("MklToTf") == 0) {
+    // MklToTf->X, where X is any node.
+    if (src->type_string().compare("_MklToTf") == 0 ||
+        dst->type_string().compare("_MklToTf") == 0) {
       continue;
     }
 
@@ -210,7 +212,6 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
     GetNodeAttr(dst->def(), "T", &dst_datatype);
 
     // Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
-
     if (IsMklSupportedOp(src->type_string(), src_datatype) &&
         !IsMklSupportedOp(dst->type_string(), dst_datatype)) {
       VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index 7d9237f8454..bd2cb0989c1 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
 #ifdef INTEL_MKL
 
 #include "tensorflow/core/graph/mkl_tfconversion_pass.h"
+#include "tensorflow/core/util/mkl_util.h"
 
 #include <algorithm>
 #include <string>
@@ -109,7 +110,7 @@ class MklToTfConversionPass : public ::testing::Test {
 
 REGISTER_OP("Input").Output("o: float").SetIsStateful();
 REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
-REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
+REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
 
 TEST_F(MklToTfConversionPass, Basic) {
   InitGraph(
@@ -125,58 +126,116 @@ TEST_F(MklToTfConversionPass, Basic) {
 }
 
 // MklConv2D followed by Non-Mkl layer
-// C=MklConv2D(A,M,B,N); E=Sub(C,D)
+// C=MklConv2D(A,M,B,N); E=Sub(C,D) (for interleaved ordering)
+// C=MklConv2D(A,B,M,N); E=Sub(C,D) (for contiguous ordering)
 TEST_F(MklToTfConversionPass, Positive) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
-      "node { name: 'D' op: 'Input'}"
-      "node { name: 'E' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['C', 'D']}");
-  EXPECT_EQ(DoRunMklToTfConversionPass(),
-            "A(Input);B(Input);C(MklConv2D);D(Input);E(Sub);M(MklInput);"
-            "Mkl2Tf/_0(MklToTf);N(MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
-            "C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
+  if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+    InitGraph(
+        "node { name: 'A' op: 'Input'}"
+        "node { name: 'M' op: '_MklInput'}"
+        "node { name: 'B' op: 'Input'}"
+        "node { name: 'N' op: '_MklInput'}"
+        "node { name: 'C' op: '_MklConv2D'"
+        " attr { key: 'T'                value { type: DT_FLOAT } }"
+        " attr { key: 'data_format'      value { s: 'NCHW' } }"
+        " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+        " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } "
+        "}"
+        " attr { key: 'padding'          value { s: 'SAME' } }"
+        " input: ['A', 'M', 'B', 'N']}"
+        "node { name: 'D' op: 'Input'}"
+        "node { name: 'E' op: 'Sub'"
+        " attr {key: 'T'                 value { type: DT_FLOAT } }"
+        " input: ['C', 'D']}");
+    EXPECT_EQ(DoRunMklToTfConversionPass(),
+              "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
+              "_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
+              "C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
+  } else {
+    CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+    InitGraph(
+        "node { name: 'A' op: 'Input'}"
+        "node { name: 'B' op: 'Input'}"
+        "node { name: 'M' op: '_MklInput'}"
+        "node { name: 'N' op: '_MklInput'}"
+        "node { name: 'C' op: '_MklConv2D'"
+        " attr { key: 'T'                value { type: DT_FLOAT } }"
+        " attr { key: 'data_format'      value { s: 'NCHW' } }"
+        " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+        " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } "
+        "}"
+        " attr { key: 'padding'          value { s: 'SAME' } }"
+        " input: ['A', 'B', 'M', 'N']}"
+        "node { name: 'D' op: 'Input'}"
+        "node { name: 'E' op: 'Sub'"
+        " attr {key: 'T'                 value { type: DT_FLOAT } }"
+        " input: ['C', 'D']}");
+    EXPECT_EQ(DoRunMklToTfConversionPass(),
+              "A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
+              "_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
+              "C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
+  }
 }
 
 // MklConv2D followed by MklToTf op followed by Non-Mkl layer.
-// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E)
+// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for interleaved)
+// C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for contiguous)
 // MklToTf node should not be inserted again.
 TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
-  InitGraph(
-      "node { name: 'A' op: 'Input'}"
-      "node { name: 'M' op: 'MklInput'}"
-      "node { name: 'B' op: 'Input'}"
-      "node { name: 'N' op: 'MklInput'}"
-      "node { name: 'C' op: 'MklConv2D'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
-      " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } }"
-      " attr { key: 'padding'          value { s: 'SAME' } }"
-      " input: ['A', 'M', 'B', 'N']}"
-      "node { name: 'D' op: 'MklToTf'"
-      " attr { key: 'T'                value { type: DT_FLOAT } }"
-      " attr { key: 'data_format'      value { s: 'NCHW' } }"
-      " input: ['C:0', 'C:1']}"
-      "node { name: 'E' op: 'Input'}"
-      "node { name: 'F' op: 'Sub'"
-      " attr {key: 'T'                 value { type: DT_FLOAT } }"
-      " input: ['D', 'E']}");
-  EXPECT_EQ(DoRunMklToTfConversionPass(),
-            "A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);"
-            "F(Sub);M(MklInput);N(MklInput)|"
-            "A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
+  if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+    InitGraph(
+        "node { name: 'A' op: 'Input'}"
+        "node { name: 'M' op: '_MklInput'}"
+        "node { name: 'B' op: 'Input'}"
+        "node { name: 'N' op: '_MklInput'}"
+        "node { name: 'C' op: '_MklConv2D'"
+        " attr { key: 'T'                value { type: DT_FLOAT } }"
+        " attr { key: 'data_format'      value { s: 'NCHW' } }"
+        " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+        " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } "
+        "}"
+        " attr { key: 'padding'          value { s: 'SAME' } }"
+        " input: ['A', 'M', 'B', 'N']}"
+        "node { name: 'D' op: '_MklToTf'"
+        " attr { key: 'T'                value { type: DT_FLOAT } }"
+        " attr { key: 'data_format'      value { s: 'NCHW' } }"
+        " input: ['C:0', 'C:1']}"
+        "node { name: 'E' op: 'Input'}"
+        "node { name: 'F' op: 'Sub'"
+        " attr {key: 'T'                 value { type: DT_FLOAT } }"
+        " input: ['D', 'E']}");
+    EXPECT_EQ(DoRunMklToTfConversionPass(),
+              "A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
+              "F(Sub);M(_MklInput);N(_MklInput)|"
+              "A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
+  } else {
+    CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+    InitGraph(
+        "node { name: 'A' op: 'Input'}"
+        "node { name: 'B' op: 'Input'}"
+        "node { name: 'M' op: '_MklInput'}"
+        "node { name: 'N' op: '_MklInput'}"
+        "node { name: 'C' op: '_MklConv2D'"
+        " attr { key: 'T'                value { type: DT_FLOAT } }"
+        " attr { key: 'data_format'      value { s: 'NCHW' } }"
+        " attr { key: 'use_cudnn_on_gpu' value { b: false } }"
+        " attr { key: 'strides'          value { list: {i: 1, i:1, i:1, i:1} } "
+        "}"
+        " attr { key: 'padding'          value { s: 'SAME' } }"
+        " input: ['A', 'B', 'M', 'N']}"
+        "node { name: 'D' op: '_MklToTf'"
+        " attr { key: 'T'                value { type: DT_FLOAT } }"
+        " attr { key: 'data_format'      value { s: 'NCHW' } }"
+        " input: ['C:0', 'C:1']}"
+        "node { name: 'E' op: 'Input'}"
+        "node { name: 'F' op: 'Sub'"
+        " attr {key: 'T'                 value { type: DT_FLOAT } }"
+        " input: ['D', 'E']}");
+    EXPECT_EQ(DoRunMklToTfConversionPass(),
+              "A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
+              "F(Sub);M(_MklInput);N(_MklInput)|"
+              "A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3");
+  }
 }
 
 // C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc
index b5497d35947..078fb10bc95 100644
--- a/tensorflow/core/grappler/optimizers/auto_parallel.cc
+++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc
@@ -230,7 +230,7 @@ void AutoParallel::BuildGraph(GraphDef* graph) {
     AddOneReplica(graph, i);
   }
   std::set<string> fetches;
-  for (int i = 0; i < item_->fetch.size(); i++) {
+  for (size_t i = 0; i < item_->fetch.size(); i++) {
     for (int j = 0; j < num_replicas_; j++) {
       string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
       string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index e32f51a3a2a..49b12df7aa9 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4828,6 +4828,38 @@ tf_mkl_kernel_library(
     ],
 )
 
+tf_mkl_kernel_library(
+    name = "mkl_fused_batch_norm_op",
+    srcs = ["mkl_fused_batch_norm_op.cc"],
+    deps = NN_DEPS + [
+        "//third_party/mkl:intel_binary_blob",
+    ],
+)
+
+tf_mkl_kernel_library(
+    name = "mkl_concat_op",
+    prefix = "mkl_concat_op",
+    deps = ARRAY_DEPS + [
+        "//third_party/mkl:intel_binary_blob",
+    ],
+)
+
+tf_mkl_kernel_library(
+    name = "mkl_reshape_op",
+    prefix = "mkl_reshape_op",
+    deps = ARRAY_DEPS + [
+        "//third_party/mkl:intel_binary_blob",
+    ],
+)
+
+tf_mkl_kernel_library(
+    name = "mkl_lrn_op",
+    prefix = "mkl_lrn_op",
+    deps = NN_DEPS + [
+        "//third_party/mkl:intel_binary_blob",
+    ],
+)
+
 # -----------------------------------------------------------------------------
 # Google-internal targets.  These must be at the end for syncrepo.
 
diff --git a/tensorflow/core/kernels/fixed_length_record_reader_op.cc b/tensorflow/core/kernels/fixed_length_record_reader_op.cc
index 637a6cef95d..ce7fb9c332b 100644
--- a/tensorflow/core/kernels/fixed_length_record_reader_op.cc
+++ b/tensorflow/core/kernels/fixed_length_record_reader_op.cc
@@ -28,12 +28,14 @@ namespace tensorflow {
 class FixedLengthRecordReader : public ReaderBase {
  public:
   FixedLengthRecordReader(const string& node_name, int64 header_bytes,
-                          int64 record_bytes, int64 footer_bytes, Env* env)
+                          int64 record_bytes, int64 footer_bytes,
+                          int64 hop_bytes, Env* env)
       : ReaderBase(
             strings::StrCat("FixedLengthRecordReader '", node_name, "'")),
         header_bytes_(header_bytes),
         record_bytes_(record_bytes),
         footer_bytes_(footer_bytes),
+        hop_bytes_(hop_bytes),
         env_(env),
         file_pos_limit_(-1),
         record_number_(0) {}
@@ -62,14 +64,31 @@ class FixedLengthRecordReader : public ReaderBase {
 
   Status ReadLocked(string* key, string* value, bool* produced,
                     bool* at_end) override {
-    if (input_buffer_->Tell() >= file_pos_limit_) {
+    // The condition `input_buffer_->Tell() + record_bytes_ > file_pos_limit_`
+    // is to confirm that none of record bytes is out of the range of
+    // file_pos_limit_.
+    // This is necessary for the condition `hop_bytes > 0`. For example.
+    // File: "0123456"
+    // Reader setting: `record_bytes=3`, `hop_bytes=2`, `footer_bytes=0`,
+    //     `header_bytes=0`
+    // Without this checking condition, the forth time the reader will at
+    // this position: "012345|6" and the reading operation will result in
+    // an error.
+    if (input_buffer_->Tell() >= file_pos_limit_ ||
+        input_buffer_->Tell() + record_bytes_ > file_pos_limit_) {
       *at_end = true;
       return Status::OK();
     }
+    const int64 pos_before_read = input_buffer_->Tell();
     TF_RETURN_IF_ERROR(input_buffer_->ReadNBytes(record_bytes_, value));
     *key = strings::StrCat(current_work(), ":", record_number_);
     *produced = true;
     ++record_number_;
+
+    if (hop_bytes_ > 0) {
+      input_buffer_->Seek(pos_before_read + hop_bytes_).IgnoreError();
+    }
+
     return Status::OK();
   }
 
@@ -87,6 +106,7 @@ class FixedLengthRecordReader : public ReaderBase {
   const int64 header_bytes_;
   const int64 record_bytes_;
   const int64 footer_bytes_;
+  const int64 hop_bytes_;
   Env* const env_;
   int64 file_pos_limit_;
   int64 record_number_;
@@ -98,10 +118,12 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel {
  public:
   explicit FixedLengthRecordReaderOp(OpKernelConstruction* context)
       : ReaderOpKernel(context) {
-    int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1;
+    int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1,
+          hop_bytes = -1;
     OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes));
     OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes));
     OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes));
+    OP_REQUIRES_OK(context, context->GetAttr("hop_bytes", &hop_bytes));
     OP_REQUIRES(context, header_bytes >= 0,
                 errors::InvalidArgument("header_bytes must be >= 0 not ",
                                         header_bytes));
@@ -111,11 +133,15 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel {
     OP_REQUIRES(context, footer_bytes >= 0,
                 errors::InvalidArgument("footer_bytes must be >= 0 not ",
                                         footer_bytes));
+    OP_REQUIRES(
+        context, hop_bytes >= 0,
+        errors::InvalidArgument("hop_bytes must be >= 0 not ", hop_bytes));
     Env* env = context->env();
-    SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, env]() {
-      return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
-                                         footer_bytes, env);
-    });
+    SetReaderFactory(
+        [this, header_bytes, record_bytes, footer_bytes, hop_bytes, env]() {
+          return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
+                                             footer_bytes, hop_bytes, env);
+        });
   }
 };
 
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
index 71918fe269c..8bd1724e321 100644
--- a/tensorflow/core/kernels/mkl_avgpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -29,10 +29,9 @@ namespace tensorflow {
 typedef Eigen::ThreadPoolDevice CPUDevice;
 
 template <typename Device, typename T>
-class MklAvgPoolingOp : public UnaryOp<T> {
+class MklAvgPoolingOp : public OpKernel {
  public:
-  explicit MklAvgPoolingOp(OpKernelConstruction* context)
-      : UnaryOp<T>(context) {
+  explicit MklAvgPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
     string data_format;
     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
@@ -78,6 +77,7 @@ class MklAvgPoolingOp : public UnaryOp<T> {
     Tensor mkl_tmp_input_buf_tensor_;
     mkl_context.MklCreateLayoutsAndPrimitives(context,
                                               &mkl_tmp_input_buf_tensor_);
+    OP_REQUIRES_OK(context, context->status());
 
     Tensor workspace_tensor;
     void* workspace_buf;
@@ -120,7 +120,7 @@ class MklAvgPoolingOp : public UnaryOp<T> {
                                 mkl_out_shape.GetMklLayout())) /
                             sizeof(T));
 
-    AllocateOutputSetMklshape(context, 0, &output, tensor_out_shape,
+    AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
                               mkl_out_shape);
     mkl_context.pooling_res[dnnResourceDst] =
         static_cast<void*>(output->flat<T>().data());
@@ -138,9 +138,10 @@ class MklAvgPoolingOp : public UnaryOp<T> {
   typedef struct {
     MklPoolingOpParams params;
     MklShape input_shape;
-    dnnPrimitive_t prim_pooling_fwd, convert_input;
-    dnnLayout_t lt_user_input, lt_prim_input, lt_workspace;
-    void* input_buf;
+    dnnPrimitive_t prim_pooling_fwd = nullptr, convert_input = nullptr;
+    dnnLayout_t lt_user_input = nullptr, lt_prim_input = nullptr,
+                lt_workspace = nullptr;
+    void* input_buf = nullptr;
     void* pooling_res[dnnResourceNumber];
 
     void MklCreateLayoutsAndPrimitives(OpKernelContext* context,
@@ -243,6 +244,11 @@ class MklAvgPoolingGradOp : public OpKernel {
     pool_params.Init(context, ksize_, stride_, padding_, data_format_,
                      output_shape);
 
+    if (outbackprop_in_mkl_format == false)
+      mkl_context.params.in_dim = out_backprop.dims();
+    else
+      mkl_context.params.in_dim = mkl_context.out_backprop_shape.GetDimension();
+
     // Extract the parameters for the op from the pooling specs
     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
 
@@ -250,6 +256,7 @@ class MklAvgPoolingGradOp : public OpKernel {
     Tensor outbackprop_buf_tensor;
     void* outbackprop_buf;
     mkl_context.MklCreateLayoutsAndPrimitives(context);
+    OP_REQUIRES_OK(context, context->status());
 
     // Check if outbackprop layout requires conversion.
     if (!dnnLayoutCompare_F32(mkl_context.lt_user_outbackprop,
@@ -304,7 +311,7 @@ class MklAvgPoolingGradOp : public OpKernel {
                                 mkl_out_shape.GetMklLayout())) /
                             sizeof(T));
 
-    AllocateOutputSetMklshape(context, 0, &output, tensor_out_shape,
+    AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
                               mkl_out_shape);
 
     // Set output tensor.
@@ -323,10 +330,10 @@ class MklAvgPoolingGradOp : public OpKernel {
   typedef struct {
     MklPoolingOpParams params;
     MklShape out_backprop_shape;
-    dnnPrimitive_t prim_pooling_bwd, convert_outbackprop;
+    dnnPrimitive_t prim_pooling_bwd = nullptr, convert_outbackprop = nullptr;
     void* pooling_res[dnnResourceNumber];
-    dnnLayout_t lt_user_input, lt_user_outbackprop, lt_prim_outbackprop,
-        lt_workspace;
+    dnnLayout_t lt_user_input = nullptr, lt_user_outbackprop = nullptr,
+                lt_prim_outbackprop = nullptr, lt_workspace = nullptr;
 
     void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
       const Tensor& tensor_in_shape = MklGetInput(context, 0);
@@ -348,11 +355,6 @@ class MklAvgPoolingGradOp : public OpKernel {
                                             "4-dimensional"));
       } else {
         // Input in MKL format.
-        OP_REQUIRES(
-            context, out_backprop.dims() == 2,
-            errors::InvalidArgument("out_backprop in MKL format must be "
-                                    "2-dimensional"));
-
         // For avgpooling, out_backprop should have 4 dimensions.
         OP_REQUIRES(context, out_backprop_shape.GetDimension() == 4,
                     errors::InvalidArgument("out_backprop must be "
@@ -412,16 +414,16 @@ class MklAvgPoolingGradOp : public OpKernel {
   TensorFormat data_format_;
 };
 
-REGISTER_KERNEL_BUILDER(Name("MklAvgPool")
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
                             .Device(DEVICE_CPU)
                             .TypeConstraint<float>("T")
-                            .Label(mkl_layer_registry::kMklLayerLabel),
+                            .Label(mkl_op_registry::kMklOpLabel),
                         MklAvgPoolingOp<CPUDevice, float>);
 
-REGISTER_KERNEL_BUILDER(Name("MklAvgPoolGrad")
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad")
                             .Device(DEVICE_CPU)
                             .TypeConstraint<float>("T")
-                            .Label(mkl_layer_registry::kMklLayerLabel),
+                            .Label(mkl_op_registry::kMklOpLabel),
                         MklAvgPoolingGradOp<CPUDevice, float>);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
new file mode 100644
index 00000000000..27930c44a65
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -0,0 +1,458 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+#include <limits>
+#include <vector>
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
+
+// TODO(intelft) Check if we can reuse existing EigenConcatOp using Mutable
+// reference inputs.
+// --------------------------------------------------------------------------
+//                      Eigen Concat Op
+// --------------------------------------------------------------------------
+template <typename Device, typename T, AxisArgumentName AxisArgName>
+class EigenConcatBaseOp : public OpKernel {
+ public:
+  typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+      ConstMatrixVector;
+
+  explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
+
+  // Although, we modify Compute for this call to accept one extra param,
+  // we need to have empty Compute because Compute is pure virtual function.
+  void Compute(OpKernelContext* c) {}
+
+  void Compute(OpKernelContext* c, const std::vector<Tensor>& values) {
+    const Tensor* concat_dim_tensor;
+    const char* axis_attribute_name =
+        AxisArgName == NAME_IS_AXIS
+            ? "axis"
+            : AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
+    OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
+    OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
+                errors::InvalidArgument(
+                    axis_attribute_name,
+                    " tensor should be a scalar integer, but got shape ",
+                    concat_dim_tensor->shape().DebugString()));
+    const int32 concat_dim =
+        internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
+    // Instead of accessing values from context, we use input to Compute.
+    const int N = values.size();
+    const int input_dims = values[0].dims();
+    const TensorShape& input_shape = values[0].shape();
+
+    int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
+    OP_REQUIRES(c,
+                (0 <= axis && axis < input_dims) ||
+                    (allow_legacy_scalars() && concat_dim == 0),
+                errors::InvalidArgument(
+                    "ConcatOp : Expected concatenating dimensions in the range "
+                    "[",
+                    -input_dims, ", ", input_dims, "), but got ", concat_dim));
+    // Note that we reduce the concat of n-dimensional tensors into a two
+    // dimensional concat. Assuming the dimensions of any input/output
+    // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
+    // the dimension indicated with size y0, we flatten it to {x, y}, where y =
+    // Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
+    ConstMatrixVector inputs_flat;
+    inputs_flat.reserve(N);
+    int64 inputs_flat_dim0 = 1;
+    for (int d = 0; d < axis; ++d) {
+      inputs_flat_dim0 *= input_shape.dim_size(d);
+    }
+    int64 output_concat_dim = 0;
+    const bool input_is_scalar = IsLegacyScalar(input_shape);
+    for (int i = 0; i < N; ++i) {
+      const auto in = values[i];
+      const bool in_is_scalar = IsLegacyScalar(in.shape());
+      OP_REQUIRES(
+          c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
+          errors::InvalidArgument(
+              "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
+              input_shape.DebugString(), " vs. shape[", i,
+              "] = ", in.shape().DebugString()));
+      for (int j = 0; j < input_dims; ++j) {
+        if (j == axis) {
+          continue;
+        }
+        OP_REQUIRES(
+            c, in.dim_size(j) == input_shape.dim_size(j),
+            errors::InvalidArgument(
+                "ConcatOp : Dimensions of inputs should match: shape[0] = ",
+                input_shape.DebugString(), " vs. shape[", i,
+                "] = ", in.shape().DebugString()));
+      }
+      if (in.NumElements() > 0) {
+        int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
+        inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+            in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
+      }
+      // TODO(irving): Remove check once !allow_legacy_scalars().
+      output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
+    }
+
+    TensorShape output_shape(input_shape);
+    // TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
+    if (output_shape.dims() == 0) {
+      output_shape.AddDim(output_concat_dim);
+    } else {
+      output_shape.set_dim(axis, output_concat_dim);
+    }
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
+    if (output->NumElements() > 0) {
+      int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
+      auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
+      ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
+    }
+  }
+};
+
+// --------------------------------------------------------------------------
+//                      Mkl Concat Op
+// --------------------------------------------------------------------------
+
+template <typename Device, typename T, AxisArgumentName AxisArgName>
+class MklConcatOp : public OpKernel {
+ private:
+  TensorFormat data_format_;
+  EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_;
+
+ public:
+  typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+      ConstMatrixVector;
+
+  explicit MklConcatOp(OpKernelConstruction* c)
+      : OpKernel(c), eigen_concat_op_(c) {}
+
+  void Compute(OpKernelContext* context) override {
+    MklConcatOpContext mkl_context;
+
+    // Get input tensors.
+    OpInputList input_tensors;
+    GetMklInputList(context, "values", &input_tensors);
+    const int N = input_tensors.size();
+    // Get MKL shapes.
+    MklShapeList input_shapes(N);
+    GetMklShapeList(context, "values", &input_shapes);
+
+    // If this is Concat, then concat_dim is 0th input.
+    // If this is ConcatV2, then axis is Nth input.
+    const Tensor& concat_dim_tensor = AxisArgName == NAME_IS_CONCAT_DIM
+                                          ? MklGetInput(context, 0)
+                                          : MklGetInput(context, N);
+
+    // Sanity checks
+    OP_REQUIRES(
+        context, IsLegacyScalar(concat_dim_tensor.shape()),
+        errors::InvalidArgument(
+            "Concat dim tensor should be a scalar integer, but got shape ",
+            concat_dim_tensor.shape().DebugString()));
+    int32 concat_dim =
+        internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
+
+    MklShape& inpshape0 = input_shapes[0];
+
+    // Check that all tensors are Mkl, if not we call Eigen version.
+    bool invoke_eigen = false;
+    bool is_concat_dim_channel = true;
+    if (!AreAllMklTensors(input_shapes)) {
+      invoke_eigen = true;
+    }
+
+    // Check that total number of dimensions is 4, if not call Eigen.
+    if (!invoke_eigen) {
+      for (auto& s : input_shapes) {
+        if (s.GetDimension() != 4) {
+          invoke_eigen = true;
+          break;
+        }
+      }
+    }
+
+    // check that concat_dim is channel, if not call Eigen version.
+    if (!invoke_eigen) {
+      for (auto& s : input_shapes) {
+        if (!s.IsMklChannelDim(concat_dim)) {
+          invoke_eigen = true;
+          is_concat_dim_channel = false;
+          break;
+        }
+      }
+    }
+
+    if (invoke_eigen) {
+      string msg = std::string("Invoking Eigen version of Concat. Reason:") +
+                   (!is_concat_dim_channel
+                        ? std::string("Concat dimension is not channel")
+                        : std::string("Not all tensors are in Mkl layout"));
+      VLOG(1) << "_MklConcatOp: " << msg;
+      CallEigenVersion(context, input_tensors, input_shapes);
+      return;
+    }
+
+    // For MKL format, the channel is dimension number 2.
+    // So if we are concating over channel and _all_ inputs are in MKL
+    // format, then we set concat_dim to 2.
+    // Since we have reached till here, it means we are concating
+    // over channel.
+    concat_dim = MklDims::C;
+
+    // One more sanity check: check that ranks of all tensors match
+    // and that their shapes match except for concat_dim.
+    int i = 0;
+    for (auto& s : input_shapes) {
+      size_t exp_dims = inpshape0.GetDimension();
+      OP_REQUIRES(context, s.GetDimension() == exp_dims,
+                  errors::InvalidArgument(
+                      "_MklConcatOp : Ranks of all input tensors should match:"
+                      " input dimensions = ",
+                      s.GetDimension(), " vs. expected rank = ", exp_dims));
+
+      for (int d = 0; d < exp_dims; ++d) {
+        if (d == concat_dim) {
+          continue;
+        }
+
+        size_t exp_size = inpshape0.GetSizes()[d];
+        OP_REQUIRES(
+            context, exp_size == s.GetSizes()[d],
+            errors::InvalidArgument("_MklConcatOp : Dimensions of inputs"
+                                    "should match: shape[0][",
+                                    d, "]= ", exp_size, " vs. shape[", i, "][",
+                                    d, "] = ", s.GetSizes()[d]));
+      }
+      ++i;
+    }
+
+    // Use input MKL layout instead of creating new layouts.
+    int64 output_concat_dim_size = 0;
+    for (auto& s : input_shapes) {
+      output_concat_dim_size +=
+          s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1;
+    }
+    mkl_context.MklCreateInputLayouts(context, input_shapes);
+
+    CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N,
+                                 &mkl_context.lt_inputs[0]),
+             E_SUCCESS);
+
+    // Calculate output sizes and strides
+    TensorFormat data_format;
+    if (inpshape0.IsTensorInNHWCFormat()) {
+      data_format = FORMAT_NHWC;
+    } else {
+      OP_REQUIRES(
+          context, inpshape0.IsTensorInNCHWFormat(),
+          errors::InvalidArgument(
+              "_MklConcat only supports all inputs in NCHW or NHWC format "));
+      data_format = FORMAT_NCHW;
+    }
+
+    // Since all tensors are in Mkl layout, we copy sizes from input tensor.
+    mkl_context.out_sizes[MklDims::W] = inpshape0.GetSizes()[MklDims::W];
+    mkl_context.out_sizes[MklDims::H] = inpshape0.GetSizes()[MklDims::H];
+    mkl_context.out_sizes[MklDims::C] = output_concat_dim_size;
+    mkl_context.out_sizes[MklDims::N] = inpshape0.GetSizes()[MklDims::N];
+    GetStridesFromSizes(data_format, mkl_context.out_strides,
+                        mkl_context.out_sizes);
+
+    // Set output Mkl shape.
+    int64 dim = 4;
+    MklShape mkl_output_mkl_shape;
+    mkl_output_mkl_shape.SetMklTensor(true);
+    mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_concat, dnnResourceDst);
+    mkl_output_mkl_shape.SetTfLayout(dim, mkl_context.out_sizes,
+                                     mkl_context.out_strides);
+    mkl_output_mkl_shape.SetTfDimOrder(dim, inpshape0.GetTfToMklDimMap());
+
+    TensorShape mkl_output_tf_shape;
+    mkl_output_tf_shape.AddDim(1);
+    mkl_output_tf_shape.AddDim(
+        dnnLayoutGetMemorySize_F32(
+            static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
+        sizeof(T));
+
+    Tensor* output = nullptr;
+    AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
+                              mkl_output_mkl_shape);
+
+    // Set destination resource.
+    mkl_context.concat_res[dnnResourceDst] =
+        const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
+
+    mkl_context.mkl_tmp_tensors.resize(N);
+    mkl_context.MklPrepareConcatInputs(context, input_tensors);
+
+    // Execute primitive.
+    CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res),
+             E_SUCCESS);
+
+    mkl_context.MklCleanup();
+  }
+
+ private:
+  typedef struct {
+    TensorFormat data_format;
+    size_t out_sizes[4];
+    size_t out_strides[4];
+    dnnPrimitive_t prim_concat;
+    void* concat_res[dnnResourceNumber];
+    std::vector<dnnLayout_t> lt_inputs;
+    std::vector<Tensor> mkl_tmp_tensors;
+
+    // Create MKL dnnLayout_t objects for tensors coming into the layer
+    // We only support case where input tensors are all in Mkl layout.
+    void MklCreateInputLayouts(OpKernelContext* context,
+                               MklShapeList& input_shapes) {
+      for (auto& is : input_shapes) {
+        CHECK_EQ(is.IsMklTensor(), true);
+        lt_inputs.push_back((dnnLayout_t)is.GetCurLayout());
+      }
+    }
+
+    void MklPrepareConcatInputs(OpKernelContext* context,
+                                OpInputList& input_tensors) {
+      CHECK_EQ(lt_inputs.size(), mkl_tmp_tensors.size());
+
+      for (int i = 0; i < lt_inputs.size(); ++i) {
+        dnnPrimitive_t mkl_prim_convert_input;
+        dnnLayout_t mkl_lt_internal_input;
+        void* mkl_buf_convert_input = nullptr;
+
+        CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
+                     &mkl_lt_internal_input, prim_concat,
+                     (dnnResourceType_t)(dnnResourceMultipleSrc + i)),
+                 E_SUCCESS);
+
+        if (!dnnLayoutCompare_F32(lt_inputs[i], mkl_lt_internal_input)) {
+          CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
+                                           lt_inputs[i], mkl_lt_internal_input),
+                   E_SUCCESS);
+
+          AllocTmpBuffer(context, &mkl_tmp_tensors[i], mkl_lt_internal_input,
+                         &mkl_buf_convert_input);
+
+          CHECK_EQ(dnnConversionExecute_F32(
+                       mkl_prim_convert_input,
+                       const_cast<void*>(static_cast<const void*>(
+                           input_tensors[i].flat<T>().data())),
+                       mkl_buf_convert_input),
+                   E_SUCCESS);
+
+          concat_res[dnnResourceMultipleSrc + i] = mkl_buf_convert_input;
+          CHECK_EQ(dnnDelete_F32(mkl_prim_convert_input), E_SUCCESS);
+        } else {
+          concat_res[dnnResourceMultipleSrc + i] = const_cast<void*>(
+              static_cast<const void*>(input_tensors[i].flat<T>().data()));
+        }
+
+        CHECK_EQ(dnnLayoutDelete_F32(mkl_lt_internal_input), E_SUCCESS);
+      }
+    }
+
+    void MklCleanup() {
+      for (auto& lt : lt_inputs) {
+        lt = nullptr;
+      }
+      CHECK_EQ(dnnDelete_F32(prim_concat), E_SUCCESS);
+    }
+  } MklConcatOpContext;
+
+  void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
+                        const MklShapeList& input_shapes) {
+    // Before calling Eigen version, we need to convert Mkl tensors to TF.
+    // First check that the number of input tensors and the number of Mkl
+    // shapes match.
+    CHECK_EQ(values.size(), input_shapes.size());
+
+    std::vector<Tensor> converted_values;
+    for (int i = 0; i < input_shapes.size(); i++) {
+      if (input_shapes[i].IsMklTensor()) {
+        // If input tensor is Mkl, then do the conversion.
+        Tensor tmp_tensor =
+            ConvertMklToTF<T>(context, values[i], input_shapes[i]);
+        converted_values.push_back(tmp_tensor);
+      } else {
+        // If input tensor is TF already, then we do not need any conversion.
+        converted_values.push_back(values[i]);
+      }
+    }
+
+    // Call Eigen concat.
+    eigen_concat_op_.Compute(context, converted_values);
+
+    // Set dummy Mkl tensor as output Mkl tensor for this op.
+    MklShape mkl_tensor_mkl_shape;
+    mkl_tensor_mkl_shape.SetMklTensor(false);
+    mkl_tensor_mkl_shape.SetDimensions(4);
+    mkl_tensor_mkl_shape.SetTfDimOrder(4);  // Dimensions
+    Tensor* mkl_tensor = nullptr;
+    TensorShape mkl_tensor_tf_shape;
+    mkl_tensor_tf_shape.AddDim(
+        SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension()));
+    int tf_output_index = 0;
+    context->allocate_output(
+        GetTensorMetaDataIndex(tf_output_index, context->num_outputs()),
+        mkl_tensor_tf_shape, &mkl_tensor);
+    mkl_tensor_mkl_shape.SerializeMklShape(
+        mkl_tensor->flat<uint8>().data(),
+        mkl_tensor->flat<uint8>().size() * sizeof(uint8));
+  }
+};
+
+/* Use optimized concat for float type only */
+#define REGISTER_MKL_CPU(type)                                              \
+  REGISTER_KERNEL_BUILDER(Name("_MklConcat")                                \
+                              .Device(DEVICE_CPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .HostMemory("concat_dim")                     \
+                              .Label(mkl_op_registry::kMklOpLabel),         \
+                          MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \
+  REGISTER_KERNEL_BUILDER(Name("_MklConcatV2")                               \
+                              .Device(DEVICE_CPU)                           \
+                              .TypeConstraint<type>("T")                    \
+                              .TypeConstraint<int32>("Tidx")                \
+                              .HostMemory("axis")                           \
+                              .Label(mkl_op_registry::kMklOpLabel),         \
+                          MklConcatOp<CPUDevice, type, NAME_IS_AXIS>)
+
+TF_CALL_float(REGISTER_MKL_CPU);
+
+#undef REGISTER_CONCAT_MKL
+}  // namespace tensorflow
+
+#endif  // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
index 627fd83b0d7..8a1006a8e95 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
@@ -87,7 +87,7 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel {
     Tensor* bias_backprop = nullptr;
     MklShape output_mkl_shape;
     output_mkl_shape.SetMklTensor(false);
-    AllocateOutputSetMklshape(context, 0, &bias_backprop, output_shape,
+    AllocateOutputSetMklShape(context, 0, &bias_backprop, output_shape,
                               output_mkl_shape);
 
     mkl_context.in_dims = 4;
@@ -251,11 +251,11 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel {
   TF_DISALLOW_COPY_AND_ASSIGN(MklConv2DCustomBackpropBiasOp);
 };
 
-#define REGISTER_CPU_KERNELS(T)                                           \
-  REGISTER_KERNEL_BUILDER(Name("MklConv2DWithBiasBackpropBias")           \
-                              .Device(DEVICE_CPU)                         \
-                              .TypeConstraint<T>("T")                     \
-                              .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_CPU_KERNELS(T)                                     \
+  REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias")     \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
                           MklConv2DCustomBackpropBiasOp<CPUDevice, T>);
 
 TF_CALL_float(REGISTER_CPU_KERNELS);
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 85198d89f56..6381b527a1b 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -217,7 +217,7 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
     mkl_context.grad_filter_shape.SetTfLayout(mkl_context.filter_dims,
                                               mkl_context.filter_sizes,
                                               mkl_context.filter_strides);
-    AllocateOutputSetMklshape(context, 0, &grad_filter, filter_shape,
+    AllocateOutputSetMklShape(context, 0, &grad_filter, filter_shape,
                               mkl_context.grad_filter_shape);
 
     // Need to set member variable for TF layout
@@ -408,11 +408,11 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
   TensorFormat data_format_;
 };
 
-#define REGISTER_MKL_FILTER_KERNELS(T)                                    \
-  REGISTER_KERNEL_BUILDER(Name("MklConv2DBackpropFilter")                 \
-                              .Device(DEVICE_CPU)                         \
-                              .TypeConstraint<T>("T")                     \
-                              .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_MKL_FILTER_KERNELS(T)                              \
+  REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter")           \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
                           MklConv2DCustomBackpropFilterOp<CPUDevice, T>);
 
 TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index c7d95c86bcd..638ce4c0243 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -202,7 +202,7 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
     mkl_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
                              mklOutputShape.GetMklLayout())) /
                          sizeof(T));
-    AllocateOutputSetMklshape(context, 0, &in_backprop, mkl_out_shape,
+    AllocateOutputSetMklShape(context, 0, &in_backprop, mkl_out_shape,
                               mklOutputShape);
 
     mkl_context.conv_res[dnnResourceDiffSrc] =
@@ -341,11 +341,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
   TensorFormat data_format;
 };
 
-#define REGISTER_MKL_CPU_KERNELS(T)                                       \
-  REGISTER_KERNEL_BUILDER(Name("MklConv2DBackpropInput")                  \
-                              .Device(DEVICE_CPU)                         \
-                              .TypeConstraint<T>("T")                     \
-                              .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_MKL_CPU_KERNELS(T)                                 \
+  REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput")            \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
                           MklConv2DCustomBackpropInputOp<CPUDevice, T>);
 
 TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index e5c4c21a10a..b818819b020 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -178,7 +178,7 @@ class MklConv2DOp : public OpKernel {
       // Nothing to do, allocate output tensor and return
       MklShape mkl_output_mkl_shape;
       mkl_output_mkl_shape.SetMklTensor(false);
-      AllocateOutputSetMklshape(context, 0, &output, input.shape(),
+      AllocateOutputSetMklShape(context, 0, &output, input.shape(),
                                 mkl_output_mkl_shape);
       return;
     }
@@ -264,7 +264,7 @@ class MklConv2DOp : public OpKernel {
         dnnLayoutGetMemorySize_F32(
             static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
         sizeof(T));
-    AllocateOutputSetMklshape(context, 0, &output, mkl_output_tf_shape,
+    AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
                               mkl_output_mkl_shape);
     mkl_context.conv_res[dnnResourceDst] =
         static_cast<void*>(output->flat<T>().data());
@@ -437,16 +437,16 @@ class MklConv2DOp : public OpKernel {
   TensorFormat data_format_;
 };
 
-#define REGISTER_MKL_CPU(T)                                               \
-  REGISTER_KERNEL_BUILDER(Name("MklConv2D")                               \
-                              .Device(DEVICE_CPU)                         \
-                              .TypeConstraint<T>("T")                     \
-                              .Label(mkl_layer_registry::kMklLayerLabel), \
-                          MklConv2DOp<CPUDevice, T, false>);              \
-  REGISTER_KERNEL_BUILDER(Name("MklConv2DWithBias")                       \
-                              .Device(DEVICE_CPU)                         \
-                              .TypeConstraint<T>("T")                     \
-                              .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_MKL_CPU(T)                                         \
+  REGISTER_KERNEL_BUILDER(Name("_MklConv2D")                         \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
+                          MklConv2DOp<CPUDevice, T, false>);        \
+  REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias")                 \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
                           MklConv2DOp<CPUDevice, T, true>);
 
 TF_CALL_float(REGISTER_MKL_CPU);
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
new file mode 100644
index 00000000000..512e799d152
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -0,0 +1,689 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifdef INTEL_MKL
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+// TODO(inteltf) Address comments from PR 8968.
+
+namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+template <typename Device, typename T>
+class MklFusedBatchNormOp : public OpKernel {
+ public:
+  explicit MklFusedBatchNormOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    float epsilon;
+    OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
+    epsilon_ = T(epsilon);
+    string tensor_format;
+    OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
+    OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
+                errors::InvalidArgument("Invalid data format"));
+    OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    MklFusedBatchNormOpContext mkl_context;
+
+    const Tensor& input = MklGetInput(context, 0);
+    const Tensor& scale = MklGetInput(context, 1);
+    const Tensor& shift = MklGetInput(context, 2);
+    const Tensor& est_mean = MklGetInput(context, 3);
+    const Tensor& est_variance = MklGetInput(context, 4);
+
+    GetMklShape(context, 0, &(mkl_context.mkl_shape_input_shape));
+    bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor();
+    if (!input_in_mkl_format) {
+      OP_REQUIRES(context, input.dims() == 4,
+                  errors::InvalidArgument("input must be 4-dimensional",
+                                          input.shape().DebugString()));
+    }
+    OP_REQUIRES(context, scale.dims() == 1,
+                errors::InvalidArgument("scale must be 1-dimensional",
+                                        scale.shape().DebugString()));
+    OP_REQUIRES(context, shift.dims() == 1,
+                errors::InvalidArgument("offset must be 1-dimensional",
+                                        shift.shape().DebugString()));
+    OP_REQUIRES(context, est_mean.dims() == 1,
+                errors::InvalidArgument("estimated_mean must be 1-dimensional",
+                                        est_mean.shape().DebugString()));
+    OP_REQUIRES(
+        context, est_variance.dims() == 1,
+        errors::InvalidArgument("estimated_variance must be 1-dimensional",
+                                est_variance.shape().DebugString()));
+    if (is_training_) {
+      OP_REQUIRES(context, est_mean.dim_size(0) == 0,
+                  errors::InvalidArgument("estimated_mean empty for training",
+                                          est_mean.shape().DebugString()));
+      OP_REQUIRES(context, est_variance.dim_size(0) == 0,
+                  errors::InvalidArgument(
+                      "estimated_variance must be empty for training",
+                      est_variance.shape().DebugString()));
+    }
+
+    unsigned int flag_batch_norm =
+        is_training_ ? dnnUseScaleShift
+                     : (dnnUseInputMeanVariance | dnnUseScaleShift);
+
+    mkl_context.MklExtractParams(context, tensor_format_);
+
+    // Create layout only for input data as it is used in Op primitive.
+    mkl_context.MklCreateInputLayout(context);
+
+    // Create Op primitive.
+    CHECK_EQ(dnnBatchNormalizationCreateForward_v2_F32(
+                 &(mkl_context.mkl_prim_batchnorm), nullptr,
+                 mkl_context.mkl_lt_input, static_cast<float>(epsilon_),
+                 flag_batch_norm),
+             E_SUCCESS);
+
+    // Temporary tensors with buffers for the context inputs, if
+    // conversion to MKL-Op specific layouts are required. It is assumed here
+    // that TF's 1D tensors (scale, shift, est_mean, and est_variance) won't
+    // require any conversion.
+    // Since scale-shift is combined in MKL, a buffer is required.
+    Tensor mkl_tmp_input_buf_tensor, mkl_tmp_scale_shift_buf_tensor;
+    mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor,
+                                        &mkl_tmp_scale_shift_buf_tensor);
+
+    // Output data in MKL layout
+    Tensor* output = nullptr;
+    TensorShape tf_shape_output;
+    MklShape mkl_shape_output;
+    mkl_shape_output.SetMklTensor(true);
+    mkl_shape_output.SetMklLayout(mkl_context.mkl_prim_batchnorm,
+                                  dnnResourceDst);
+    mkl_shape_output.SetTfLayout(mkl_context.mkl_params.in_dim,
+                                 mkl_context.mkl_params.in_sizes,
+                                 mkl_context.mkl_params.in_strides);
+    mkl_shape_output.SetTfDimOrder(mkl_context.mkl_params.in_dim,
+                                   tensor_format_);
+    tf_shape_output.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+                               mkl_shape_output.GetMklLayout())) /
+                           sizeof(T));
+    AllocateOutputSetMklShape(context, 0, &output, tf_shape_output,
+                              mkl_shape_output);
+    mkl_context.mkl_res_batchnorm[dnnResourceDst] =
+        static_cast<void*>(output->flat<T>().data());
+
+    // Batch mean in TF layout
+    Tensor* batch_mean = nullptr;
+    MklShape mkl_shape_batch_mean;
+    mkl_shape_batch_mean.SetMklTensor(false);
+    AllocateOutputSetMklShape(context, 1, &batch_mean, scale.shape(),
+                              mkl_shape_batch_mean);
+    // Batch variance in TF layout
+    Tensor* batch_variance = nullptr;
+    MklShape mkl_shape_batch_variance;
+    mkl_shape_batch_variance.SetMklTensor(false);
+    AllocateOutputSetMklShape(context, 2, &batch_variance, scale.shape(),
+                              mkl_shape_batch_variance);
+    // If training mode, set dnnResourceMean and dnnResourceVariance to
+    // output tensors for batch mean and variance.
+    // Otherwise, set dnnResourceMean and dnnResourceVariance to
+    // estimated mean and variance.
+    if (is_training_)
+      mkl_context.MklSetMeanVariance(*batch_mean, *batch_variance);
+    else
+      mkl_context.MklSetMeanVariance(est_mean, est_variance);
+
+    // Now that all resources are set, it is ready for dnnExecute
+    CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm,
+                            mkl_context.mkl_res_batchnorm),
+             E_SUCCESS);
+
+    // Mean and variance (without Bessel's correction) saved for backward
+    // computation to serve as pre-computed mean and variance.
+    Tensor* saved_mean = nullptr;
+    MklShape mkl_shape_saved_mean;
+    mkl_shape_saved_mean.SetMklTensor(false);
+    AllocateOutputSetMklShape(context, 3, &saved_mean, scale.shape(),
+                              mkl_shape_saved_mean);
+    std::memcpy(
+        reinterpret_cast<char*>(saved_mean->flat<float>().data()),
+        reinterpret_cast<char*>(mkl_context.mkl_res_batchnorm[dnnResourceMean]),
+        scale.NumElements() * sizeof(float));
+    Tensor* saved_variance = nullptr;
+    MklShape mkl_shape_saved_variance;
+    mkl_shape_saved_variance.SetMklTensor(false);
+    AllocateOutputSetMklShape(context, 4, &saved_variance, scale.shape(),
+                              mkl_shape_saved_variance);
+    std::memcpy(reinterpret_cast<char*>(saved_variance->flat<float>().data()),
+                reinterpret_cast<char*>(
+                    mkl_context.mkl_res_batchnorm[dnnResourceVariance]),
+                scale.NumElements() * sizeof(float));
+
+    // Bessel's correction on variance, if training mode is on
+    if (is_training_) {
+      float* p_var = static_cast<float*>(batch_variance->flat<T>().data());
+      auto depth = mkl_context.mkl_params.depth;
+      size_t orig_size = mkl_context.mkl_params.in_sizes[0] *
+                         mkl_context.mkl_params.in_sizes[1] *
+                         mkl_context.mkl_params.in_sizes[3];
+      size_t adjust_size = orig_size - 1;
+      float adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
+      for (int i = 0; i < depth; i++) p_var[i] = adjust_factor * p_var[i];
+    }
+
+    mkl_context.MklCleanup();
+  }
+
+ private:
+  T epsilon_;
+  TensorFormat tensor_format_;
+  bool is_training_;
+
+  // Structure containing all info for MklOp
+  typedef struct {
+    // Parameters used for input and output layouts
+    struct MklBatchNormParams {
+      // BatchNormOp src and
+      size_t in_dim;
+      size_t in_sizes[4];
+      size_t in_strides[4];
+      size_t depth;  // Batch normalization is done for per channel.
+    } mkl_params;
+
+    MklShape mkl_shape_input_shape;
+
+    // MKL primitive and resources for BatchNormOp
+    dnnPrimitive_t mkl_prim_batchnorm = nullptr;
+    void* mkl_res_batchnorm[dnnResourceNumber];
+
+    // MKL layouts for inputs in the context
+    dnnLayout_t mkl_lt_input = nullptr;
+
+    void MklCleanup() {
+      bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+      if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input);
+      if (mkl_prim_batchnorm != nullptr) dnnDelete_F32(mkl_prim_batchnorm);
+    }
+
+    void MklExtractParams(OpKernelContext* context,
+                          const TensorFormat& tensor_format) {
+      const Tensor& input = MklGetInput(context, 0);
+      bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+      mkl_params.in_dim = input_in_mkl_format
+                              ? mkl_shape_input_shape.GetDimension()
+                              : input.dims();
+      mkl_params.in_sizes[0] = static_cast<size_t>(
+          input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0]
+                              : GetTensorDim(input, tensor_format, 'W'));
+      mkl_params.in_sizes[1] = static_cast<size_t>(
+          input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1]
+                              : GetTensorDim(input, tensor_format, 'H'));
+      mkl_params.in_sizes[2] = static_cast<size_t>(
+          input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2]
+                              : GetTensorDim(input, tensor_format, 'C'));
+      mkl_params.in_sizes[3] = static_cast<size_t>(
+          input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3]
+                              : GetTensorDim(input, tensor_format, 'N'));
+      mkl_params.depth = mkl_params.in_sizes[2];
+      GetStridesFromSizes(tensor_format, mkl_params.in_strides,
+                          mkl_params.in_sizes);
+    }
+
+    void MklCreateInputLayout(OpKernelContext* context) {
+      const Tensor& input = MklGetInput(context, 0);
+      bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+      if (input_in_mkl_format) {
+        mkl_lt_input =
+            static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout());
+      } else {
+        CHECK_EQ(
+            dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dim,
+                                mkl_params.in_sizes, mkl_params.in_strides),
+            E_SUCCESS);
+      }
+    }
+
+    void MklPrepareContextInputs(OpKernelContext* context,
+                                 Tensor* mkl_tmp_input_buf_tensor,
+                                 Tensor* mkl_tmp_scale_shift_buf_tensor) {
+      bool mkl_convert_input;
+      dnnPrimitive_t mkl_prim_convert_input = nullptr;
+      dnnLayout_t mkl_lt_internal_input = nullptr;
+      void* mkl_buf_converted_input = nullptr;
+      // Compare with internal layouts and convert if needed
+      const Tensor& input = MklGetInput(context, 0);
+      void* mkl_buf_input =
+          const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
+                   &mkl_lt_internal_input, mkl_prim_batchnorm, dnnResourceSrc),
+               E_SUCCESS);
+      mkl_convert_input =
+          !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input);
+      if (mkl_convert_input) {
+        CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input,
+                                         mkl_lt_internal_input),
+                 E_SUCCESS);
+        AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
+                       &mkl_buf_converted_input);
+        CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
+                                          mkl_buf_converted_input),
+                 E_SUCCESS);
+        dnnDelete_F32(mkl_prim_convert_input);
+      }
+      dnnLayoutDelete_F32(mkl_lt_internal_input);
+      mkl_res_batchnorm[dnnResourceSrc] =
+          (mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input;
+
+      // scale-shift layout is created from primitive. So no conversion
+      // is needed, however, a buffer has to be allocated.
+      dnnLayout_t mkl_lt_scale_shift = nullptr;
+      void* mkl_buf_scale_shift = nullptr;
+      CHECK_EQ(
+          dnnLayoutCreateFromPrimitive_F32(
+              &mkl_lt_scale_shift, mkl_prim_batchnorm, dnnResourceScaleShift),
+          E_SUCCESS);
+      AllocTmpBuffer(context, mkl_tmp_scale_shift_buf_tensor,
+                     mkl_lt_scale_shift, &mkl_buf_scale_shift);
+      // Fill the scale-shift buffer with data, presumably buffer is 2D array
+      const Tensor& scale = MklGetInput(context, 1);
+      const Tensor& shift = MklGetInput(context, 2);
+      float* buf_scale_shift = static_cast<float*>(mkl_buf_scale_shift);
+      float* buf_scale = const_cast<float*>(
+          static_cast<const float*>(scale.flat<float>().data()));
+      float* buf_shift = const_cast<float*>(
+          static_cast<const float*>(shift.flat<float>().data()));
+      auto depth = mkl_params.depth;
+      for (int i = 0; i < depth; i++) {
+        buf_scale_shift[i] = buf_scale[i];
+        buf_scale_shift[i + depth] = buf_shift[i];
+      }
+      mkl_res_batchnorm[dnnResourceScaleShift] = mkl_buf_scale_shift;
+    }
+
+    inline void MklSetMeanVariance(const Tensor& mean, const Tensor& variance) {
+      mkl_res_batchnorm[dnnResourceMean] = const_cast<void*>(
+          static_cast<const void*>(mean.flat<float>().data()));
+      mkl_res_batchnorm[dnnResourceVariance] = const_cast<void*>(
+          static_cast<const void*>(variance.flat<float>().data()));
+    }
+  } MklFusedBatchNormOpContext;
+};
+
+#define REGISTER_MKL_CPU(T)                                         \
+  REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNorm")                 \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
+                          MklFusedBatchNormOp<CPUDevice, T>);
+TF_CALL_float(REGISTER_MKL_CPU);
+#undef REGISTER_MKL_CPU
+
+template <typename Device, typename T>
+class MklFusedBatchNormGradOp : public OpKernel {
+ public:
+  explicit MklFusedBatchNormGradOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    float epsilon;
+    OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
+    epsilon_ = T(epsilon);
+    string tensor_format;
+    OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
+    OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
+                errors::InvalidArgument("Invalid data format"));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    MklFusedBatchNormGradOpContext mkl_context;
+
+    const Tensor& out_backprop = MklGetInput(context, 0);
+    const Tensor& input = MklGetInput(context, 1);
+    const Tensor& scale = MklGetInput(context, 2);
+    const Tensor& saved_mean = MklGetInput(context, 3);
+    const Tensor& saved_var = MklGetInput(context, 4);
+
+    // Here scale, mean, and variance are 1D and considered
+    // those having same layout in MKL and TF
+    GetMklShape(context, 0, &(mkl_context.mkl_shape_out_backprop));
+    GetMklShape(context, 1, &(mkl_context.mkl_shape_input_shape));
+
+    bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor();
+    bool out_backprop_in_mkl_format =
+        mkl_context.mkl_shape_out_backprop.IsMklTensor();
+    if (!out_backprop_in_mkl_format) {
+      OP_REQUIRES(context, out_backprop.dims() == 4,
+                  errors::InvalidArgument("input must be 4-dimensional",
+                                          out_backprop.shape().DebugString()));
+    }
+    if (!input_in_mkl_format) {
+      OP_REQUIRES(context, input.dims() == 4,
+                  errors::InvalidArgument("input must be 4-dimensional",
+                                          input.shape().DebugString()));
+    }
+    OP_REQUIRES(context, scale.dims() == 1,
+                errors::InvalidArgument("scale must be 1-dimensional",
+                                        scale.shape().DebugString()));
+    OP_REQUIRES(context, saved_mean.dims() == 1,
+                errors::InvalidArgument("saved mean must be 1-dimensional",
+                                        saved_mean.shape().DebugString()));
+    OP_REQUIRES(context, saved_var.dims() == 1,
+                errors::InvalidArgument("saved variance must be 1-dimensional",
+                                        saved_var.shape().DebugString()));
+
+    mkl_context.MklExtractParams(context, tensor_format_);
+
+    mkl_context.MklCreateInputLayout(context);
+
+    unsigned int flag_batch_norm_grad = dnnUseScaleShift;
+
+    // Create Backward Op primitive.
+    CHECK_EQ(dnnBatchNormalizationCreateBackward_v2_F32(
+                 &(mkl_context.mkl_prim_batchnorm_bwd), nullptr,
+                 mkl_context.mkl_lt_input, static_cast<float>(epsilon_),
+                 flag_batch_norm_grad),
+             E_SUCCESS);
+
+    // Temporary tensors and their buffers if conversion is required
+    Tensor mkl_tmp_input_buf_tensor, mkl_tmp_outbackprop_buf_tensor,
+        mkl_tmp_scaleshift_buf_tensor;
+    mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor,
+                                        &mkl_tmp_outbackprop_buf_tensor,
+                                        &mkl_tmp_scaleshift_buf_tensor);
+
+    // Allocate tensor for grad w.r.t. input(x)
+    Tensor* in_backprop = nullptr;
+    TensorShape tf_shape_in_backprop;
+    MklShape mkl_shape_in_backprop;
+    mkl_shape_in_backprop.SetMklTensor(true);
+    mkl_shape_in_backprop.SetMklLayout(mkl_context.mkl_prim_batchnorm_bwd,
+                                       dnnResourceDiffSrc);
+    mkl_shape_in_backprop.SetTfLayout(mkl_context.mkl_params.in_dims,
+                                      mkl_context.mkl_params.in_sizes,
+                                      mkl_context.mkl_params.in_strides);
+    mkl_shape_in_backprop.SetTfDimOrder(mkl_context.mkl_params.in_dims,
+                                        tensor_format_);
+    tf_shape_in_backprop.AddDim(
+        dnnLayoutGetMemorySize_F32(
+            static_cast<dnnLayout_t>(mkl_shape_in_backprop.GetMklLayout())) /
+        sizeof(T));
+    AllocateOutputSetMklShape(context, 0, &in_backprop, tf_shape_in_backprop,
+                              mkl_shape_in_backprop);
+    mkl_context.mkl_res_batchnorm_bwd[dnnResourceDiffSrc] =
+        static_cast<void*>(in_backprop->flat<T>().data());
+
+    // grad_scale and grad_shift are combined together in MKL
+    // So create a single temporary buffer for those.
+    // Also set dnnResourceDiffScaleShift to the temporary buffer
+    Tensor mkl_tmp_grad_scale_shift_buf_tensor;
+    mkl_context.MklPrepareGradScaleShift(context,
+                                         &mkl_tmp_grad_scale_shift_buf_tensor);
+
+    // All dnn resources are set now, ready to execute
+    CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm_bwd,
+                            mkl_context.mkl_res_batchnorm_bwd),
+             E_SUCCESS);
+
+    // Now separate out scale and shift grad and copy to individual tensors
+    const TensorShape& tf_shape_scale_shift = scale.shape();
+    // Allocate tensor for grad w.r.t. scale (beta)
+    Tensor* scale_backprop = nullptr;
+    MklShape mkl_shape_scale_backprop;
+    AllocateOutputSetMklShape(context, 1, &scale_backprop, tf_shape_scale_shift,
+                              mkl_shape_scale_backprop);
+
+    // Allocate tensor for grad w.r.t. shift(gamma)
+    Tensor* shift_backprop = nullptr;
+    MklShape mkl_shape_shift_backprop;
+    AllocateOutputSetMklShape(context, 2, &shift_backprop, tf_shape_scale_shift,
+                              mkl_shape_shift_backprop);
+
+    // copy scale and shift grads to tensors
+    float* mkl_buf_scale_shift = const_cast<float*>(static_cast<const float*>(
+        mkl_tmp_grad_scale_shift_buf_tensor.flat<T>().data()));
+    float* tf_buf_scale = const_cast<float*>(
+        static_cast<const float*>(scale_backprop->flat<T>().data()));
+    float* tf_buf_shift = const_cast<float*>(
+        static_cast<const float*>(shift_backprop->flat<T>().data()));
+    auto depth = mkl_context.mkl_params.depth;
+    for (int i = 0; i < depth; i++) {
+      tf_buf_scale[i] = mkl_buf_scale_shift[i];
+      tf_buf_shift[i] = mkl_buf_scale_shift[i + depth];
+    }
+
+    // Two placeholders for estimated_mean and estimated_variance, which are
+    // used for inference and thus not needed here for gradient computation.
+    Tensor* placeholder_1 = nullptr;
+    MklShape mkl_shape_placeholder_1;
+    AllocateOutputSetMklShape(context, 3, &placeholder_1, TensorShape({}),
+                              mkl_shape_placeholder_1);
+    Tensor* placeholder_2 = nullptr;
+    MklShape mkl_shape_placeholder_2;
+    AllocateOutputSetMklShape(context, 4, &placeholder_2, TensorShape({}),
+                              mkl_shape_placeholder_2);
+
+    mkl_context.MklCleanup();
+  }
+
+ private:
+  T epsilon_;
+  TensorFormat tensor_format_;
+
+  // Structure containing all info for MklOp
+  typedef struct {
+    // Parameters used for input and output layouts
+    struct MklBatchNormParams {
+      // BatchNormOp src and
+      size_t in_dims;
+      size_t in_sizes[4];
+      size_t in_strides[4];
+      size_t depth;  // Batch normalization is done for per channel.
+    } mkl_params;
+
+    MklShape mkl_shape_out_backprop;
+    MklShape mkl_shape_input_shape;
+
+    // MKL primitive and resources for BatchNormOp
+    dnnPrimitive_t mkl_prim_batchnorm_bwd = nullptr;
+    void* mkl_res_batchnorm_bwd[dnnResourceNumber];
+
+    // MKL layouts for inputs in the context
+    dnnLayout_t mkl_lt_out_backprop = nullptr;
+    dnnLayout_t mkl_lt_input = nullptr;
+
+    void MklCleanup() {
+      bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+      bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor();
+      if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input);
+      if (!out_backprop_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_out_backprop);
+
+      dnnDelete_F32(mkl_prim_batchnorm_bwd);
+    }
+
+    void MklExtractParams(OpKernelContext* context,
+                          const TensorFormat& tensor_format) {
+      const Tensor& input = MklGetInput(context, 1);
+      bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+      mkl_params.in_dims = input_in_mkl_format
+                               ? mkl_shape_input_shape.GetDimension()
+                               : input.dims();
+      mkl_params.in_sizes[0] = static_cast<size_t>(
+          input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0]
+                              : GetTensorDim(input, tensor_format, 'W'));
+      mkl_params.in_sizes[1] = static_cast<size_t>(
+          input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1]
+                              : GetTensorDim(input, tensor_format, 'H'));
+      mkl_params.in_sizes[2] = static_cast<size_t>(
+          input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2]
+                              : GetTensorDim(input, tensor_format, 'C'));
+      mkl_params.in_sizes[3] = static_cast<size_t>(
+          input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3]
+                              : GetTensorDim(input, tensor_format, 'N'));
+      mkl_params.depth = mkl_params.in_sizes[2];
+      GetStridesFromSizes(tensor_format, mkl_params.in_strides,
+                          mkl_params.in_sizes);
+    }
+
+    void MklCreateInputLayout(OpKernelContext* context) {
+      bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
+      if (input_in_mkl_format) {
+        mkl_lt_input =
+            static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout());
+      } else {
+        CHECK_EQ(
+            dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dims,
+                                mkl_params.in_sizes, mkl_params.in_strides),
+            E_SUCCESS);
+      }
+
+      bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor();
+      if (out_backprop_in_mkl_format) {
+        mkl_lt_out_backprop =
+            static_cast<dnnLayout_t>(mkl_shape_out_backprop.GetCurLayout());
+      } else {
+        CHECK_EQ(
+            dnnLayoutCreate_F32(&mkl_lt_out_backprop, mkl_params.in_dims,
+                                mkl_params.in_sizes, mkl_params.in_strides),
+            E_SUCCESS);
+      }
+    }
+
+    void MklPrepareContextInputs(OpKernelContext* context,
+                                 Tensor* mkl_tmp_input_buf_tensor,
+                                 Tensor* mkl_tmp_outbackprop_buf_tensor,
+                                 Tensor* mkl_tmp_scaleshift_buf_tensor) {
+      bool mkl_convert_input;
+      dnnPrimitive_t mkl_prim_convert_input = nullptr;
+      dnnLayout_t mkl_lt_internal_input = nullptr;
+      void* mkl_buf_converted_input = nullptr;
+      // Compare with internal layouts and convert if needed
+      const Tensor& input = MklGetInput(context, 1);
+      void* mkl_buf_input =
+          const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+      CHECK_EQ(
+          dnnLayoutCreateFromPrimitive_F32(
+              &mkl_lt_internal_input, mkl_prim_batchnorm_bwd, dnnResourceSrc),
+          E_SUCCESS);
+      mkl_convert_input =
+          !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input);
+      if (mkl_convert_input) {
+        CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input,
+                                         mkl_lt_internal_input),
+                 E_SUCCESS);
+        AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
+                       &mkl_buf_converted_input);
+        CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
+                                          mkl_buf_converted_input),
+                 E_SUCCESS);
+        dnnDelete_F32(mkl_prim_convert_input);
+      }
+      dnnLayoutDelete_F32(mkl_lt_internal_input);
+      mkl_res_batchnorm_bwd[dnnResourceSrc] =
+          (mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input;
+
+      bool mkl_convert_out_backprop;
+      dnnPrimitive_t mkl_prim_convert_out_backprop = nullptr;
+      dnnLayout_t mkl_lt_internal_out_backprop = nullptr;
+      void* mkl_buf_converted_out_backprop = nullptr;
+      // Compare with internal layouts and convert if needed
+      const Tensor& out_backprop = MklGetInput(context, 0);
+      void* mkl_buf_out_backprop = const_cast<void*>(
+          static_cast<const void*>(out_backprop.flat<T>().data()));
+      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop,
+                                                mkl_prim_batchnorm_bwd,
+                                                dnnResourceDiffDst),
+               E_SUCCESS);
+      mkl_convert_out_backprop = !dnnLayoutCompare_F32(
+          mkl_lt_internal_out_backprop, mkl_lt_out_backprop);
+      if (mkl_convert_out_backprop) {
+        CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop,
+                                         mkl_lt_out_backprop,
+                                         mkl_lt_internal_out_backprop),
+                 E_SUCCESS);
+        AllocTmpBuffer(context, mkl_tmp_outbackprop_buf_tensor,
+                       mkl_lt_internal_out_backprop,
+                       &mkl_buf_converted_out_backprop);
+        CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop,
+                                          mkl_buf_out_backprop,
+                                          mkl_buf_converted_out_backprop),
+                 E_SUCCESS);
+        dnnDelete_F32(mkl_prim_convert_out_backprop);
+      }
+      dnnLayoutDelete_F32(mkl_lt_internal_out_backprop);
+      mkl_res_batchnorm_bwd[dnnResourceDiffDst] =
+          (mkl_convert_out_backprop) ? mkl_buf_converted_out_backprop
+                                     : mkl_buf_out_backprop;
+
+      // Set dnnResourceMean and dnnResourceVariance
+      const Tensor& saved_mean = MklGetInput(context, 3);
+      const Tensor& saved_var = MklGetInput(context, 4);
+      void* mkl_buf_saved_mean = const_cast<void*>(
+          static_cast<const void*>(saved_mean.flat<T>().data()));
+      void* mkl_buf_saved_var = const_cast<void*>(
+          static_cast<const void*>(saved_var.flat<T>().data()));
+      mkl_res_batchnorm_bwd[dnnResourceMean] = mkl_buf_saved_mean;
+      mkl_res_batchnorm_bwd[dnnResourceVariance] = mkl_buf_saved_var;
+
+      // Set dnnResourceScaleShift
+      // Note backward Op needs only current values of scale parameters,
+      // shift parameters could be garbage and won't be used
+      const Tensor& scale = MklGetInput(context, 2);
+      dnnLayout_t mkl_lt_scale_shift = nullptr;
+      void* mkl_buf_scale_shift = nullptr;
+      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_scale_shift,
+                                                mkl_prim_batchnorm_bwd,
+                                                dnnResourceScaleShift),
+               E_SUCCESS);
+      AllocTmpBuffer(context, mkl_tmp_scaleshift_buf_tensor, mkl_lt_scale_shift,
+                     &mkl_buf_scale_shift);
+      float* pscale =
+          const_cast<float*>(static_cast<const float*>(scale.flat<T>().data()));
+      float* pscale_shift = static_cast<float*>(mkl_buf_scale_shift);
+      auto depth = mkl_params.depth;
+      for (int i = 0; i < depth; i++) pscale_shift[i] = pscale[i];
+      mkl_res_batchnorm_bwd[dnnResourceScaleShift] = mkl_buf_scale_shift;
+      dnnLayoutDelete_F32(mkl_lt_scale_shift);
+    }
+
+    void MklPrepareGradScaleShift(OpKernelContext* context,
+                                  Tensor* mkl_tmp_grad_scale_shift_buf_tensor) {
+      dnnLayout_t mkl_lt_grad_scaleshift = nullptr;
+      void* mkl_buf_grad_scaleshift = nullptr;
+      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_grad_scaleshift,
+                                                mkl_prim_batchnorm_bwd,
+                                                dnnResourceDiffScaleShift),
+               E_SUCCESS);
+      AllocTmpBuffer(context, mkl_tmp_grad_scale_shift_buf_tensor,
+                     mkl_lt_grad_scaleshift, &mkl_buf_grad_scaleshift);
+      mkl_res_batchnorm_bwd[dnnResourceDiffScaleShift] =
+          mkl_buf_grad_scaleshift;
+      dnnLayoutDelete_F32(mkl_lt_grad_scaleshift);
+    }
+  } MklFusedBatchNormGradOpContext;
+};
+
+#define REGISTER_MKL_CPU(T)                                         \
+  REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormGrad")             \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
+                          MklFusedBatchNormGradOp<CPUDevice, T>);
+TF_CALL_float(REGISTER_MKL_CPU);
+#undef REGISTER_MKL_CPU
+}  // namespace tensorflow
+
+#endif  // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
new file mode 100644
index 00000000000..edca8e2553d
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -0,0 +1,722 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// LRN = Local Response Normalization
+// See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
+// layout and primitives, use MKL dnn primitives to compute local
+// response normalization
+
+#ifdef INTEL_MKL
+
+#define EIGEN_USE_THREADS
+#include <vector>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/mkl_util.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+#if !defined(IS_MOBILE_PLATFORM)
+#include "tensorflow/core/util/work_sharder.h"
+#endif
+
+namespace tensorflow {
+
+namespace {
+// Create a depth-by-depth band matrix with 1s along a swath of size (2 *
+// depth_radius + 1) around the diagonal.
+template <typename T>
+void GetBandMatrix(int depth, int depth_radius,
+                   Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
+  result->setZero();
+  for (int row = 0; row < depth; ++row) {
+    const int begin = std::max<int>(0, row - depth_radius);
+    const int end = std::min<int>(depth, row + depth_radius + 1);
+    Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
+    Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
+    result->slice(start, sizes).setConstant(T(1));
+  }
+}
+
+}  // namespace
+
+template <typename T>
+class MklLRNOp : public OpKernel {
+ public:
+  ~MklLRNOp() {}
+
+  explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
+    int64 depth_radius64;
+    OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
+    OP_REQUIRES(
+        context,
+        FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
+        errors::InvalidArgument("depth_radius = ", depth_radius64,
+                                " larger than int max"));
+    depth_radius_ = static_cast<size_t>(depth_radius64);
+
+    OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+    OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+    OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+    workspace_enabled_ = false;
+    context->GetAttr("workspace_enabled", &workspace_enabled_);
+  }
+
+  void Compute(OpKernelContext* context) override {
+    MklLRNOpContext mkl_context;
+
+    const Tensor& input = MklGetInput(context, 0);
+    GetMklShape(context, 0, &mkl_context.input_shape);
+    bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
+
+    // Sanity checks
+    mkl_context.in_dims = input_in_mkl_format
+                              ? mkl_context.input_shape.GetDimension()
+                              : input.dims();
+    OP_REQUIRES(context, mkl_context.in_dims == 4,
+                errors::InvalidArgument("input must be 4-dimensional"));
+    OP_REQUIRES(
+        context,
+        FastBoundsCheck(input.NumElements(), std::numeric_limits<int>::max()),
+        errors::InvalidArgument("argument to LRN too large"));
+
+    if (!input_in_mkl_format) {
+      mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
+                                    beta_, input);
+      return;
+    }
+
+    if (input_in_mkl_format) {
+      // MKL supports normalization over channel dimension only
+      if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
+          MklDims::C) {
+        mkl_context.lt_input =
+            static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout());
+        workspace_enabled_ = true;
+      } else {
+        mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
+                                      beta_, input);
+        return;
+      }
+    }
+
+    int kernel_size = 2 * depth_radius_ + 1;
+
+    CHECK_EQ(dnnLRNCreateForward_F32(
+                 &mkl_context.lrn_fwd, NULL, mkl_context.lt_input, kernel_size,
+                 static_cast<float>(alpha_ * kernel_size), beta_, bias_),
+             E_SUCCESS);
+
+    // Allocate output tensor and shape
+    Tensor* output = nullptr;
+    Tensor* workspace = nullptr;
+
+    // Convert Inputs if needed
+    Tensor mkl_tmp_input_buf_tensor;
+    mkl_context.MklPrepareLRNInputs(context, &mkl_tmp_input_buf_tensor);
+
+    // Allocate Layer Outputs
+    mkl_context.MklAllocateOutputs(context, &output, &workspace,
+                                   workspace_enabled_);
+
+    Tensor mkl_tmp_workspace_buf_tensor;
+    mkl_context.MklPrepareLRNOutputs(context, output, workspace,
+                                     &mkl_tmp_workspace_buf_tensor,
+                                     workspace_enabled_);
+
+    // Execute LRN.
+    CHECK_EQ(dnnExecute_F32(mkl_context.lrn_fwd, mkl_context.lrn_res),
+             E_SUCCESS);
+
+    // Release MKL resources.
+    mkl_context.MklCleanup();
+  }
+
+ private:
+  typedef struct {
+    size_t in_dims;
+    size_t in_sizes[4];
+    size_t in_strides[4];
+    size_t out_sizes[4];
+    size_t out_strides[4];
+    MklShape input_shape;
+    dnnPrimitive_t lrn_fwd = nullptr;
+    dnnPrimitive_t convert_input = nullptr;
+    /* dnnPrimitive_t convert_output; */
+    dnnLayout_t lt_input = nullptr;
+    /* dnnLayout_t lt_output; */
+    dnnLayout_t lt_internal_input = nullptr;
+    dnnLayout_t lt_internal_workspace = nullptr;
+    dnnLayout_t lt_internal_output = nullptr;
+    void* lrn_res[dnnResourceNumber];
+
+    // Convert Inputs if needed
+    void MklPrepareLRNInputs(OpKernelContext* context,
+                             Tensor* mkl_tmp_input_buf_tensor) {
+      const Tensor& input = MklGetInput(context, 0);
+      void* mkl_buf_input =
+          const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
+
+      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_input, lrn_fwd,
+                                                dnnResourceSrc),
+               E_SUCCESS);
+
+      void* mkl_buf_convert_input = nullptr;
+      bool mkl_convert_input = false;
+      mkl_convert_input = !dnnLayoutCompare_F32(lt_internal_input, lt_input);
+
+      if (mkl_convert_input) {
+        CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input,
+                                         lt_internal_input),
+                 E_SUCCESS);
+        AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_internal_input,
+                       &mkl_buf_convert_input);
+        CHECK_EQ(dnnConversionExecute_F32(convert_input, mkl_buf_input,
+                                          mkl_buf_convert_input),
+                 E_SUCCESS);
+        dnnDelete_F32(convert_input);
+      }
+
+      lrn_res[dnnResourceSrc] =
+          (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
+    }
+
+    // Allocate Layer Outputs
+    void MklAllocateOutputs(OpKernelContext* context, Tensor** output,
+                            Tensor** workspace, bool workspace_enabled_) {
+      TensorShape mkl_output_tf_shape; /* First tensor */
+      MklShape mkl_output_mkl_shape;   /* Second tensor */
+
+      mkl_output_mkl_shape.SetMklTensor(true);
+      mkl_output_mkl_shape.SetMklLayout(lrn_fwd, dnnResourceDst);
+      mkl_output_mkl_shape.SetTfLayout(in_dims, input_shape.GetSizes(),
+                                       input_shape.GetStrides());
+      mkl_output_mkl_shape.SetTfDimOrder(in_dims,
+                                         input_shape.GetTfToMklDimMap());
+      mkl_output_tf_shape.AddDim(
+          dnnLayoutGetMemorySize_F32(
+              static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
+          sizeof(T));
+      AllocateOutputSetMklShape(context, 0, output,
+                                mkl_output_tf_shape /* First tensor */,
+                                mkl_output_mkl_shape /* Second Tensor */);
+
+      if (workspace_enabled_) {
+        TensorShape mkl_workspace_tf_shape; /* First tensor */
+        MklShape mkl_workspace_mkl_shape;   /* Second tensor */
+        mkl_workspace_mkl_shape.SetMklTensor(false);
+        mkl_workspace_mkl_shape.SetMklLayout(lrn_fwd, dnnResourceWorkspace);
+        // Assumes workspace has same TF layout and TF dim order as input
+        mkl_workspace_mkl_shape.SetTfLayout(in_dims, input_shape.GetSizes(),
+                                            input_shape.GetStrides());
+        mkl_workspace_mkl_shape.SetTfDimOrder(in_dims,
+                                              input_shape.GetTfToMklDimMap());
+        mkl_workspace_tf_shape.AddDim(
+            dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+                mkl_workspace_mkl_shape.GetMklLayout())) /
+            sizeof(T));
+        AllocateOutputSetMklShape(context, 1, workspace,
+                                  mkl_workspace_tf_shape /* First tensor */,
+                                  mkl_workspace_mkl_shape /* Second Tensor */);
+      }
+    }
+
+    void MklPrepareLRNOutputs(OpKernelContext* context, Tensor* output,
+                              Tensor* workspace,
+                              Tensor* mkl_tmp_workspace_buf_tensor,
+                              bool workspace_enabled_) {
+      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_workspace, lrn_fwd,
+                                                dnnResourceWorkspace),
+               E_SUCCESS);
+
+      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_output, lrn_fwd,
+                                                dnnResourceDst),
+               E_SUCCESS);
+
+      void* mkl_buf_output =
+          const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
+      lrn_res[dnnResourceDst] = mkl_buf_output;
+
+      void* mkl_buf_workspace = nullptr;
+      if (workspace_enabled_) {
+        mkl_buf_workspace = const_cast<void*>(
+            static_cast<const void*>(workspace->flat<T>().data()));
+      } else {
+        AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor,
+                       lt_internal_workspace, &mkl_buf_workspace);
+      }
+      lrn_res[dnnResourceWorkspace] = mkl_buf_workspace;
+    }
+
+    // Fallback implementation - Taken from lrn_op.cc
+    // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
+    // copy.
+    void MklDefaultToEigen(OpKernelContext* context, int depth_radius_,
+                           float bias_, float alpha_, float beta_,
+                           const Tensor& input) {
+      const int batch = static_cast<int>(input.dim_size(0));
+      const int rows = static_cast<int>(input.dim_size(1));
+      const int cols = static_cast<int>(input.dim_size(2));
+      const int depth = static_cast<int>(input.dim_size(3));
+      const int nodes = cols * rows;
+
+      auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
+      // Multiplying the input with the band matrix has the effect of reducing
+      // the
+      // correct patch along the depth.
+      Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
+      GetBandMatrix<T>(depth, depth_radius_, &multiplier);
+
+      Tensor *output, *workspace;
+      MklShape mkl_output_mkl_shape, mkl_workspace_mkl_shape;
+      mkl_output_mkl_shape.SetMklTensor(false);
+      mkl_output_mkl_shape.SetDimensions(4);
+      AllocateOutputSetMklShape(context, 0, &output, input.shape(),
+                                mkl_output_mkl_shape);
+
+      mkl_workspace_mkl_shape.SetMklTensor(false);
+      mkl_workspace_mkl_shape.SetDimensions(4);
+      AllocateOutputSetMklShape(context, 1, &workspace, input.shape(),
+                                mkl_workspace_mkl_shape);
+
+      auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
+      Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
+      auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
+      if (beta_ == T(1)) {
+        out_shaped.device(context->eigen_cpu_device()) =
+            in_shaped * tmp.inverse();
+      } else if (beta_ == T(0.5)) {
+        out_shaped.device(context->eigen_cpu_device()) =
+            in_shaped * tmp.rsqrt();
+      } else {
+        out_shaped.device(context->eigen_cpu_device()) =
+            in_shaped * (tmp.log() * -beta_).exp();
+      }
+    }
+
+    // Release MKL resources.
+    void MklCleanup() {
+      dnnDelete_F32(lrn_fwd);
+      dnnLayoutDelete_F32(lt_internal_input);
+      dnnLayoutDelete_F32(lt_internal_workspace);
+      dnnLayoutDelete_F32(lt_internal_output);
+    }
+  } MklLRNOpContext;
+
+  typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
+
+  bool workspace_enabled_;
+  int depth_radius_;
+  float bias_;
+  float alpha_;
+  float beta_;
+};
+
+template <typename T>
+class MklLRNGradOp : public OpKernel {
+ public:
+  explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
+    int64 depth_radius64;
+    OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
+    OP_REQUIRES(
+        context,
+        FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
+        errors::InvalidArgument("depth_radius = ", depth_radius64,
+                                " larger than int max"));
+    depth_radius_ = static_cast<int>(depth_radius64);
+    OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+    OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+    OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+    workspace_enabled_ = false;
+    context->GetAttr("workspace_enabled", &workspace_enabled_);
+  }
+
+  void Compute(OpKernelContext* context) override {
+    MklLRNGradOpContext mkl_context;
+    mkl_context.depth_radius_ = depth_radius_;
+    mkl_context.bias_ = bias_;
+    mkl_context.alpha_ = alpha_;
+    mkl_context.beta_ = beta_;
+
+    const Tensor& in_grads = MklGetInput(context, 0);
+    const Tensor& in_image = MklGetInput(context, 1);
+    const Tensor& out_image = MklGetInput(context, 2);
+
+    GetMklShape(context, 0, &mkl_context.ingrad_shape);
+    GetMklShape(context, 1, &mkl_context.inimage_shape);
+    GetMklShape(context, 2, &mkl_context.outimage_shape);
+
+    bool ingrad_in_mkl_format = mkl_context.ingrad_shape.IsMklTensor();
+    bool inimage_in_mkl_format = mkl_context.inimage_shape.IsMklTensor();
+    bool outimage_in_mkl_format = mkl_context.outimage_shape.IsMklTensor();
+
+    mkl_context.in_dims = inimage_in_mkl_format
+                              ? mkl_context.inimage_shape.GetDimension()
+                              : in_image.dims();
+    OP_REQUIRES(context, mkl_context.in_dims == 4,
+                errors::InvalidArgument("input images must be 4-dimensional"));
+
+    if (!workspace_enabled_) {
+      mkl_context.MklDefaultToEigen(context);
+      return;
+    }
+    if (ingrad_in_mkl_format || inimage_in_mkl_format) {
+      const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
+                                          ? &mkl_context.ingrad_shape
+                                          : &mkl_context.inimage_shape;
+      if (tmp_mkl_shape->tf_dim_idx(mkl_context.in_dims - 1) != MklDims::C) {
+        // Fallback to eigen
+        mkl_context.MklDefaultToEigen(context);
+        return;
+      } else {  // MKL supports normalization over channel dimension only
+        for (int i = 0; i < mkl_context.in_dims; i++) {
+          mkl_context.in_sizes[i] = mkl_context.out_sizes[i] =
+              tmp_mkl_shape->GetSizes()[i];
+          mkl_context.in_strides[i] = mkl_context.out_strides[i] =
+              tmp_mkl_shape->GetStrides()[i];
+        }
+      }
+    } else {
+      // Fallback to eigen
+      mkl_context.MklDefaultToEigen(context);
+      return;
+    }
+
+    // Dimensions check for sanity purpose
+    if (ingrad_in_mkl_format) {
+      OP_REQUIRES(
+          context, mkl_context.ingrad_shape.GetDimension() == 4,
+          errors::InvalidArgument("input gradient must be 4-dimensional"));
+    } else {
+      OP_REQUIRES(
+          context, in_grads.dims() == 4,
+          errors::InvalidArgument("input gradient must be 4-dimensional"));
+    }
+
+    if (outimage_in_mkl_format) {
+      OP_REQUIRES(
+          context, mkl_context.outimage_shape.GetDimension() == 4,
+          errors::InvalidArgument("Output image must be 4-dimensional"));
+    } else {
+      OP_REQUIRES(
+          context, out_image.dims() == 4,
+          errors::InvalidArgument("Output image must be 4-dimensional"));
+    }
+
+    // Prepare mkl input layout
+    mkl_context.MklPrepareLRNInputsLayouts(context);
+    int ksize = 2 * depth_radius_ + 1;
+
+    CHECK_EQ(dnnLRNCreateBackward_F32(
+                 &mkl_context.lrn_bwd, NULL, mkl_context.lt_input,
+                 mkl_context.lt_output, ksize,
+                 static_cast<float>(alpha_ * ksize), beta_, bias_),
+             E_SUCCESS);
+
+    // Allocate output tensor and shape.
+    TensorShape mkl_output_tf_shape; /* First tensor */
+    MklShape mkl_output_mkl_shape;   /* Second tensor */
+    mkl_output_mkl_shape.SetMklTensor(true);
+    CHECK_NE(mkl_context.lrn_bwd, nullptr);
+    mkl_output_mkl_shape.SetMklLayout(mkl_context.lrn_bwd, dnnResourceDiffSrc);
+    mkl_output_mkl_shape.SetTfLayout(mkl_context.in_dims, mkl_context.out_sizes,
+                                     mkl_context.out_strides);
+    if (ingrad_in_mkl_format) {
+      mkl_output_mkl_shape.SetTfDimOrder(
+          mkl_context.in_dims, mkl_context.ingrad_shape.GetTfToMklDimMap());
+    } else {
+      mkl_output_mkl_shape.SetTfDimOrder(
+          mkl_context.in_dims, mkl_context.inimage_shape.GetTfToMklDimMap());
+    }
+    mkl_output_tf_shape.AddDim(
+        dnnLayoutGetMemorySize_F32(
+            static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
+        sizeof(T));
+    Tensor* output = nullptr;
+    AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
+                              mkl_output_mkl_shape);
+
+    // Get pointers to output data.
+    void* user_output =
+        const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
+
+    Tensor mkl_tmp_input_buf_tensor, mkl_tmp_image_buf_tensor,
+        mkl_tmp_outimage_buf_tensor, mkl_tmp_workspace_buf_tensor;
+    // Convert Inputs if needed
+    mkl_context.MklPrepareLRNGradInput(
+        context, &mkl_tmp_input_buf_tensor, &mkl_tmp_image_buf_tensor,
+        &mkl_tmp_outimage_buf_tensor, &mkl_tmp_workspace_buf_tensor);
+
+    // We do not do any conversion for output. But we simply emit it
+    // in MKL format.
+    mkl_context.res_lrn_bwd[dnnResourceDiffSrc] = user_output;
+    // Execute LRN backward using dnnExecute
+    CHECK_EQ(dnnExecute_F32(mkl_context.lrn_bwd, mkl_context.res_lrn_bwd),
+             E_SUCCESS);
+    // Release MKL resources.
+    mkl_context.Mklcleanup();
+  }
+
+ private:
+  typedef struct {
+    int depth_radius_;
+    float bias_;
+    float alpha_;
+    float beta_;
+    size_t in_dims;
+    size_t in_sizes[4];
+    size_t in_strides[4];
+    size_t out_sizes[4];
+    size_t out_strides[4];
+    MklShape ingrad_shape, inimage_shape, outimage_shape;
+    dnnPrimitive_t lrn_bwd = nullptr;
+    dnnPrimitive_t convert_input = nullptr;
+    /* dnnPrimitive_t convert_output; */
+    dnnLayout_t lt_input = nullptr;
+    dnnLayout_t lt_output = nullptr;
+    dnnLayout_t lt_bdw_input = nullptr;
+    dnnLayout_t lt_workspace = nullptr;
+    dnnLayout_t lt_internal_input = nullptr;
+    /* dnnLayout_t lt_internal_workspace;
+    dnnLayout_t lt_internal_output; */
+    void* res_lrn_bwd[dnnResourceNumber];
+
+    // prepare mkl input
+    void MklPrepareLRNInputsLayouts(OpKernelContext* context) {
+      bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
+      bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
+      if (!ingrad_in_mkl_format) {
+        CHECK_EQ(dnnLayoutCreate_F32(&lt_input, in_dims, in_sizes, in_strides),
+                 E_SUCCESS);
+      } else {
+        lt_input = static_cast<dnnLayout_t>(ingrad_shape.GetCurLayout());
+      }
+
+      if (!inimage_in_mkl_format) {
+        CHECK_EQ(
+            dnnLayoutCreate_F32(&lt_output, in_dims, out_sizes, out_strides),
+            E_SUCCESS);
+      } else {
+        lt_output = static_cast<dnnLayout_t>(inimage_shape.GetCurLayout());
+      }
+    }
+
+    // convert input if needed
+    void MklPrepareLRNGradInput(OpKernelContext* context,
+                                Tensor* mkl_tmp_input_buf_tensor,
+                                Tensor* mkl_tmp_image_buf_tensor,
+                                Tensor* mkl_tmp_outimage_buf_tensor,
+                                Tensor* mkl_tmp_workspace_buf_tensor) {
+      const Tensor& in_grads = MklGetInput(context, 0);
+      const Tensor& in_image = MklGetInput(context, 1);
+      const Tensor& out_image = MklGetInput(context, 2);
+
+      void* user_input = const_cast<void*>(
+          static_cast<const void*>(in_grads.flat<T>().data()));
+      void* user_fwd_input = const_cast<void*>(
+          static_cast<const void*>(in_image.flat<T>().data()));
+      void* user_fwd_output = const_cast<void*>(
+          static_cast<const void*>(out_image.flat<T>().data()));
+      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, lrn_bwd,
+                                                dnnResourceWorkspace),
+               E_SUCCESS);
+      CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_bdw_input, lrn_bwd,
+                                                dnnResourceDiffDst),
+               E_SUCCESS);
+
+      bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
+      if (ingrad_in_mkl_format) {
+        if (!dnnLayoutCompare_F32(lt_bdw_input, lt_input)) {
+          AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_bdw_input,
+                         &res_lrn_bwd[dnnResourceDiffDst]);
+          ingrad_shape.GetConvertedFlatData(lt_bdw_input, user_input,
+                                            res_lrn_bwd[dnnResourceDiffDst]);
+        } else {
+          res_lrn_bwd[dnnResourceDiffDst] = user_input;
+        }
+      } else {
+        if (!dnnLayoutCompare_F32(lt_bdw_input, lt_input)) {
+          CHECK_EQ(
+              dnnConversionCreate_F32(&convert_input, lt_input, lt_bdw_input),
+              E_SUCCESS);
+
+          AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_bdw_input,
+                         &res_lrn_bwd[dnnResourceDiffDst]);
+          CHECK_EQ(dnnConversionExecute_F32(convert_input, user_input,
+                                            res_lrn_bwd[dnnResourceDiffDst]),
+                   E_SUCCESS);
+          dnnDelete_F32(convert_input);
+        } else {
+          res_lrn_bwd[dnnResourceDiffDst] = user_input;
+        }
+      }
+
+// Although MKL documentation for LRN does not specify setting/getting
+// of dnnResourceSrc and dnnResourceDst, Caffe code sets dnnResourceSrc.
+// So we set dnnResourceSrc here. But we do not know why we are setting
+// dnnResourceDst.
+#if 0
+    // NOTE: The code below is kept just so that we know how we should handle
+    // dnnResourceSrc if the primitive layout for dnnResourceSrc was supported.
+
+    if (!dnnLayoutCompare_F32(lt_internal_input,
+         static_cast<dnnLayout_t>inimage_shape.GetCurLayout())) {
+      AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
+                     &res_lrn_bwd[dnnResourceSrc]);
+      inimage_shape.GetConvertedFlatData(lt_internal_input,
+                                           user_fwd_input,
+                                           res_lrn_bwd[dnnResourceSrc]);
+    } else {
+      res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
+    }
+#endif
+
+      // Since we cannot get expected layout for dnnResourceSrc, we construct
+      // buffer using
+      // MKL format if input is in MKL format.
+      if (inimage_shape.IsMklTensor()) {
+        AllocTmpBuffer(context, mkl_tmp_image_buf_tensor,
+                       (dnnLayout_t)inimage_shape.GetCurLayout(),
+                       &res_lrn_bwd[dnnResourceSrc]);
+      } else {
+        res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
+      }
+
+      // Same comment as above.
+      if (outimage_shape.IsMklTensor()) {
+        AllocTmpBuffer(context, mkl_tmp_outimage_buf_tensor,
+                       (dnnLayout_t)outimage_shape.GetCurLayout(),
+                       &res_lrn_bwd[dnnResourceDst]);
+      } else {
+        res_lrn_bwd[dnnResourceDst] = user_fwd_output;
+      }
+
+      // Allocate buffer for workspace.
+      AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor, lt_workspace,
+                     &res_lrn_bwd[dnnResourceWorkspace]);
+    }
+
+    // Fallback implementation - Taken from lrn_op.cc
+    // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
+    // copy.
+    void MklDefaultToEigen(OpKernelContext* context) {
+      // CHECK(false);
+      Tensor in_grads = MklGetInput(context, 0);
+      Tensor in_image = MklGetInput(context, 1);
+      Tensor out_image = MklGetInput(context, 2);
+
+      GetMklShape(context, 0, &ingrad_shape);
+      GetMklShape(context, 1, &inimage_shape);
+      GetMklShape(context, 2, &outimage_shape);
+
+      const int64 batch = static_cast<int64>(in_grads.dim_size(0));
+      const int64 rows = static_cast<int64>(in_grads.dim_size(1));
+      const int64 cols = static_cast<int64>(in_grads.dim_size(2));
+      const int64 depth = static_cast<int64>(in_grads.dim_size(3));
+      const auto nodes = cols * rows;
+
+      auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
+      auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
+      auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
+
+      Tensor* output;
+      MklShape mkl_output_mkl_shape;
+      mkl_output_mkl_shape.SetMklTensor(false);
+      mkl_output_mkl_shape.SetDimensions(4);
+      AllocateOutputSetMklShape(context, 0, &output, in_grads.shape(),
+                                mkl_output_mkl_shape);
+
+      auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
+      out_shaped.setZero();
+      auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
+                    depth](int64 begin, int64 end) {
+        for (int64 i = begin; i < end; ++i) {
+          for (int64 j = 0; j < depth; ++j) {
+            int64 depth_begin = std::max<int64>(0, j - depth_radius_);
+            int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
+
+            T norm(0);
+            for (int64 k = depth_begin; k < depth_end; ++k) {
+              norm += in_shaped(i, k) * in_shaped(i, k);
+            }
+            norm = alpha_ * norm + bias_;
+            DCHECK_GT(norm, T(1e-6));
+            for (int64 k = depth_begin; k < depth_end; ++k) {
+              T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
+                      activations(i, j) / norm;
+              if (k == j) {
+                dyi += Eigen::numext::pow(norm, -beta_);
+              }
+              dyi *= grads_shaped(i, j);
+              const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) +=
+                  dyi;
+            }
+          }
+        }
+      };
+      auto worker_threads =
+          *(context->device()->tensorflow_cpu_worker_threads());
+      Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
+            depth * depth, shard);
+    }
+
+    // release mkl resources
+    void Mklcleanup() {
+      bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
+      bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
+      if (!ingrad_in_mkl_format) {
+        CHECK_EQ(dnnLayoutDelete_F32(lt_input), E_SUCCESS);
+      }
+
+      if (!inimage_in_mkl_format) {
+        CHECK_EQ(dnnLayoutDelete_F32(lt_output), E_SUCCESS);
+      }
+      dnnDelete_F32(lrn_bwd);
+      dnnLayoutDelete_F32(lt_bdw_input);
+      dnnLayoutDelete_F32(lt_workspace);
+    }
+  } MklLRNGradOpContext;
+
+  typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
+  bool workspace_enabled_;
+  int depth_radius_;
+  float bias_;
+  float alpha_;
+  float beta_;
+};
+
+#define REGISTER_MKL_LRN_CPU(T)                                     \
+  REGISTER_KERNEL_BUILDER(Name("_MklLRN")                            \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
+                          MklLRNOp<T>);                             \
+  REGISTER_KERNEL_BUILDER(Name("_MklLRNGrad")                        \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
+                          MklLRNGradOp<T>);
+
+TF_CALL_float(REGISTER_MKL_LRN_CPU);
+
+}  // namespace tensorflow
+
+#endif  // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index 9d6cfb0c97d..e27881f882d 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -83,10 +83,11 @@ class MklMaxPoolingOp : public OpKernel {
     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
 
     mkl_context.MklCreateLayoutsAndPrimitives(context);
+    OP_REQUIRES_OK(context, context->status());
 
     // Declare output tensor
     TensorShape tensor_out_shape;
-    MklShape mkl_out_shape;
+    MklShape mkl_out_shape, mkl_workspace_shape;
     mkl_out_shape.SetMklTensor(true);
     mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst);
     mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
@@ -98,31 +99,22 @@ class MklMaxPoolingOp : public OpKernel {
     tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
                                 mkl_out_shape.GetMklLayout())) /
                             sizeof(T));
-    AllocateOutputSetMklshape(context, 0, &output_tensor, tensor_out_shape,
+    AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape,
                               mkl_out_shape);
 
-    if (!workspace_enabled_) {
-      mkl_out_shape.SetMklTensor(false);
-    }
-
     Tensor* workspace_tensor;
     void* workspace_buf = nullptr;
-    if (workspace_enabled_) {
-      TensorShape workspace_shape;
-      workspace_shape.AddDim(
-          dnnLayoutGetMemorySize_F32(
-              static_cast<dnnLayout_t>(mkl_context.lt_workspace)) /
-          sizeof(T));
-      AllocateOutputSetMklshape(context, 1, &workspace_tensor, workspace_shape,
-                                mkl_out_shape);
-      mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
-          static_cast<const void*>(workspace_tensor->flat<T>().data()));
-    } else {
-      AllocTmpBuffer(context, workspace_tensor, mkl_context.lt_workspace,
-                     &workspace_buf);
-      mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf;
-    }
 
+    TensorShape workspace_shape;
+    mkl_workspace_shape.SetMklTensor(false);
+    workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
+                               mkl_context.lt_workspace)) /
+                           sizeof(T));
+    AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape,
+                              mkl_workspace_shape);
+
+    mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
+        static_cast<const void*>(workspace_tensor->flat<T>().data()));
     mkl_context.pooling_res[dnnResourceSrc] =
         const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data()));
     mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>(
@@ -140,8 +132,8 @@ class MklMaxPoolingOp : public OpKernel {
     MklPoolingOpParams params;
     MklShape input_shape;
     void* pooling_res[dnnResourceNumber];
-    dnnPrimitive_t prim_pooling_fwd;
-    dnnLayout_t lt_user_input, lt_workspace;
+    dnnPrimitive_t prim_pooling_fwd = nullptr;
+    dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr;
 
     void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
       bool input_in_mkl_format = input_shape.IsMklTensor();
@@ -256,8 +248,13 @@ class MklMaxPoolingGradOp : public OpKernel {
     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
 
     mkl_context.MklCreateLayouts(context);
+    OP_REQUIRES_OK(context, context->status());
+
     mkl_context.MklCreatePrimitives(context, workspace_enabled_);
+    OP_REQUIRES_OK(context, context->status());
+
     mkl_context.MklPrepareInputs(context, workspace_enabled_);
+    OP_REQUIRES_OK(context, context->status());
 
     // Create shape for the input back prop output
     TensorShape mkl_input_backprop;
@@ -274,7 +271,7 @@ class MklMaxPoolingGradOp : public OpKernel {
         dnnLayoutGetMemorySize_F32(
             static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) /
         sizeof(T));
-    AllocateOutputSetMklshape(context, 0, &output_tensor, mkl_input_backprop,
+    AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop,
                               mkl_output_shape);
     mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
         static_cast<const void*>(output_tensor->flat<T>().data()));
@@ -297,12 +294,15 @@ class MklMaxPoolingGradOp : public OpKernel {
     MklShape input_shape, output_backprop_shape;
     void* pooling_resfwd[dnnResourceNumber];
     void* pooling_res[dnnResourceNumber];
-    dnnPrimitive_t prim_pooling_fwd, prim_pooling_bwd, convert_input,
-        convert_outbackprop;
-    dnnLayout_t lt_outbackprop_user, lt_outbackprop_prim, lt_input_user,
-        lt_input_prim;
+    dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr,
+                   convert_input = nullptr, convert_outbackprop = nullptr;
+    dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr,
+                lt_input_user = nullptr, lt_input_prim = nullptr;
     void* input_buf;
     void* outbackprop_buf;
+    Tensor tmp_output_buf_tensor;
+    Tensor workspace_buf_tensor;
+    Tensor input_buf_tensor, outbackprop_buf_tensor;
 
     void MklCreateLayouts(OpKernelContext* context) {
       bool input_in_mkl_format = input_shape.IsMklTensor();
@@ -351,9 +351,6 @@ class MklMaxPoolingGradOp : public OpKernel {
                    &lt_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst),
                E_SUCCESS);
 
-      // Tensors needed to create temporary buffers
-      Tensor input_buf_tensor, outbackprop_buf_tensor;
-
       if (workspace_enabled == false) {
         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
                      &lt_input_prim, prim_pooling_fwd, dnnResourceSrc),
@@ -384,11 +381,8 @@ class MklMaxPoolingGradOp : public OpKernel {
       bool input_in_mkl_format = input_shape.IsMklTensor();
       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
 
-      void* tmp_output_buf;
-      Tensor tmp_output_buf_tensor;
-
-      void* workspace_buf;
-      Tensor workspace_buf_tensor;
+      void* tmp_output_buf = nullptr;
+      void* workspace_buf = nullptr;
 
       if (workspace_enabled == false) {
         if (convert_input != nullptr) {
@@ -490,16 +484,16 @@ class MklMaxPoolingGradOp : public OpKernel {
   bool workspace_enabled_;
 };
 
-REGISTER_KERNEL_BUILDER(Name("MklMaxPool")
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
                             .Device(DEVICE_CPU)
                             .TypeConstraint<float>("T")
-                            .Label(mkl_layer_registry::kMklLayerLabel),
+                            .Label(mkl_op_registry::kMklOpLabel),
                         MklMaxPoolingOp<CPUDevice, float>);
 
-REGISTER_KERNEL_BUILDER(Name("MklMaxPoolGrad")
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad")
                             .Device(DEVICE_CPU)
                             .TypeConstraint<float>("T")
-                            .Label(mkl_layer_registry::kMklLayerLabel),
+                            .Label(mkl_op_registry::kMklOpLabel),
                         MklMaxPoolingGradOp<CPUDevice, float>);
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 7809711524c..25c8359cc53 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -63,7 +63,7 @@ class MklReluOp : public OpKernel {
       const TensorShape& o_shape = input.shape();
       Tensor* out_tensor = nullptr;
       mkl_context.output_shape.SetMklTensor(false);
-      AllocateOutputSetMklshape(context, 0, &out_tensor, o_shape,
+      AllocateOutputSetMklShape(context, 0, &out_tensor, o_shape,
                                 mkl_context.output_shape);
       void* out_o = static_cast<void*>(out_tensor->flat<T>().data());
       (static_cast<T*>(out_o))[0] =
@@ -114,12 +114,12 @@ class MklReluOp : public OpKernel {
       tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
                           mkl_context.output_shape.GetMklLayout())) /
                       sizeof(T));
-      AllocateOutputSetMklshape(context, 0, &output, tf_shape,
+      AllocateOutputSetMklShape(context, 0, &output, tf_shape,
                                 mkl_context.output_shape);
     } else {
       const TensorShape& o_shape = input.shape();
       mkl_context.output_shape.SetMklTensor(false);
-      AllocateOutputSetMklshape(context, 0, &output, o_shape,
+      AllocateOutputSetMklShape(context, 0, &output, o_shape,
                                 mkl_context.output_shape);
     }
 
@@ -293,7 +293,7 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
     // Allocate space for g and
     const TensorShape& g_shape = g.shape();
     mkl_context.output_shape.SetMklTensor(false);
-    AllocateOutputSetMklshape(context, 0, &output, g_shape,
+    AllocateOutputSetMklShape(context, 0, &output, g_shape,
                               mkl_context.output_shape);
     void* out_o = static_cast<void*>(output->flat<T>().data());
     (static_cast<T*>(out_o))[0] =
@@ -359,13 +359,13 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
     tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
                         mkl_context.output_shape.GetMklLayout())) /
                     sizeof(T));
-    AllocateOutputSetMklshape(context, 0, &output, tf_shape,
+    AllocateOutputSetMklShape(context, 0, &output, tf_shape,
                               mkl_context.output_shape);
 
   } else {
     const TensorShape& o_shape = g.shape();
     mkl_context.output_shape.SetMklTensor(false);
-    AllocateOutputSetMklshape(context, 0, &output, o_shape,
+    AllocateOutputSetMklShape(context, 0, &output, o_shape,
                               mkl_context.output_shape);
   }
 
@@ -379,16 +379,16 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
 
 /* Register DNN kernels for supported operations and supported types - right now
  * it is only Relu and f32*/
-#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type)                   \
-  REGISTER_KERNEL_BUILDER(Name("MklRelu")                                 \
-                              .Device(DEVICE_CPU)                         \
-                              .TypeConstraint<type>("T")                  \
-                              .Label(mkl_layer_registry::kMklLayerLabel), \
-                          MklReluOp<CPUDevice, type>);                    \
-  REGISTER_KERNEL_BUILDER(Name("MklReluGrad")                             \
-                              .Device(DEVICE_CPU)                         \
-                              .TypeConstraint<type>("T")                  \
-                              .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type)             \
+  REGISTER_KERNEL_BUILDER(Name("_MklRelu")                           \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<type>("T")            \
+                              .Label(mkl_op_registry::kMklOpLabel), \
+                          MklReluOp<CPUDevice, type>);              \
+  REGISTER_KERNEL_BUILDER(Name("_MklReluGrad")                       \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<type>("T")            \
+                              .Label(mkl_op_registry::kMklOpLabel), \
                           MklReluGradOp<CPUDevice, type>);
 TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
 
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
new file mode 100644
index 00000000000..753a8b52b42
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -0,0 +1,149 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+#include <memory>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
+#include "tensorflow/core/util/mkl_util.h"
+
+namespace tensorflow {
+using CPUDevice = Eigen::ThreadPoolDevice;
+template <typename Device, typename T>
+class MklReshapeOp : public OpKernel {
+ public:
+  explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+  void Compute(OpKernelContext* context) override {
+    const Tensor& input = MklGetInput(context, 0);
+    const Tensor& sizes = MklGetInput(context, 1);
+
+    // Preliminary validation of sizes.
+    OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
+                errors::InvalidArgument("sizes input must be 1-D, not shape ",
+                                        sizes.shape().DebugString()));
+    const int64 num_dims = sizes.NumElements();
+
+    // Compute the output shape.  Determine product of specified
+    // dimensions, and find the index of the unspecified one.
+    TensorShape shape;
+    int64 product = 1;
+    int unknown_index = -1;
+    auto vec_size = sizes.flat<int32>();
+    for (int d = 0; d < num_dims; ++d) {
+      const int32 size = vec_size(d);
+      if (size == -1) {
+        OP_REQUIRES(
+            context, unknown_index == -1,
+            errors::InvalidArgument("only one input size may be -1, not both ",
+                                    unknown_index, " and ", d));
+        unknown_index = d;
+        shape.AddDim(1);
+      } else {
+        OP_REQUIRES(context, size >= 0,
+                    errors::InvalidArgument(
+                        "size ", d, " must be non-negative, not ", size));
+        shape.AddDim(size);
+        product *= size;
+      }
+    }
+    if (unknown_index != -1) {
+      OP_REQUIRES(
+          context, product > 0,
+          errors::InvalidArgument("Reshape cannot infer the missing input size "
+                                  "for an empty tensor unless all specified "
+                                  "input sizes are non-zero"));
+      const int64 missing = input.NumElements() / product;
+      OP_REQUIRES(
+          context, product * missing == input.NumElements(),
+          errors::InvalidArgument(
+              "Input to reshape is a tensor with ", input.NumElements(),
+              " values, but the requested shape requires a multiple of ",
+              product));
+      shape.set_dim(unknown_index, missing);
+    }
+    OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
+                errors::InvalidArgument("Input to reshape is a tensor with ",
+                                        input.NumElements(),
+                                        " values, but the requested shape has ",
+                                        shape.num_elements()));
+
+    MklShape mkl_shape_input;
+    GetMklShape(context, 0, &mkl_shape_input);
+    bool input_in_mkl_format = mkl_shape_input.IsMklTensor();
+    if (input_in_mkl_format) {
+      TensorShape& shape_to = shape;
+      TensorShape shape_from;
+      for (size_t i = 0; i < mkl_shape_input.GetDimension(); i++) {
+        // Outermost to innermost dimension
+        shape_from.AddDim(
+            mkl_shape_input.GetSizes()[mkl_shape_input.tf_dim_idx(i)]);
+      }
+
+      if (shape_from == shape_to) {
+        CopyMklTensorInToOut(context, 0, 0);
+        return;
+      } else {
+        // Allocate output tensor.
+        Tensor* output_tensor = NULL;
+        MklShape mkl_shape_output;
+        mkl_shape_output.SetMklTensor(false);
+        AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
+                                  mkl_shape_output);
+
+        // Get output layout pointer.
+        dnnLayout_t output_layout =
+            static_cast<dnnLayout_t>(mkl_shape_input.GetTfLayout());
+
+        // Execute DNNConversion.
+        // Note: we  assume an MKL tensor always have float as its data type.
+        void* input_buffer =
+            static_cast<void*>(const_cast<float*>(input.flat<float>().data()));
+        void* output_buffer = static_cast<void*>(
+            const_cast<float*>(output_tensor->flat<float>().data()));
+        mkl_shape_input.GetConvertedFlatData(output_layout, input_buffer,
+                                             output_buffer);
+
+        VLOG(1) << "MKLToTFConversion complete successfully.";
+        return;
+      }
+    } else {
+      CopyTFTensorInToOut(context, 0, 0, shape);
+    }
+  }
+};
+
+#define REGISTER_MKL_CPU(T)                                         \
+  REGISTER_KERNEL_BUILDER(Name("_MklReshape")                       \
+                              .Device(DEVICE_CPU)                   \
+                              .HostMemory("shape")                  \
+                              .TypeConstraint<T>("T")               \
+                              .TypeConstraint<int32>("Tshape")      \
+                              .Label(mkl_op_registry::kMklOpLabel), \
+                          MklReshapeOp<CPUDevice, T>);
+TF_CALL_float(REGISTER_MKL_CPU);
+#undef REGISTER_MKL_CPU
+}  // namespace tensorflow
+
+#endif  // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.cc b/tensorflow/core/kernels/mkl_tfconv_op.cc
index 51f90b3f901..c31ef5c2554 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.cc
+++ b/tensorflow/core/kernels/mkl_tfconv_op.cc
@@ -105,11 +105,11 @@ class MklToTfOp : public OpKernel {
 //               Register kernel
 ///////////////////////////////////////////////////////////
 
-#define REGISTER_CPU(T)                                                   \
-  REGISTER_KERNEL_BUILDER(Name("MklToTf")                                 \
-                              .Device(DEVICE_CPU)                         \
-                              .TypeConstraint<T>("T")                     \
-                              .Label(mkl_layer_registry::kMklLayerLabel), \
+#define REGISTER_CPU(T)                                             \
+  REGISTER_KERNEL_BUILDER(Name("MklToTf")                           \
+                              .Device(DEVICE_CPU)                   \
+                              .TypeConstraint<T>("T")               \
+                              .Label(mkl_op_registry::kMklOpLabel), \
                           MklToTfOp<CPUDevice, T>);
 
 TF_CALL_float(REGISTER_CPU);
diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc
index afa1f65aefa..56a7e161df4 100644
--- a/tensorflow/core/kernels/quantized_conv_ops.cc
+++ b/tensorflow/core/kernels/quantized_conv_ops.cc
@@ -233,9 +233,9 @@ class Im2ColConvFunctor {
     int filter_top_offset;
     if (padding == VALID) {
       filter_left_offset =
-          ((output_width - 1) * stride + filter_width - input_width) / 2;
+          ((output_width - 1) * stride + filter_width - input_width + 1) / 2;
       filter_top_offset =
-          ((output_height - 1) * stride + filter_height - input_height) / 2;
+          ((output_height - 1) * stride + filter_height - input_height + 1) / 2;
     } else {
       filter_left_offset =
           ((output_width - 1) * stride + filter_width - input_width) / 2;
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
index 2bcc7f407d4..30026f222a6 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
@@ -29,7 +29,7 @@ namespace tensorflow {
 typedef Eigen::ThreadPoolDevice CPUDevice;
 typedef Eigen::GpuDevice GPUDevice;
 
-template <typename Device, typename T>
+template <typename Device, typename T, typename Tindices>
 class SparseTensorDenseMatMulOp : public OpKernel {
  public:
   explicit SparseTensorDenseMatMulOp(OpKernelConstruction* ctx)
@@ -139,15 +139,14 @@ class SparseTensorDenseMatMulOp : public OpKernel {
                                              TensorShape({0}), &scratch));
     }
 
-#define MAYBE_ADJOINT(ADJ_A, ADJ_B)                                           \
-  if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) {                           \
-    Status functor_status = functor::SparseTensorDenseMatMulFunctor<          \
-        Device, T, ADJ_A, ADJ_B>::Compute(ctx->eigen_device<Device>(),        \
-                                          out->matrix<T>(),                   \
-                                          a_indices->matrix<int64>(),         \
-                                          a_values->vec<T>(), b->matrix<T>(), \
-                                          scratch.vec<T>());                  \
-    OP_REQUIRES_OK(ctx, functor_status);                                      \
+#define MAYBE_ADJOINT(ADJ_A, ADJ_B)                                        \
+  if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) {                        \
+    Status functor_status = functor::SparseTensorDenseMatMulFunctor<       \
+        Device, T, Tindices, ADJ_A,                                        \
+        ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(),     \
+                        a_indices->matrix<Tindices>(), a_values->vec<T>(), \
+                        b->matrix<T>(), scratch.vec<T>());                 \
+    OP_REQUIRES_OK(ctx, functor_status);                                   \
   }
 
     MAYBE_ADJOINT(false, false);
@@ -163,53 +162,73 @@ class SparseTensorDenseMatMulOp : public OpKernel {
   bool adjoint_b_;
 };
 
-#define REGISTER_CPU(T)                                   \
-  REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseMatMul") \
-                              .Device(DEVICE_CPU)         \
-                              .TypeConstraint<T>("T")     \
-                              .HostMemory("a_shape"),     \
-                          SparseTensorDenseMatMulOp<CPUDevice, T>);
+#define REGISTER_CPU(TypeT, TypeIndex)           \
+  REGISTER_KERNEL_BUILDER(                       \
+      Name("SparseTensorDenseMatMul")            \
+          .Device(DEVICE_CPU)                    \
+          .TypeConstraint<TypeT>("T")            \
+          .TypeConstraint<TypeIndex>("Tindices") \
+          .HostMemory("a_shape"),                \
+      SparseTensorDenseMatMulOp<CPUDevice, TypeT, TypeIndex>);
 
-REGISTER_CPU(float);
-REGISTER_CPU(double);
-REGISTER_CPU(int32);
-REGISTER_CPU(complex64);
-REGISTER_CPU(complex128);
+#define REGISTER_KERNELS_CPU(T) \
+  REGISTER_CPU(T, int64);       \
+  REGISTER_CPU(T, int32)
+
+REGISTER_KERNELS_CPU(float);
+REGISTER_KERNELS_CPU(double);
+REGISTER_KERNELS_CPU(int32);
+REGISTER_KERNELS_CPU(complex64);
+REGISTER_KERNELS_CPU(complex128);
 
 #if GOOGLE_CUDA
 
 namespace functor {
-#define DECLARE_GPU_SPEC(T, ADJ_A, ADJ_B)                                     \
-  template <>                                                                 \
-  Status SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B>::Compute( \
-      const GPUDevice& d, typename TTypes<T>::Matrix out,                     \
-      TTypes<int64>::ConstMatrix a_indices,                                   \
-      typename TTypes<T>::ConstVec a_values,                                  \
-      typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch);    \
-  extern template struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A,  \
-                                                        ADJ_B>;
+#define DECLARE_GPU_SPEC(T, Tindices, ADJ_A, ADJ_B)                       \
+  template <>                                                             \
+  Status SparseTensorDenseMatMulFunctor<                                  \
+      GPUDevice, T, Tindices, ADJ_A,                                      \
+      ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \
+                      typename TTypes<Tindices>::ConstMatrix a_indices,   \
+                      typename TTypes<T>::ConstVec a_values,              \
+                      typename TTypes<T>::ConstMatrix b,                  \
+                      typename TTypes<T>::Vec scratch);                   \
+  extern template struct SparseTensorDenseMatMulFunctor<                  \
+      GPUDevice, T, Tindices, ADJ_A, ADJ_B>;
 
-#define DECLARE_ADJOINT_GPU_SPEC(T) \
-  DECLARE_GPU_SPEC(T, false, false) \
-  DECLARE_GPU_SPEC(T, false, true)  \
-  DECLARE_GPU_SPEC(T, true, false)  \
-  DECLARE_GPU_SPEC(T, true, true)
+#define REGISTER_GPU_SPEC(T, ADJ_A, ADJ_B)  \
+  DECLARE_GPU_SPEC(T, int32, ADJ_A, ADJ_B); \
+  DECLARE_GPU_SPEC(T, int64, ADJ_A, ADJ_B)
+
+#define DECLARE_ADJOINT_GPU_SPEC(T)  \
+  REGISTER_GPU_SPEC(T, false, false) \
+  REGISTER_GPU_SPEC(T, false, true)  \
+  REGISTER_GPU_SPEC(T, true, false)  \
+  REGISTER_GPU_SPEC(T, true, true)
 
 DECLARE_ADJOINT_GPU_SPEC(float);
 #undef DECLARE_ADJOINT_GPU_SPEC
 #undef DECLARE_GPU_SPEC
+#undef REGISTER_GPU_SPEC
 
 }  // namespace functor
 
-#define REGISTER_GPU(T)                                   \
-  REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseMatMul") \
-                              .Device(DEVICE_GPU)         \
-                              .TypeConstraint<T>("T")     \
-                              .HostMemory("a_shape"),     \
-                          SparseTensorDenseMatMulOp<GPUDevice, T>);
+#define REGISTER_GPU(TypeT, TypeIndex)           \
+  REGISTER_KERNEL_BUILDER(                       \
+      Name("SparseTensorDenseMatMul")            \
+          .Device(DEVICE_GPU)                    \
+          .TypeConstraint<TypeT>("T")            \
+          .TypeConstraint<TypeIndex>("Tindices") \
+          .HostMemory("a_shape"),                \
+      SparseTensorDenseMatMulOp<GPUDevice, TypeT, TypeIndex>);
 
-REGISTER_GPU(float);
+#define REGISTER_KERNELS_GPU(T) \
+  REGISTER_GPU(T, int64);       \
+  REGISTER_GPU(T, int32)
+
+REGISTER_KERNELS_GPU(float);
 #undef REGISTER_GPU
+#undef REGISTER_KERNELS_GPU
 #endif  // GOOGLE_CUDA
 
 namespace functor {
@@ -228,13 +247,13 @@ Status MOutOfBoundsError(int64 m, std::size_t i, int lhs_index_a,
 }
 }  // namespace
 
-template <typename T, bool ADJ_A, bool ADJ_B>
-struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
+template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
+struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
   // Vectorize certain operations above this size.
   static const std::size_t kNumVectorize = 32;
 
   static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out,
-                        TTypes<int64>::ConstMatrix a_indices,
+                        typename TTypes<Tindices>::ConstMatrix a_indices,
                         typename TTypes<T>::ConstVec a_values,
                         typename TTypes<T>::ConstMatrix b,
                         typename TTypes<T>::Vec scratch) {
@@ -255,8 +274,8 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
       auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b);
 
       for (std::size_t i = 0; i < nnz; ++i) {
-        const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
-        const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
+        const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
+        const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
         if (!FastBoundsCheck(k, lhs_right)) {
           return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);
         }
@@ -273,19 +292,19 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
       // Vectorization via Eigen.
       const int b_chip_index = ADJ_B ? 1 : 0;
 
-#define LOOP_NNZ(b_passed)                                               \
-  for (std::size_t i = 0; i < nnz; ++i) {                                \
-    const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
-    const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
-    const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i);    \
-    if (!FastBoundsCheck(k, lhs_right)) {                                \
-      return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);            \
-    }                                                                    \
-    if (!FastBoundsCheck(m, out.dimension(0))) {                         \
-      return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));     \
-    }                                                                    \
-    out.template chip<0>(m) +=                                           \
-        b_passed.template chip<b_chip_index>(k) * a_value;               \
+#define LOOP_NNZ(b_passed)                                                  \
+  for (std::size_t i = 0; i < nnz; ++i) {                                   \
+    const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
+    const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
+    const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i);       \
+    if (!FastBoundsCheck(k, lhs_right)) {                                   \
+      return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);               \
+    }                                                                       \
+    if (!FastBoundsCheck(m, out.dimension(0))) {                            \
+      return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0));        \
+    }                                                                       \
+    out.template chip<0>(m) +=                                              \
+        b_passed.template chip<b_chip_index>(k) * a_value;                  \
   }
 
       if (ADJ_B) {
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
index bcb836367b9..e707743f782 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
@@ -25,11 +25,12 @@ namespace tensorflow {
 
 namespace functor {
 
-template <typename Device, typename T, bool ADJ_A, bool ADJ_B>
+template <typename Device, typename T, typename Tindices, bool ADJ_A,
+          bool ADJ_B>
 struct SparseTensorDenseMatMulFunctor {
   static EIGEN_ALWAYS_INLINE Status
   Compute(const Device& d, typename TTypes<T>::Matrix out,
-          TTypes<int64>::ConstMatrix a_indices,
+          typename TTypes<Tindices>::ConstMatrix a_indices,
           typename TTypes<T>::ConstVec a_values,
           typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch);
 };
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
index 07d218311eb..7266e0cf812 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
@@ -27,12 +27,12 @@ typedef Eigen::GpuDevice GPUDevice;
 
 namespace generator {
 
-template <typename T, bool ADJ_A, bool ADJ_B>
+template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
 class SparseTensorDenseMatMulGPUGenerator {
  public:
   EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseTensorDenseMatMulGPUGenerator(
       typename TTypes<T, 2>::Tensor32Bit out,
-      TTypes<const int64, 2>::Tensor32Bit a_indices,
+      typename TTypes<const Tindices, 2>::Tensor32Bit a_indices,
       typename TTypes<const T, 1>::Tensor32Bit a_values,
       typename TTypes<const T, 2>::Tensor32Bit b)
       : out_(out),
@@ -77,7 +77,7 @@ class SparseTensorDenseMatMulGPUGenerator {
   mutable typename TTypes<T, 2>::Tensor32Bit out_;
   const int lhs_index_a_;
   const int rhs_index_a_;
-  TTypes<const int64, 2>::Tensor32Bit a_indices_;
+  typename TTypes<const Tindices, 2>::Tensor32Bit a_indices_;
   typename TTypes<const T, 1>::Tensor32Bit a_values_;
   const int lhs_right_size;
   functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit, ADJ_B>
@@ -88,14 +88,14 @@ class SparseTensorDenseMatMulGPUGenerator {
 
 namespace functor {
 
-template <typename T, bool ADJ_A, bool ADJ_B>
-struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B> {
+template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
+struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
   static EIGEN_ALWAYS_INLINE Status
   Compute(const GPUDevice& d, typename TTypes<T>::Matrix out,
-          TTypes<int64>::ConstMatrix a_indices,
+          typename TTypes<Tindices>::ConstMatrix a_indices,
           typename TTypes<T>::ConstVec a_values,
           typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch) {
-    generator::SparseTensorDenseMatMulGPUGenerator<T, ADJ_A, ADJ_B>
+    generator::SparseTensorDenseMatMulGPUGenerator<T, Tindices, ADJ_A, ADJ_B>
         sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices),
                                              To32Bit(a_values), To32Bit(b));
     To32Bit(out).device(d) = To32Bit(out).constant(T(0));
@@ -146,17 +146,18 @@ struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B> {
 
 }  // namespace functor
 
-#define DEFINE(T)                                                              \
-  template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, false, \
-                                                          false>;              \
-  template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, false, \
-                                                          true>;               \
-  template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, true,  \
-                                                          false>;              \
-  template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, true,  \
-                                                          true>;
+#define DEFINE(T, Tindices)                                \
+  template struct functor::SparseTensorDenseMatMulFunctor< \
+      GPUDevice, T, Tindices, false, false>;               \
+  template struct functor::SparseTensorDenseMatMulFunctor< \
+      GPUDevice, T, Tindices, false, true>;                \
+  template struct functor::SparseTensorDenseMatMulFunctor< \
+      GPUDevice, T, Tindices, true, false>;                \
+  template struct functor::SparseTensorDenseMatMulFunctor< \
+      GPUDevice, T, Tindices, true, true>;
 
-DEFINE(float);
+DEFINE(float, int32);
+DEFINE(float, int64);
 #undef DEFINE
 
 }  // end namespace tensorflow
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 11df3c43c7a..e540ecfa8d9 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -394,6 +394,28 @@ output: A `Tensor` with the concatenation of values stacked along the
   in `concat_dim` where it has the sum of the sizes.
 )doc");
 
+// TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops
+// are not to be made user-accessible.
+#ifdef INTEL_MKL
+REGISTER_OP("_MklConcatV2")
+    .Input("values: N * T")
+    .Input("axis: Tidx")
+    .Input("mkl_values: N * uint8")
+    .Input("mkl_axis: uint8")
+    .Output("output: T")
+    .Output("mkl_output: uint8")
+    .Attr("N: int >= 2")
+    .Attr("T: type")
+    .Attr("Tidx: {int32, int64} = DT_INT32")
+    .SetShapeFn(shape_inference::ConcatV2Shape)
+    .Doc(R"doc(
+MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+#endif
+
 REGISTER_OP("ConcatOffset")
     .Input("concat_dim: int32")
     .Input("shape: N * int32")
@@ -1638,6 +1660,21 @@ reshape(t, []) ==> 7
 shape: Defines the shape of the output tensor.
 )Doc");
 
+#ifdef INTEL_MKL
+REGISTER_OP("_MklReshape")
+    .Input("tensor: T")
+    .Input("shape: Tshape")
+    .Input("mkl_tensor: uint8")
+    .Input("mkl_shape: uint8")
+    .Output("output: T")
+    .Output("mkl_output: uint8")
+    .Attr("T: type")
+    .Attr("Tshape: {int32, int64} = DT_INT32")
+    .SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
+    .Doc(R"Doc( MKL implementation of ReshapeOp.
+)Doc");
+#endif  // INTEL_MKL
+
 // --------------------------------------------------------------------------
 REGISTER_OP("InvertPermutation")
     .Input("x: T")
@@ -4965,6 +5002,27 @@ backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`:
   `sum_per_d(gradients * (inputs > max))`.
 )doc");
 
+#ifdef INTEL_MKL
+REGISTER_OP("_MklConcat")
+    .Input("concat_dim: int32")
+    .Input("values: N * T")
+    .Input("mkl_concat_dim: uint8")
+    .Input("mkl_values: N * uint8")
+    .Output("output: T")
+    .Output("mkl_output: uint8")
+    .Attr("N: int >= 2")
+    .Attr("T: type")
+    .SetShapeFn([](InferenceContext* c) {
+      return shape_inference::ConcatShape(c, c->num_inputs() - 3);
+    })
+    .Doc(R"doc(
+MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+#endif
+
 // Deprecated op registrations:
 
 // The following can be deleted after 10mar2017.
diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc
index 3e2583f7060..0bce6fc0ea8 100644
--- a/tensorflow/core/ops/io_ops.cc
+++ b/tensorflow/core/ops/io_ops.cc
@@ -440,6 +440,7 @@ REGISTER_OP("FixedLengthRecordReader")
     .Attr("header_bytes: int = 0")
     .Attr("record_bytes: int")
     .Attr("footer_bytes: int = 0")
+    .Attr("hop_bytes: int = 0")
     .Attr("container: string = ''")
     .Attr("shared_name: string = ''")
     .SetIsStateful()
@@ -448,6 +449,11 @@ REGISTER_OP("FixedLengthRecordReader")
 A Reader that outputs fixed-length records from a file.
 
 reader_handle: The handle to reference the Reader.
+header_bytes: Number of bytes in the header, defaults to 0.
+record_bytes: Number of bytes in the record.
+footer_bytes: Number of bytes in the footer, defaults to 0.
+hop_bytes: Number of bytes to hop before each read. Default of 0 means using
+        record_bytes.
 container: If non-empty, this reader is placed in the given container.
         Otherwise, a default container is used.
 shared_name: If non-empty, this reader is named in the given bucket
@@ -459,6 +465,7 @@ REGISTER_OP("FixedLengthRecordReaderV2")
     .Attr("header_bytes: int = 0")
     .Attr("record_bytes: int")
     .Attr("footer_bytes: int = 0")
+    .Attr("hop_bytes: int = 0")
     .Attr("container: string = ''")
     .Attr("shared_name: string = ''")
     .SetIsStateful()
@@ -467,6 +474,11 @@ REGISTER_OP("FixedLengthRecordReaderV2")
 A Reader that outputs fixed-length records from a file.
 
 reader_handle: The handle to reference the Reader.
+header_bytes: Number of bytes in the header, defaults to 0.
+record_bytes: Number of bytes in the record.
+footer_bytes: Number of bytes in the footer, defaults to 0.
+hop_bytes: Number of bytes to hop before each read. Default of 0 means using
+        record_bytes.
 container: If non-empty, this reader is placed in the given container.
         Otherwise, a default container is used.
 shared_name: If non-empty, this reader is named in the given bucket
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index e9d5897af04..932113bf2c4 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -2612,10 +2612,10 @@ scale_after_normalization: A bool indicating whether the resulted tensor
 )doc");
 
 #ifdef INTEL_MKL
-REGISTER_OP("MklConv2D")
+REGISTER_OP("_MklConv2D")
     .Input("input: T")
-    .Input("mkl_input: uint8")
     .Input("filter: T")
+    .Input("mkl_input: uint8")
     .Input("mkl_filter: uint8")
     .Output("output: T")
     .Output("mkl_output: uint8")
@@ -2632,12 +2632,12 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklConv2DWithBias")
+REGISTER_OP("_MklConv2DWithBias")
     .Input("input: T")
-    .Input("mkl_input: uint8")
     .Input("filter: T")
-    .Input("mkl_filter: uint8")
     .Input("bias: T")
+    .Input("mkl_input: uint8")
+    .Input("mkl_filter: uint8")
     .Input("mkl_bias: uint8")
     .Output("output: T")
     .Output("mkl_output: uint8")
@@ -2654,12 +2654,12 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklConv2DBackpropFilter")
+REGISTER_OP("_MklConv2DBackpropFilter")
     .Input("input: T")
-    .Input("mkl_input: uint8")
     .Input("filter_sizes: int32")
-    .Input("mkl_filter_size: uint8")
     .Input("out_backprop: T")
+    .Input("mkl_input: uint8")
+    .Input("mkl_filter_size: uint8")
     .Input("mkl_out_backprop: uint8")
     .Output("output: T")
     .Output("mkl_output: uint8")
@@ -2669,7 +2669,7 @@ REGISTER_OP("MklConv2DBackpropFilter")
     .Attr(GetPaddingAttrString())
     .Attr(GetConvnetDataFormatAttrString())
     .SetShapeFn([](InferenceContext* c) {
-      return InputTensorShapeOrUnknown(c, 2 /* input_idx */, 4 /* ndims */);
+      return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
     })
     .Doc(R"doc(
 MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
@@ -2679,7 +2679,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklConv2DWithBiasBackpropBias")
+REGISTER_OP("_MklConv2DWithBiasBackpropBias")
     .Input("out_backprop: T")
     .Input("mkl_out_backprop: uint8")
     .Output("output: T")
@@ -2695,12 +2695,12 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklConv2DBackpropInput")
+REGISTER_OP("_MklConv2DBackpropInput")
     .Input("input_sizes: int32")
-    .Input("mkl_input_sizes: uint8")
     .Input("filter: T")
-    .Input("mkl_filter: uint8")
     .Input("out_backprop: T")
+    .Input("mkl_input_sizes: uint8")
+    .Input("mkl_filter: uint8")
     .Input("mkl_out_backprop: uint8")
     .Output("output: T")
     .Output("mkl_output: uint8")
@@ -2720,7 +2720,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklRelu")
+REGISTER_OP("_MklRelu")
     .Input("features: T")
     .Input("mkl_features: uint8")
     .Output("activations: T")
@@ -2734,10 +2734,10 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklReluGrad")
+REGISTER_OP("_MklReluGrad")
     .Input("gradients: T")
-    .Input("mkl_gradients: uint8")
     .Input("features: T")
+    .Input("mkl_gradients: uint8")
     .Input("mkl_features: uint8")
     .Output("backprops: T")
     .Output("mkl_backprops: uint8")
@@ -2751,7 +2751,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklMaxPool")
+REGISTER_OP("_MklMaxPool")
     .Attr("T: {float, half} = DT_FLOAT")
     .Attr("ksize: list(int) >= 4")
     .Attr("strides: list(int) >= 4")
@@ -2761,8 +2761,8 @@ REGISTER_OP("MklMaxPool")
     .Input("input: T")
     .Input("mkl_input: uint8")
     .Output("output: T")
-    .Output("mkl_output: uint8")
     .Output("workspace: T")
+    .Output("mkl_output: uint8")
     .Output("mkl_workspace: uint8")
     .SetShapeFn(shape_inference::MaxPoolShape)
     .Doc(R"doc(
@@ -2773,7 +2773,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklMaxPoolGrad")
+REGISTER_OP("_MklMaxPoolGrad")
     .Attr("T: {float, half} = DT_FLOAT")
     .Attr("ksize: list(int) >= 4")
     .Attr("strides: list(int) >= 4")
@@ -2781,12 +2781,12 @@ REGISTER_OP("MklMaxPoolGrad")
     .Attr(GetPaddingAttrString())
     .Attr(GetConvnetDataFormatAttrString())
     .Input("orig_input: T")
-    .Input("mkl_orig_input: uint8")
     .Input("orig_output: T")
-    .Input("mkl_orig_output: uint8")
     .Input("grad: T")
-    .Input("mkl_grad: uint8")
     .Input("workspace: T")
+    .Input("mkl_orig_input: uint8")
+    .Input("mkl_orig_output: uint8")
+    .Input("mkl_grad: uint8")
     .Input("mkl_workspace: uint8")
     .Output("output: T")
     .Output("mkl_output: uint8")
@@ -2801,7 +2801,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklAvgPool")
+REGISTER_OP("_MklAvgPool")
     .Input("value: T")
     .Input("mkl_input: uint8")
     .Output("output: T")
@@ -2820,10 +2820,10 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklAvgPoolGrad")
+REGISTER_OP("_MklAvgPoolGrad")
     .Input("orig_input_shape: int32")
-    .Input("mkl_orig_input: uint8")
     .Input("grad: T")
+    .Input("mkl_orig_input: uint8")
     .Input("mkl_grad: uint8")
     .Output("output: T")
     .Output("mkl_output: uint8")
@@ -2843,7 +2843,212 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
 expected to invoke these operators.
 )doc");
 
-REGISTER_OP("MklToTf")
+REGISTER_OP("_MklLRN")
+    .Input("input: T")
+    .Input("mkl_input: uint8")
+    .Output("output: T")
+    .Output("workspace: T")
+    .Output("mkl_output: uint8")
+    .Output("mkl_workspace: uint8")
+    .Attr("depth_radius: int = 5")
+    .Attr("bias: float = 1.0")
+    .Attr("alpha: float = 1.0")
+    .Attr("beta: float = 0.5")
+    .Attr("workspace_enabled: bool = false")
+    .Attr("T: {float, half} = DT_FLOAT")
+    .SetShapeFn([](InferenceContext* c) {
+      return UnchangedShapeWithRank(c, 4);
+    })
+    .Doc(R"doc(
+MKL version of LRN operator. Uses MKL DNN APIs to perform local response
+normalization.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklLRNGrad")
+    .Input("input_grads: T")
+    .Input("input_image: T")
+    .Input("output_image: T")
+    .Input("workspace: T")
+    .Input("mkl_input_grads: uint8")
+    .Input("mkl_input_image: uint8")
+    .Input("mkl_output_image: uint8")
+    .Input("mkl_workspace: uint8")
+    .Output("output: T")
+    .Output("mkl_output: uint8")
+    .Attr("depth_radius: int = 5")
+    .Attr("bias: float = 1.0")
+    .Attr("alpha: float = 1.0")
+    .Attr("beta: float = 0.5")
+    .Attr("workspace_enabled: bool = false")
+    .Attr("T: {float, half} = DT_FLOAT")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle s;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s));  // input_grads
+      TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s));     // input_image
+      TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s));     // output_image
+      c->set_output(0, s);
+      return Status::OK();
+    })
+    .Doc(R"doc(
+MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for
+local response normalization.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklFusedBatchNorm")
+    .Input("x: T")
+    .Input("scale: T")
+    .Input("offset: T")
+    .Input("mean: T")
+    .Input("variance: T")
+    .Input("mkl_x: uint8")
+    .Input("mkl_scale: uint8")
+    .Input("mkl_offset: uint8")
+    .Input("mkl_mean: uint8")
+    .Input("mkl_variance: uint8")
+    .Output("y: T")
+    .Output("batch_mean: T")
+    .Output("batch_variance: T")
+    .Output("reserve_space_1: T")
+    .Output("reserve_space_2: T")
+    .Output("mkl_y: uint8")
+    .Output("mkl_batch_mean: uint8")
+    .Output("mkl_batch_variance: uint8")
+    .Output("mkl_reserve_space_1: uint8")
+    .Output("mkl_reserve_space_2: uint8")
+    .Attr("T: numbertype")
+    .Attr("epsilon: float = 0.0001")
+    .Attr("data_format: string = 'NHWC'")
+    .Attr("is_training: bool = true")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle x;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
+
+      bool is_training;
+      c->GetAttr("is_training", &is_training);
+      int number_inputs = (is_training) ? 3 : 5;
+      string data_format;
+      c->GetAttr("data_format", &data_format);
+      DimensionHandle channel_dim =
+          (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
+
+      // covers scale, offset, and if is_training is false, mean, variance
+      for (int i = 1; i < number_inputs; ++i) {
+        ShapeHandle vec;
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
+        TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
+      }
+
+      ShapeHandle y;
+      if (data_format == "NHWC") {
+        TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
+      } else {
+        TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
+      }
+      c->set_output(0, y);
+      ShapeHandle vector_shape = c->Vector(channel_dim);
+      c->set_output(1, vector_shape);
+      c->set_output(2, vector_shape);
+      c->set_output(3, vector_shape);
+      c->set_output(4, vector_shape);
+      return Status::OK();
+    })
+    .Doc(R"doc(
+MKL version of FusedBatchNorm operator. Uses MKL DNN APIs to perform fused
+batch normalization.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklFusedBatchNormGrad")
+    .Input("y_backprop: T")
+    .Input("x: T")
+    .Input("scale: T")
+    .Input("reserve_space_1: T")
+    .Input("reserve_space_2: T")
+    .Input("mkl_y_backprop: uint8")
+    .Input("mkl_x: uint8")
+    .Input("mkl_scale: uint8")
+    .Input("mkl_reserve_space_1: uint8")
+    .Input("mkl_reserve_space_2: uint8")
+    .Output("x_backprop: T")
+    .Output("scale_backprop: T")
+    .Output("offset_backprop: T")
+    .Output("reserve_space_3: T")
+    .Output("reserve_space_4: T")
+    .Output("mkl_x_backprop: uint8")
+    .Output("mkl_scale_backprop: uint8")
+    .Output("mkl_offset_backprop: uint8")
+    .Output("mkl_reserve_space_3: uint8")
+    .Output("mkl_reserve_space_4: uint8")
+    .Attr("T: numbertype")
+    .Attr("epsilon: float = 0.0001")
+    .Attr("data_format: string = 'NHWC'")
+    .Attr("is_training: bool = true")
+    .SetShapeFn([](InferenceContext* c) {
+      ShapeHandle y_backprop;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
+      ShapeHandle x;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
+
+      bool is_training;
+      string data_format;
+      c->GetAttr("is_training", &is_training);
+      c->GetAttr("data_format", &data_format);
+      DimensionHandle channel_dim = (data_format == "NHWC")
+                                        ? c->Dim(y_backprop, 3)
+                                        : c->Dim(y_backprop, 1);
+      if (data_format == "NHWC") {
+        TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
+      } else {
+        TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
+      }
+
+      // covers scale, mean (reserve_space_1), variance (reserve_space_2)
+      for (int i = 2; i < 5; ++i) {
+        ShapeHandle vec;
+        TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
+        TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
+      }
+
+      ShapeHandle x_backprop;
+      if (data_format == "NHWC") {
+        TF_RETURN_IF_ERROR(
+            c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
+      } else {
+        TF_RETURN_IF_ERROR(
+            c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
+      }
+      c->set_output(0, x_backprop);
+      c->set_output(1, c->Vector(channel_dim));
+      c->set_output(2, c->Vector(channel_dim));
+      // Set the correct shapes for reserve_spaces
+      // so that gradients can be performed when
+      // the op is in a symbolic condition.
+      if (is_training) {
+        c->set_output(3, c->Vector(0));
+        c->set_output(4, c->Vector(0));
+      } else {
+        c->set_output(3, c->Vector(channel_dim));
+        c->set_output(4, c->Vector(channel_dim));
+      }
+      return Status::OK();
+    })
+    .Doc(R"doc(
+MKL version of FusedBatchNormGrad operator. Uses MKL DNN APIs to compute
+gradients for fused batch normalization.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklToTf")
     .Input("input: T")
     .Input("mkl_input: uint8")
     .Output("output: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 6d28cb7e840..cbbabe0b876 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -26416,6 +26416,59 @@ op {
   summary: "Computes the sum along segments of a tensor."
   description: "Read @{$math_ops#segmentation$the section on segmentation} for an explanation of\nsegments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`.  Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"https://www.tensorflow.org/images/UnsortedSegmentSum.png\" alt>\n</div>"
 }
+op {
+  name: "UnsortedSegmentSum"
+  input_arg {
+    name: "data"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "segment_ids"
+    description: "A tensor whose shape is a prefix of `data.shape`."
+    type_attr: "Tindices"
+  }
+  input_arg {
+    name: "num_segments"
+    type: DT_INT32
+  }
+  output_arg {
+    name: "output"
+    description: "Has same shape as data, except for the first `segment_ids.rank`\ndimensions, which are replaced with a single dimension which has size\n`num_segments`."
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_FLOAT
+        type: DT_DOUBLE
+        type: DT_INT64
+        type: DT_INT32
+        type: DT_UINT8
+        type: DT_UINT16
+        type: DT_INT16
+        type: DT_INT8
+        type: DT_QINT8
+        type: DT_QUINT8
+        type: DT_QINT32
+        type: DT_HALF
+      }
+    }
+  }
+  attr {
+    name: "Tindices"
+    type: "type"
+    allowed_values {
+      list {
+        type: DT_INT32
+        type: DT_INT64
+      }
+    }
+  }
+  summary: "Computes the max along segments of a tensor."
+  description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n  range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
+}
 op {
   name: "Unstage"
   output_arg {
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index 860b3475e93..b90f7a5dfb8 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -128,12 +128,13 @@ pair takes space.
 )doc");
 
 REGISTER_OP("SparseTensorDenseMatMul")
-    .Input("a_indices: int64")
+    .Input("a_indices: Tindices")
     .Input("a_values: T")
     .Input("a_shape: int64")
     .Input("b: T")
     .Output("product: T")
     .Attr("T: type")
+    .Attr("Tindices: {int32,int64} = DT_INT64")
     .Attr("adjoint_a: bool = false")
     .Attr("adjoint_b: bool = false")
     .SetShapeFn([](InferenceContext* c) {
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 58ccda5c9bb..10414cbca26 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -194,13 +194,15 @@ def tf_kernel_tests_linkstatic():
 
 def tf_additional_lib_defines():
   return select({
-      "//tensorflow:with_jemalloc": ["TENSORFLOW_USE_JEMALLOC"],
+      "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
+      "//tensorflow:with_jemalloc_linux_ppc64le":["TENSORFLOW_USE_JEMALLOC"],
       "//conditions:default": [],
   })
 
 def tf_additional_lib_deps():
   return select({
-      "//tensorflow:with_jemalloc": ["@jemalloc"],
+      "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc"],
+      "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc"],
       "//conditions:default": [],
   })
 
@@ -246,3 +248,9 @@ def tf_lib_proto_parsing_deps():
       ":protos_all_cc",
       "//tensorflow/core/platform/default/build_config:proto_parsing",
   ]
+
+def tf_additional_verbs_lib_defines():
+  return select({
+      "//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
+      "//conditions:default": [],
+  })
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 79f97c12347..eb804bfc786 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -22,3 +22,11 @@ def tf_additional_license_deps():
       "//tensorflow:with_xla_support": ["@llvm//:LICENSE.TXT"],
       "//conditions:default": [],
   })
+
+def tf_additional_verbs_deps():
+  return select({
+      "//tensorflow:with_verbs_support": [
+      "//tensorflow/contrib/verbs:verbs_server_lib",
+      "//tensorflow/contrib/verbs:grpc_verbs_client"], 
+      "//conditions:default": [],
+  })
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index dfffbfa396c..df33cf38c97 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -20,11 +20,11 @@ limitations under the License.
 
 #define TF_MAJOR_VERSION 1
 #define TF_MINOR_VERSION 1
-#define TF_PATCH_VERSION 0-rc1
+#define TF_PATCH_VERSION 0
 
 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
 // "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-rc2"
 
 #define TF_STR_HELPER(x) #x
 #define TF_STR(x) TF_STR_HELPER(x)
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index ebbe195bbc9..897b174eff2 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -75,7 +75,6 @@ class MklShape {
   void SetTfLayout(const size_t dimension, const size_t* sizes,
                    const size_t* strides) {
     dimension_ = dimension;
-
     if (dimension > 0) {  // MKl doesn't support zero dimension tensors
       sizes_ = new size_t[dimension];
       strides_ = new size_t[dimension];
@@ -140,6 +139,39 @@ class MklShape {
   const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; }
   size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; }
 
+  // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+  // corresponds to MKL's Channel dimension.
+  bool IsMklChannelDim(int d) const { return tf_dim_idx(d) == MklDims::C; }
+  // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+  // corresponds to MKL's Batch dimension.
+  bool IsMklBatchDim(int d) const { return tf_dim_idx(d) == MklDims::N; }
+  // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+  // corresponds to MKL's Width dimension.
+  bool IsMklWidthDim(int d) const { return tf_dim_idx(d) == MklDims::W; }
+  // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+  // corresponds to MKL's Height dimension.
+  bool IsMklHeightDim(int d) const { return tf_dim_idx(d) == MklDims::H; }
+
+  // Check if the TF-Mkl dimension ordering map specifies if the input
+  // tensor is in NCHW format.
+  bool IsTensorInNCHWFormat() const {
+    TensorFormat data_format = FORMAT_NCHW;
+    return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
+            IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
+            IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
+            IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
+  }
+
+  // Check if the TF-Mkl dimension ordering map specifies if the input
+  // tensor is in NHWC format.
+  bool IsTensorInNHWCFormat() const {
+    TensorFormat data_format = FORMAT_NHWC;
+    return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
+            IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
+            IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
+            IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
+  }
+
   void GetConvertedFlatData(dnnLayout_t targetLayout, void* input,
                             void* output) const {
     dnnLayout_t curLayout;
@@ -194,9 +226,9 @@ class MklShape {
   (STRIDES_OFFSET(dims) + dims * sizeof(size_t))  // Location of mklLayout_
 #define TF_LAYOUT_OFFSET(dims) \
   (MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)  // Location of tfLayout_
-// Location of tf_to_mkl_dim_map_
 #define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
-  (TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
+  (TF_LAYOUT_OFFSET(dims) +            \
+   SIZE_OF_MKL_DNN_BUF)  // Location of tf_to_mkl_dim_map_
 
   // TODO(agramesh1) make sure to create a const to share with rewrite pass
   // for min size of MKL metadata tensor.
@@ -265,45 +297,166 @@ class MklShape {
   size_t dimension_ = 0;
   size_t* sizes_ = nullptr;    // Required by MKL for conversions
   size_t* strides_ = nullptr;  // Required by MKL for conversions
-  // TF dimension corresponding to this MKL dimension
-  size_t* tf_to_mkl_dim_map_ = nullptr;
+  size_t* tf_to_mkl_dim_map_ =
+      nullptr;  // TF dimension corresponding to this MKL dimension
 };
 
-int inline GetTensorDataIndex(int n) {
-  return 2 * n;  // index corresponding to nth input/output tensor
+// List of MklShape objects. Used in Concat/Split layers.
+typedef std::vector<MklShape> MklShapeList;
+
+// Check if all tensors specified by MklShapes are MKL tensors.
+inline bool AreAllMklTensors(const MklShapeList& shapes) {
+  for (auto& s : shapes) {
+    if (!s.IsMklTensor()) {
+      return false;
+    }
+  }
+  return true;
 }
 
-int inline GetTensorMetaDataIndex(int n) {
-  // index corresponding to meta data of nth input/output tensor
-  return 2 * n + 1;
+template <typename T>
+inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
+                             const MklShape& mkl_shape) {
+  Tensor output_tensor;
+  TensorShape output_shape;
+
+  for (size_t j = 0; j < mkl_shape.GetDimension(); j++) {
+    // Outermost to innermost dimension
+    output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]);
+  }
+
+  // Allocate output tensor.
+  context->allocate_temp(DataTypeToEnum<T>::v(), output_shape, &output_tensor);
+
+  dnnLayout_t output_layout = static_cast<dnnLayout_t>(mkl_shape.GetTfLayout());
+  void* input_buffer = const_cast<T*>(mkl_tensor.flat<T>().data());
+  void* output_buffer = const_cast<T*>(output_tensor.flat<T>().data());
+
+  if (mkl_tensor.NumElements() != 0) {
+    mkl_shape.GetConvertedFlatData(output_layout, input_buffer, output_buffer);
+  }
+
+  return output_tensor;
 }
+
+// Since our ops are going to produce and also consume N addition tensors
+// (Mkl) for N Tensorflow tensors, we can have following different
+// orderings among these 2N tensors.
+//
+// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
+// consume A_m, B_m, and C_m additionally.
+//
+// INTERLEAVED: in this case 2N tensors are interleaved. So for above
+//              example, the ordering looks like: A, A_m, B, B_m, C, C_m.
+//
+// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
+//             by N Mkl tensors. So for above example, the ordering looks
+//             like: A, B, C, A_m, B_m, C_m
+//
+// Following APIs map index of original Tensorflow tensors to their appropriate
+// position based on selected ordering. For contiguous ordering, we need to know
+// the total number of tensors (parameter total).
+//
+typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
+// NOTE: Currently, we use contiguous ordering. If you change this, then you
+// would need to change Mkl op definitions in nn_ops.cc.
+static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
+
+// Get index of MetaData tensor from index 'n' of Data tensor.
+inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
+  if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+    // For interleaved ordering, Mkl tensor follows immediately after
+    // Tensorflow tensor.
+    return n + 1;
+  } else {
+    CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+    // For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
+    return n + total_tensors / 2;
+  }
+}
+
+int inline GetTensorDataIndex(int n, int total_tensors) {
+  if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
+    return 2 * n;  // index corresponding to nth input/output tensor
+  } else {
+    CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
+    return n;
+  }
+}
+
+int inline GetTensorMetaDataIndex(int n, int total_tensors) {
+  // Get index for TensorData first and then use mapping function
+  // to get TensorMetaData index from TensorData index.
+  int tidx = GetTensorDataIndex(n, total_tensors);
+  return DataIndexToMetaDataIndex(tidx, total_tensors);
+}
+
 // Get the MKL shape from the second string tensor
 inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
   mklshape->DeSerializeMklShape(
-      ctext->input(GetTensorMetaDataIndex(n)).flat<uint8>().data(),
-      ctext->input(GetTensorMetaDataIndex(n)).flat<uint8>().size() *
+      ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
+          .flat<uint8>()
+          .data(),
+      ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
+              .flat<uint8>()
+              .size() *
           sizeof(uint8));
 }
 
 // Gets the actual input
 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
-  return ctext->input(GetTensorDataIndex(n));
+  return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
+}
+
+inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
+                            OpInputList* input_tensors) {
+  CHECK_NOTNULL(input_tensors);
+  ctext->input_list(name, input_tensors);
+}
+
+inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
+                            MklShapeList* mkl_shapes) {
+  OpInputList input_mkl_tensors;
+  GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
+
+  for (int i = 0; i < input_mkl_tensors.size(); i++) {
+    (*mkl_shapes)[i].DeSerializeMklShape(
+        input_mkl_tensors[i].flat<uint8>().data(),
+        input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
+  }
+}
+
+// Allocate the second output tensor that will contain
+// the MKL shape serialized
+inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
+                                      const MklShape& mkl_shape) {
+  Tensor* second_tensor = nullptr;
+  TensorShape second_shape;
+  second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
+  OP_REQUIRES_OK(ctext, ctext->allocate_output(
+                            GetTensorMetaDataIndex(n, ctext->num_outputs()),
+                            second_shape, &second_tensor));
+  mkl_shape.SerializeMklShape(
+      second_tensor->flat<uint8>().data(),
+      second_tensor->flat<uint8>().size() * sizeof(uint8));
 }
 
 // Allocate the output tensor, create a second output tensor that will contain
 // the MKL shape serialized
-inline void AllocateOutputSetMklshape(OpKernelContext* ctext, int n,
+inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
                                       Tensor** output,
-                                      const TensorShape& tfshape,
-                                      const MklShape& mklshape) {
+                                      const TensorShape& tf_shape,
+                                      const MklShape& mkl_shape) {
   Tensor* second_tensor = nullptr;
   TensorShape second_shape;
-  second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mklshape.GetDimension()));
+  second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
   OP_REQUIRES_OK(
-      ctext, ctext->allocate_output(GetTensorDataIndex(n), tfshape, output));
-  OP_REQUIRES_OK(ctext, ctext->allocate_output(GetTensorMetaDataIndex(n),
-                                               second_shape, &second_tensor));
-  mklshape.SerializeMklShape(
+      ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
+                                    tf_shape, output));
+  OP_REQUIRES_OK(ctext, ctext->allocate_output(
+                            GetTensorMetaDataIndex(n, ctext->num_outputs()),
+                            second_shape, &second_tensor));
+  mkl_shape.SerializeMklShape(
       second_tensor->flat<uint8>().data(),
       second_tensor->flat<uint8>().size() * sizeof(uint8));
 }
@@ -342,12 +495,11 @@ inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
 
 inline void MklSizesToTFSizes(OpKernelContext* context,
                               TensorFormat data_format_,
-                              const MklShape& mklshape, TensorShape* tfshape) {
-  size_t tf_dim = mklshape.GetDimension();
-  const size_t* tf_sizes = mklshape.GetSizes();
+                              const MklShape& mkl_shape,
+                              TensorShape* tf_shape) {
+  size_t tf_dim = mkl_shape.GetDimension();
+  const size_t* tf_sizes = mkl_shape.GetSizes();
 
-  // TODO(agramesh1): check if this constraint is applicable in other cases
-  // (besides BackpropInput, BackpropFilter).
   OP_REQUIRES(context, tf_dim == 4,
               errors::InvalidArgument("MKLSizesToTFSizes: size must be 4-dim"));
   std::vector<int32> sizes;
@@ -364,7 +516,7 @@ inline void MklSizesToTFSizes(OpKernelContext* context,
     sizes.push_back(tf_sizes[0]);
   }
 
-  OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tfshape));
+  OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tf_shape));
 }
 
 inline int32 GetMklTensorDimIndex(char dimension) {
@@ -383,38 +535,71 @@ inline int32 GetMklTensorDimIndex(char dimension) {
   }
 }
 
-inline int64 GetMklTensorDim(const MklShape& mklshape, char dimension) {
+inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
   int index = GetMklTensorDimIndex(dimension);
-  CHECK(index >= 0 && index < mklshape.GetDimension())
+  CHECK(index >= 0 && index < mkl_shape.GetDimension())
       << "Invalid index from the dimension: " << index << ", " << dimension;
-  return mklshape.dim_size(index);
+  return mkl_shape.dim_size(index);
 }
 
-namespace mkl_layer_registry {
+inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
+                                 int idx_out) {
+  int num_inputs = context->num_inputs();
+  int num_outputs = context->num_outputs();
+  int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
+  int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
+  int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
+  int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
 
-static const char* kMklLayerLabel = "MklLayer";
-static const char* kMklLayerLabelPattern = "label='MklLayer'";
+  const Tensor& data = context->input(idx_data_in);
+  const Tensor& meta = context->input(idx_meta_in);
+  Tensor output(data.dtype());
+  Tensor meta_output(meta.dtype());
+
+  // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
+  CHECK(output.CopyFrom(data, data.shape()));
+  CHECK(meta_output.CopyFrom(meta, meta.shape()));
+  context->set_output(idx_data_out, output);
+  context->set_output(idx_meta_out, meta_output);
+}
+
+inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in,
+                                int idx_out, const TensorShape& shape) {
+  int num_inputs = context->num_inputs();
+  int num_outputs = context->num_outputs();
+  int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
+  int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
+
+  const Tensor& data = context->input(idx_data_in);
+  MklShape mkl_shape_output;
+  mkl_shape_output.SetMklTensor(false);
+  AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
+  Tensor output(data.dtype());
+  // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
+  CHECK(output.CopyFrom(data, shape));
+  context->set_output(idx_data_out, output);
+}
+
+namespace mkl_op_registry {
+static const char* kMklOpLabel = "MklOp";
+static const char* kMklOpLabelPattern = "label='MklOp'";
 
 // Check whether opname with type T is registered as MKL-compliant.
 //
 // @input: name of the op
 // @input: T datatype to be used for checking op
-// @return: true if opname is registered as Mkl layer op
-static inline bool IsMklLayer(const std::string& op_name, DataType T) {
+// @return: true if opname is registered as Mkl op
+static inline bool IsMklOp(const std::string& op_name, DataType T) {
   string kernel = KernelsRegisteredForOp(op_name);
-  // Currently, MKL only supports float type for ops. So we check if
-  // the type is float. Actually, we should query kernel registration and
-  // find out if op is supported for type T. But there is no API to query
-  // kernel registration using name and type.
   bool result =
-      (kernel.find(kMklLayerLabelPattern) != string::npos) && (T == DT_FLOAT);
-  if (result == true) {
-    VLOG(1) << "mkl_layer_registry::" << op_name << " is " << kMklLayerLabel;
+      kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
+  if (result) {
+    VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
   }
   return result;
 }
 
-}  // namespace mkl_layer_registry
+}  // namespace mkl_op_registry
 
 }  // namespace tensorflow
 #endif  // INTEL_MKL
diff --git a/tensorflow/docs_src/community/style_guide.md b/tensorflow/docs_src/community/style_guide.md
index a2df61bc809..767e33c3d07 100644
--- a/tensorflow/docs_src/community/style_guide.md
+++ b/tensorflow/docs_src/community/style_guide.md
@@ -115,31 +115,31 @@ Example:
 
     def my_op(tensor_in, other_tensor_in, my_param, other_param=0.5,
               output_collections=(), name=None):
-    """My operation that adds two tensors with given coefficients.
+      """My operation that adds two tensors with given coefficients.
 
-    Args:
-      tensor_in: `Tensor`, input tensor.
-      other_tensor_in: `Tensor`, same shape as `tensor_in`, other input tensor.
-      my_param: `float`, coefficient for `tensor_in`.
-      other_param: `float`, coefficient for `other_tensor_in`.
-      output_collections: `tuple` of `string`s, name of the collection to
-                          collect result of this op.
-      name: `string`, name of the operation.
+      Args:
+        tensor_in: `Tensor`, input tensor.
+        other_tensor_in: `Tensor`, same shape as `tensor_in`, other input tensor.
+        my_param: `float`, coefficient for `tensor_in`.
+        other_param: `float`, coefficient for `other_tensor_in`.
+        output_collections: `tuple` of `string`s, name of the collection to
+                            collect result of this op.
+        name: `string`, name of the operation.
 
-    Returns:
-      `Tensor` of same shape as `tensor_in`, sum of input values with coefficients.
+      Returns:
+        `Tensor` of same shape as `tensor_in`, sum of input values with coefficients.
 
-    Example:
-      >>> my_op([1., 2.], [3., 4.], my_param=0.5, other_param=0.6,
-                output_collections=['MY_OPS'], name='add_t1t2')
-      [2.3, 3.4]
-    """
-    with tf.name_scope(name, "my_op", [tensor_in, other_tensor_in]):
-      tensor_in = tf.convert_to_tensor(tensor_in)
-      other_tensor_in = tf.convert_to_tensor(other_tensor_in)
-      result = my_param * tensor_in + other_param * other_tensor_in
-      tf.add_to_collections(output_collections, result)
-      return result
+      Example:
+        >>> my_op([1., 2.], [3., 4.], my_param=0.5, other_param=0.6,
+                  output_collections=['MY_OPS'], name='add_t1t2')
+        [2.3, 3.4]
+      """
+      with tf.name_scope(name, "my_op", [tensor_in, other_tensor_in]):
+        tensor_in = tf.convert_to_tensor(tensor_in)
+        other_tensor_in = tf.convert_to_tensor(other_tensor_in)
+        result = my_param * tensor_in + other_param * other_tensor_in
+        tf.add_to_collection(output_collections, result)
+        return result
 
 Usage:
 
diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md
index c75c7f111da..f54f79cbf4a 100644
--- a/tensorflow/docs_src/extend/adding_an_op.md
+++ b/tensorflow/docs_src/extend/adding_an_op.md
@@ -121,16 +121,16 @@ class ZeroOutOp : public OpKernel {
     Tensor* output_tensor = NULL;
     OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                      &output_tensor));
-    auto output = output_tensor->flat<int32>();
+    auto output_flat = output_tensor->flat<int32>();
 
     // Set all but the first element of the output tensor to 0.
     const int N = input.size();
     for (int i = 1; i < N; i++) {
-      output(i) = 0;
+      output_flat(i) = 0;
     }
 
     // Preserve the first input value if possible.
-    if (N > 0) output(0) = input(0);
+    if (N > 0) output_flat(0) = input(0);
   }
 };
 ```
diff --git a/tensorflow/docs_src/get_started/get_started.md b/tensorflow/docs_src/get_started/get_started.md
index 33ff1d87d52..6bee7529d0a 100644
--- a/tensorflow/docs_src/get_started/get_started.md
+++ b/tensorflow/docs_src/get_started/get_started.md
@@ -323,7 +323,7 @@ for i in range(1000):
   sess.run(train, {x:x_train, y:y_train})
 
 # evaluate training accuracy
-curr_W, curr_b, curr_loss  = sess.run([W, b, loss], {x:x_train, y:y_train})
+curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x:x_train, y:y_train})
 print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
 ```
 When run, it produces
diff --git a/tensorflow/docs_src/get_started/monitors.md b/tensorflow/docs_src/get_started/monitors.md
index 99d583b23dc..7db88c89812 100644
--- a/tensorflow/docs_src/get_started/monitors.md
+++ b/tensorflow/docs_src/get_started/monitors.md
@@ -282,18 +282,15 @@ validation_metrics = {
     "accuracy":
         tf.contrib.learn.MetricSpec(
             metric_fn=tf.contrib.metrics.streaming_accuracy,
-            prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
-            CLASSES),
+            prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
     "precision":
         tf.contrib.learn.MetricSpec(
             metric_fn=tf.contrib.metrics.streaming_precision,
-            prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
-            CLASSES),
+            prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
     "recall":
         tf.contrib.learn.MetricSpec(
             metric_fn=tf.contrib.metrics.streaming_recall,
-            prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
-            CLASSES)
+            prediction_key=tf.contrib.learn.PredictionKey.CLASSES)
 }
 ```
 
diff --git a/tensorflow/docs_src/get_started/tflearn.md b/tensorflow/docs_src/get_started/tflearn.md
index 0912c7a5b4a..4a893e4a45b 100644
--- a/tensorflow/docs_src/get_started/tflearn.md
+++ b/tensorflow/docs_src/get_started/tflearn.md
@@ -282,7 +282,7 @@ enough that it can be stored in @{tf.constant TensorFlow constants}. The
 following code produces the simplest possible input pipeline:
 
 ```python
-# Define the test inputs
+# Define the training inputs
 def get_train_inputs():
   x = tf.constant(training_set.data)
   y = tf.constant(training_set.target)
diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md
index 0f3914d52d4..c1581efb4f3 100644
--- a/tensorflow/docs_src/install/install_c.md
+++ b/tensorflow/docs_src/install/install_c.md
@@ -35,7 +35,7 @@ enable TensorFlow for C:
          OS="linux" # Change to "darwin" for Mac OS
          TARGET_DIRECTORY="/usr/local"
          curl -L \
-           "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.1.0-rc1.tar.gz" |
+           "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.1.0-rc2.tar.gz" |
            sudo tar -C $TARGET_DIRECTORY -xz
 
      The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md
index 6874a1f03f5..dd713e4786e 100644
--- a/tensorflow/docs_src/install/install_go.md
+++ b/tensorflow/docs_src/install/install_go.md
@@ -35,7 +35,7 @@ steps to install this library and enable TensorFlow for Go:
          TF_TYPE="cpu" # Change to "gpu" for GPU support
          TARGET_DIRECTORY='/usr/local'
          curl -L \
-           "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0-rc1.tar.gz" |
+           "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0-rc2.tar.gz" |
          sudo tar -C $TARGET_DIRECTORY -xz
 
      The `tar` command extracts the TensorFlow C library into the `lib`
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index 127d8fd0292..1abf3b69f5e 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -34,7 +34,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs:
 <dependency>
   <groupId>org.tensorflow</groupId>
   <artifactId>tensorflow</artifactId>
-  <version>1.1.0-rc1</version>
+  <version>1.1.0-rc2</version>
 </dependency>
 ```
 
@@ -63,7 +63,7 @@ As an example, these steps will create a Maven project that uses TensorFlow:
                <dependency>
                  <groupId>org.tensorflow</groupId>
                  <artifactId>tensorflow</artifactId>
-                 <version>1.1.0-rc1</version>
+                 <version>1.1.0-rc2</version>
                </dependency>
              </dependencies>
          </project>
@@ -122,7 +122,7 @@ refer to the simpler instructions above instead.
 Take the following steps to install TensorFlow for Java on Linux or Mac OS:
 
   1. Download
-     [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc1.jar),
+     [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc2.jar),
      which is the TensorFlow Java Archive (JAR).
 
   2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
@@ -141,7 +141,7 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
          OS=$(uname -s | tr '[:upper:]' '[:lower:]')
          mkdir -p ./jni
          curl -L \
-           "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc1.tar.gz" |
+           "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc2.tar.gz" |
            tar -xz -C ./jni
 
 ### Install on Windows
@@ -149,10 +149,10 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
 Take the following steps to install TensorFlow for Java on Windows:
 
   1. Download
-     [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc1.jar),
+     [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc2.jar),
      which is the TensorFlow Java Archive (JAR).
   2. Download the following Java Native Interface (JNI) file appropriate for
-     [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc1.zip).
+     [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc2.zip).
   3. Extract this .zip file.
 
 
@@ -200,7 +200,7 @@ must be part of your `classpath`. For example, you can include the
 downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
 as follows:
 
-<pre><b>javac -cp libtensorflow-1.1.0-rc1.jar HelloTF.java</b></pre>
+<pre><b>javac -cp libtensorflow-1.1.0-rc2.jar HelloTF.java</b></pre>
 
 
 ### Running
@@ -213,7 +213,7 @@ two files are available to the JVM:
 
 For example, the following command line executes the `HelloTF` program:
 
-<pre><b>java -cp libtensorflow-1.1.0-rc1.jar:. -Djava.library.path=./jni HelloTF</b></pre>
+<pre><b>java -cp libtensorflow-1.1.0-rc2.jar:. -Djava.library.path=./jni HelloTF</b></pre>
 
 If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
 installed TensorFlow for Java and are ready to use the API.  If the program
diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md
index 4a5d63f337a..8ee31fe6922 100644
--- a/tensorflow/docs_src/install/install_linux.md
+++ b/tensorflow/docs_src/install/install_linux.md
@@ -165,8 +165,8 @@ Take the following steps to install TensorFlow with Virtualenv:
      issue the following command to install TensorFlow in the active
      virtualenv environment:
 
-     <pre> (tensorflow)$ <b>pip install --upgrade \\
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl</b></pre>
+     <pre>(tensorflow)$ <b>pip3 install --upgrade \
+     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl</b></pre>
 
 If you encounter installation problems, see
 [Common Installation Problems](#common_installation_problems).
@@ -269,8 +269,10 @@ take the following steps:
      install TensorFlow for Linux, Python 2.7, and CPU-only support, issue
      the following command:
 
-     <pre> $ <b>sudo pip install --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl</b></pre>
+     <pre>
+     $ <b>sudo pip3 install --upgrade \
+     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl</b>
+     </pre>
 
      If this step fails, see
      [Common Installation Problems](#common_installation_problems).
@@ -456,7 +458,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
 
      <pre>
      (tensorflow)$ <b>pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl</b></pre>
+     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl</b></pre>
 
 
 <a name="ValidateYourInstallation"></a>
@@ -624,14 +626,14 @@ This section documents the relevant values for Linux installations.
 CPU only:
 
 <pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp27-none-linux_x86_64.whl
 </pre>
 
 
 GPU support:
 
 <pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp27-none-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp27-none-linux_x86_64.whl
 </pre>
 
 Note that GPU support requires the NVIDIA hardware and software described in
@@ -643,14 +645,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
 CPU only:
 
 <pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl
 </pre>
 
 
 GPU support:
 
 <pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp34-cp34m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp34-cp34m-linux_x86_64.whl
 </pre>
 
 Note that GPU support requires the NVIDIA hardware and software described in
@@ -662,14 +664,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
 CPU only:
 
 <pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp35-cp35m-linux_x86_64.whl
 </pre>
 
 
 GPU support:
 
 <pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp35-cp35m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp35-cp35m-linux_x86_64.whl
 </pre>
 
 
@@ -681,14 +683,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
 CPU only:
 
 <pre>
-https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp36-cp36m-linux_x86_64.whl
 </pre>
 
 
 GPU support:
 
 <pre>
-https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp36-cp36m-linux_x86_64.whl
+https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp36-cp36m-linux_x86_64.whl
 </pre>
 
 
diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md
index ccfe9ada6d0..0882422e4dd 100644
--- a/tensorflow/docs_src/install/install_mac.md
+++ b/tensorflow/docs_src/install/install_mac.md
@@ -163,7 +163,7 @@ Take the following steps to install TensorFlow with Virtualenv:
      TensorFlow in the active Virtualenv is as follows:
 
      <pre> $ <b>pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl</b></pre>
+     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl</b></pre>
 
 If you encounter installation problems, see
 [Common Installation Problems](#CommonInstallationProblems).
@@ -286,7 +286,7 @@ take the following steps:
      support, issue the following command:
 
      <pre> $ <b>sudo pip3 install --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl</b> </pre>
+     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl</b> </pre>
 
      If the preceding command fails, see
      [Common installation problems](#CommonInstallationProblems).
@@ -398,7 +398,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
      TensorFlow for Python 2.7:
 
      <pre> (tensorflow)$ <b>pip install --ignore-installed --upgrade \
-     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl</b></pre>
+     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl</b></pre>
 
 
 <a name="ValidateYourInstallation"></a>
@@ -604,13 +604,13 @@ This section documents the relevant values for Mac OS installations.
 CPU only:
 
 <pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl
 </pre>
 
 GPU support:
 
 <pre>
-https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc1-py2-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc2-py2-none-any.whl
 </pre>
 
 Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see
@@ -622,13 +622,13 @@ Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see
 CPU only:
 
 <pre>
-https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py3-none-any.whl
 </pre>
 
 GPU support:
 
 <pre>
-https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc1-py3-none-any.whl
+https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc2-py3-none-any.whl
 </pre>
 
 Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index 5f351c40b44..88268ba62f8 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -319,10 +319,11 @@ $ <b>bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pk
 Invoke `pip install` to install that pip package.
 The filename of the `.whl` file depends on your platform.
 For example, the following command will install the pip package
-for TensorFlow 1.1.0rc1 on Linux:
+
+for TensorFlow 1.1.0rc2 on Linux:
 
 <pre>
-$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.1.0rc1-py2-none-any.whl</b>
+$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.1.0rc2-py2-none-any.whl</b>
 </pre>
 
 ## Validate your installation
diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md
index 7d3d13c34a0..5f7c27c0282 100644
--- a/tensorflow/docs_src/install/install_windows.md
+++ b/tensorflow/docs_src/install/install_windows.md
@@ -114,12 +114,12 @@ Take the following steps to install TensorFlow in an Anaconda environment:
      environment. To install the CPU-only version of TensorFlow, enter the
      following command:
 
-     <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.1.0rc1-cp35-cp35m-win_amd64.whl</b> </pre>
+     <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.1.0rc2-cp35-cp35m-win_amd64.whl</b> </pre>
 
      To install the GPU version of TensorFlow, enter the following command
      (on a single line):
 
-     <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.1.0rc1-cp35-cp35m-win_amd64.whl</b> </pre>
+     <pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.1.0rc2-cp35-cp35m-win_amd64.whl</b> </pre>
 
 ## Validate your installation
 
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
index 8cd296d7520..e29387ab9d0 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
@@ -88,7 +88,8 @@ def run_training():
     saver = tf.train.Saver()
 
     # Create the op for initializing variables.
-    init_op = tf.global_variables_initializer()
+    init_op = tf.group(tf.global_variables_initializer(),
+                       tf.local_variables_initializer())
 
     # Create a session for running Ops on the Graph.
     sess = tf.Session()
diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
index 698c97ca1d7..dc0d8703158 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
@@ -135,7 +135,8 @@ def train():
       accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
   tf.summary.scalar('accuracy', accuracy)
 
-  # Merge all the summaries and write them out to /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
+  # Merge all the summaries and write them out to
+  # /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
   merged = tf.summary.merge_all()
   train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
   test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
@@ -196,9 +197,15 @@ if __name__ == '__main__':
                       help='Initial learning rate')
   parser.add_argument('--dropout', type=float, default=0.9,
                       help='Keep probability for training dropout.')
-  parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
-                      help='Directory for storing input data')
-  parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/logs/mnist_with_summaries',
-                      help='Summaries log directory')
+  parser.add_argument(
+      '--data_dir',
+      type=str,
+      default='/tmp/tensorflow/mnist/input_data',
+      help='Directory for storing input data')
+  parser.add_argument(
+      '--log_dir',
+      type=str,
+      default='/tmp/tensorflow/mnist/logs/mnist_with_summaries',
+      help='Summaries log directory')
   FLAGS, unparsed = parser.parse_known_args()
   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index cad8ccaaad6..c367d20f816 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -25,6 +25,7 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py
 load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps")
 load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps")
 load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_verbs_deps")
 
 py_library(
     name = "python",
@@ -2610,7 +2611,9 @@ tf_py_wrap_cc(
         "//tensorflow/tools/graph_transforms:transform_graph_lib",
         "//tensorflow/tools/tfprof/internal:print_model_analysis",
         "//util/python:python_headers",
-    ] + tf_additional_lib_deps() + tf_additional_plugin_deps(),
+    ] + (tf_additional_lib_deps() +
+         tf_additional_plugin_deps() +
+         tf_additional_verbs_deps()),
 )
 
 py_library(
diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py
index fac2cf4def9..f04f67ffedd 100644
--- a/tensorflow/python/framework/dtypes_test.py
+++ b/tensorflow/python/framework/dtypes_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for tensorflow.python.framework.importer."""
+"""Tests for tensorflow.python.framework.dtypes."""
 
 from __future__ import absolute_import
 from __future__ import division
diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py
index 5e8f8e8673a..10f34751d0b 100644
--- a/tensorflow/python/kernel_tests/reader_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_ops_test.py
@@ -352,9 +352,20 @@ class FixedLengthRecordReaderTest(test.TestCase):
     self._record_bytes = 3
     self._footer_bytes = 2
 
+    self._hop_bytes = 2
+    self._num_overlapped_records = 3
+
   def _Record(self, f, r):
     return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
 
+  def _OverlappedRecord(self, f, r):
+    record_str = "".join([
+        str(i)[0]
+        for i in range(r * self._hop_bytes,
+                       r * self._hop_bytes + self._record_bytes)
+    ])
+    return compat.as_bytes(record_str)
+
   def _CreateFiles(self):
     filenames = []
     for i in range(self._num_files):
@@ -367,6 +378,23 @@ class FixedLengthRecordReaderTest(test.TestCase):
         f.write(b"F" * self._footer_bytes)
     return filenames
 
+  def _CreateOverlappedRecordFiles(self):
+    filenames = []
+    for i in range(self._num_files):
+      fn = os.path.join(self.get_temp_dir(),
+                        "fixed_length_overlapped_record.%d.txt" % i)
+      filenames.append(fn)
+      with open(fn, "wb") as f:
+        f.write(b"H" * self._header_bytes)
+        all_records_str = "".join([
+            str(i)[0]
+            for i in range(self._record_bytes + self._hop_bytes *
+                           (self._num_overlapped_records - 1))
+        ])
+        f.write(compat.as_bytes(all_records_str))
+        f.write(b"F" * self._footer_bytes)
+    return filenames
+
   def testOneEpoch(self):
     files = self._CreateFiles()
     with self.test_session() as sess:
@@ -374,6 +402,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
           header_bytes=self._header_bytes,
           record_bytes=self._record_bytes,
           footer_bytes=self._footer_bytes,
+          hop_bytes=0,
           name="test_reader")
       queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
       key, value = reader.read(queue)
@@ -390,6 +419,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
                                     "\\(requested 1, current size 0\\)"):
         k, v = sess.run([key, value])
 
+  def testOneEpochWithHopBytes(self):
+    files = self._CreateOverlappedRecordFiles()
+    with self.test_session() as sess:
+      reader = io_ops.FixedLengthRecordReader(
+          header_bytes=self._header_bytes,
+          record_bytes=self._record_bytes,
+          footer_bytes=self._footer_bytes,
+          hop_bytes=self._hop_bytes,
+          name="test_reader")
+      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
+      key, value = reader.read(queue)
+
+      queue.enqueue_many([files]).run()
+      queue.close().run()
+      for i in range(self._num_files):
+        for j in range(self._num_overlapped_records):
+          k, v = sess.run([key, value])
+          print(v)
+          self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
+          self.assertAllEqual(self._OverlappedRecord(i, j), v)
+
+      with self.assertRaisesOpError("is closed and has insufficient elements "
+                                    "\\(requested 1, current size 0\\)"):
+        k, v = sess.run([key, value])
+
 
 class TFRecordReaderTest(test.TestCase):
 
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
index df5462dd2d0..e8b94294b1b 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
@@ -30,34 +30,44 @@ from tensorflow.python.platform import test
 
 class SparseTensorDenseMatMulGradientTest(test.TestCase):
 
-  def _sparsify(self, x):
+  def _sparsify(self, x, indices_dtype=np.int64):
     x[x < 0.5] = 0
 
     non_zero = np.where(x)
-    x_indices = np.vstack(non_zero).astype(np.int64).T
+    x_indices = np.vstack(non_zero).astype(indices_dtype).T
     x_values = x[non_zero]
     x_shape = x.shape
 
     return sparse_tensor.SparseTensor(
         indices=x_indices, values=x_values, dense_shape=x_shape), len(x_values)
 
-  def _randomTensor(self, size, np_dtype, adjoint=False, sparse=False):
+  def _randomTensor(self,
+                    size,
+                    values_dtype,
+                    adjoint=False,
+                    sparse=False,
+                    indices_dtype=np.int64):
     n, m = size
-    x = np.random.randn(n, m).astype(np_dtype)
+    x = np.random.randn(n, m).astype(values_dtype)
 
     if adjoint:
       x = x.transpose()
 
     if sparse:
-      return self._sparsify(x)
+      return self._sparsify(x, indices_dtype=indices_dtype)
     else:
-      return constant_op.constant(x, dtype=np_dtype)
+      return constant_op.constant(x, dtype=values_dtype)
 
-  def _testGradients(self, adjoint_a, adjoint_b, name, np_dtype):
+  def _testGradients(self, adjoint_a, adjoint_b, name, values_dtype,
+                     indices_dtype):
     n, k, m = np.random.randint(1, 10, size=3)
     sp_t, nnz = self._randomTensor(
-        [n, k], np_dtype, adjoint=adjoint_a, sparse=True)
-    dense_t = self._randomTensor([k, m], np_dtype, adjoint=adjoint_b)
+        [n, k],
+        values_dtype,
+        adjoint=adjoint_a,
+        sparse=True,
+        indices_dtype=indices_dtype)
+    dense_t = self._randomTensor([k, m], values_dtype, adjoint=adjoint_b)
 
     matmul = sparse_ops.sparse_tensor_dense_matmul(
         sp_t, dense_t, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name=name)
@@ -71,17 +81,19 @@ class SparseTensorDenseMatMulGradientTest(test.TestCase):
       print("%s gradient err = %s" % (name, err))
       self.assertLess(err, 1e-3)
 
-  def _testGradientsType(self, np_dtype):
+  def _testGradientsType(self, values_dtype, indices_dtype):
     for adjoint_a in [True, False]:
       for adjoint_b in [True, False]:
-        name = "sparse_tensor_dense_matmul_%s_%s_%s" % (adjoint_a, adjoint_b,
-                                                        np_dtype.__name__)
-        self._testGradients(adjoint_a, adjoint_b, name, np_dtype)
+        name = "sparse_tensor_dense_matmul_%s_%s_%s_%s" % (
+            adjoint_a, adjoint_b, values_dtype.__name__, indices_dtype.__name__)
+        self._testGradients(adjoint_a, adjoint_b, name, values_dtype,
+                            indices_dtype)
 
   def testGradients(self):
     np.random.seed(5)  # Fix seed to avoid flakiness
-    self._testGradientsType(np.float32)
-    self._testGradientsType(np.float64)
+    self._testGradientsType(np.float32, np.int64)
+    self._testGradientsType(np.float64, np.int64)
+    self._testGradientsType(np.float32, np.int32)
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
index da72803ee7b..80991751860 100644
--- a/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
@@ -45,7 +45,12 @@ def _maybe_complex(x):
 
 class SparseTensorDenseMatMulTest(test.TestCase):
 
-  def _testMatmul(self, x, y, adjoint_a=False, adjoint_b=False):
+  def _testMatmul(self,
+                  x,
+                  y,
+                  adjoint_a=False,
+                  adjoint_b=False,
+                  indices_dtype=np.int64):
     x_mat = np.matrix(x)
     if adjoint_a:
       x_mat = x_mat.H
@@ -55,7 +60,7 @@ class SparseTensorDenseMatMulTest(test.TestCase):
 
     np_ans = x_mat * y_mat
 
-    x_indices = np.vstack(np.where(x)).astype(np.int64).T
+    x_indices = np.vstack(np.where(x)).astype(indices_dtype).T
     x_values = x[np.where(x)]
     x_shape = x.shape
 
@@ -82,13 +87,13 @@ class SparseTensorDenseMatMulTest(test.TestCase):
         else:
           self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
 
-  def _testBasic(self, np_dtype):
-    x = _maybe_complex(np.random.rand(10, 10).astype(np_dtype))
+  def _testBasic(self, value_dtype, indices_dtype=np.int64):
+    x = _maybe_complex(np.random.rand(10, 10).astype(value_dtype))
     x[np.abs(x) < 0.5] = 0  # Make it sparse
 
-    y = _maybe_complex(np.random.randn(10, 20).astype(np_dtype))
+    y = _maybe_complex(np.random.randn(10, 20).astype(value_dtype))
 
-    self._testMatmul(x, y)
+    self._testMatmul(x, y, indices_dtype=indices_dtype)
 
   def testBasic(self):
     np.random.seed(127)  # Repeatable results
@@ -97,6 +102,8 @@ class SparseTensorDenseMatMulTest(test.TestCase):
     self._testBasic(np.float64)
     self._testBasic(np.complex64)
     self._testBasic(np.complex128)
+    self._testBasic(np.int32, indices_dtype=np.int32)
+    self._testBasic(np.float32, indices_dtype=np.int32)
 
   def testShapeInference(self):
     x = np.random.rand(10, 10)
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index c3e133d08b2..da962b2f99e 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for tf.layers.core."""
+"""Tests for tf.layers.convolutional."""
 
 from __future__ import absolute_import
 from __future__ import division
diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py
index 0f82f73ea48..933f196e011 100644
--- a/tensorflow/python/layers/normalization_test.py
+++ b/tensorflow/python/layers/normalization_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for tf.layers.core."""
+"""Tests for tf.layers.normalization."""
 
 from __future__ import absolute_import
 from __future__ import division
diff --git a/tensorflow/python/layers/utils_test.py b/tensorflow/python/layers/utils_test.py
index ace8046a0bb..54e757c112b 100644
--- a/tensorflow/python/layers/utils_test.py
+++ b/tensorflow/python/layers/utils_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for tf.layers.core."""
+"""Tests for tf.layers.utils."""
 
 from __future__ import absolute_import
 from __future__ import division
diff --git a/tensorflow/python/ops/batch_norm_benchmark.py b/tensorflow/python/ops/batch_norm_benchmark.py
index 397ed91078b..c2ee2b38323 100644
--- a/tensorflow/python/ops/batch_norm_benchmark.py
+++ b/tensorflow/python/ops/batch_norm_benchmark.py
@@ -198,7 +198,7 @@ class BatchNormBenchmark(test.Benchmark):
     if FLAGS.use_gpu:
       t1 = self._run_graph("gpu", shape, axes, 10, "op", True, True, 50)
       t2 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 50)
-      t2 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
+      t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
       print_difference("op vs py", t1, t2)
       print_difference("py vs slow", t2, t3)
     print("Forward convolution (higher layers).")
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index ae45c40aec4..68ecc219e4f 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -391,7 +391,11 @@ class FixedLengthRecordReader(ReaderBase):
   """
   # TODO(josh11b): Support serializing and restoring state.
 
-  def __init__(self, record_bytes, header_bytes=None, footer_bytes=None,
+  def __init__(self,
+               record_bytes,
+               header_bytes=None,
+               footer_bytes=None,
+               hop_bytes=None,
                name=None):
     """Create a FixedLengthRecordReader.
 
@@ -399,11 +403,15 @@ class FixedLengthRecordReader(ReaderBase):
       record_bytes: An int.
       header_bytes: An optional int. Defaults to 0.
       footer_bytes: An optional int. Defaults to 0.
+      hop_bytes: An optional int. Defaults to 0.
       name: A name for the operation (optional).
     """
     rr = gen_io_ops._fixed_length_record_reader_v2(
-        record_bytes=record_bytes, header_bytes=header_bytes,
-        footer_bytes=footer_bytes, name=name)
+        record_bytes=record_bytes,
+        header_bytes=header_bytes,
+        footer_bytes=footer_bytes,
+        hop_bytes=hop_bytes,
+        name=name)
     super(FixedLengthRecordReader, self).__init__(rr)
 
 
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 118ade45ec3..7c17cf2cb61 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -639,18 +639,22 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
           math_ops.reduce_mean(y, axes, keep_dims=True))
     else:
       shift = math_ops.cast(shift, y.dtype)
-    counts, m_ss, v_ss, shift = sufficient_statistics(
-        y, axes, shift=shift, keep_dims=keep_dims, name=name)
-    # Reshape shift as needed.
-    shift = array_ops.reshape(shift, array_ops.shape(m_ss))
-    shift.set_shape(m_ss.get_shape())
-    with ops.control_dependencies([counts, m_ss, v_ss]):
-      mean, variance = normalize_moments(counts, m_ss, v_ss, shift, name=name)
-      if x.dtype == dtypes.float16:
-        return (math_ops.cast(mean, dtypes.float16),
-                math_ops.cast(variance, dtypes.float16))
-      else:
-        return (mean, variance)
+    shifted_mean = math_ops.reduce_mean(
+        math_ops.subtract(y, shift), axes, keep_dims=True, name="shifted_mean")
+    variance = math_ops.subtract(
+        math_ops.reduce_mean(
+            math_ops.squared_difference(y, shift), axes, keep_dims=True),
+        math_ops.square(shifted_mean),
+        name="variance")
+    mean = math_ops.add(shifted_mean, shift, name="mean")
+    if not keep_dims:
+      mean = array_ops.squeeze(mean, axes)
+      variance = array_ops.squeeze(variance, axes)
+    if x.dtype == dtypes.float16:
+      return (math_ops.cast(mean, dtypes.float16), math_ops.cast(
+          variance, dtypes.float16))
+    else:
+      return (mean, variance)
 
 
 def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py
index fa015856cec..b8e356c78cc 100644
--- a/tensorflow/python/ops/sparse_grad.py
+++ b/tensorflow/python/ops/sparse_grad.py
@@ -136,12 +136,13 @@ def _SparseTensorDenseMatMulGrad(op, grad):
   Raises:
     TypeError: When the two operands don't have the same type.
   """
-  sp_t = sparse_tensor.SparseTensor(*op.inputs[:3])
+  a_indices, a_values, a_shape = op.inputs[:3]
+  b = op.inputs[3]
   adj_a = op.get_attr("adjoint_a")
   adj_b = op.get_attr("adjoint_b")
 
-  a_type = sp_t.values.dtype.base_dtype
-  b_type = op.inputs[3].dtype.base_dtype
+  a_type = a_values.dtype.base_dtype
+  b_type = b.dtype.base_dtype
   if a_type != b_type:
     raise TypeError("SparseTensorDenseMatMul op received operands with "
                     "different types: ", a_type, " and ", b_type)
@@ -150,15 +151,12 @@ def _SparseTensorDenseMatMulGrad(op, grad):
                               "complex gradients.")
 
   # gradient w.r.t. dense
-  b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad,
-                                                 adjoint_a=not adj_a)
+  b_grad = gen_sparse_ops._sparse_tensor_dense_mat_mul(  # pylint: disable=protected-access
+      a_indices, a_values, a_shape, grad, adjoint_a=not adj_a)
   if adj_b:
     b_grad = array_ops.transpose(b_grad)
 
   # gradient w.r.t. sparse values
-  a_indices = op.inputs[0]
-  b = op.inputs[3]
-
   rows = a_indices[:, 0]
   cols = a_indices[:, 1]
 
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 9f4e6607d10..af7abf52511 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -1239,7 +1239,7 @@ def sparse_tensor_dense_matmul(sp_a,
     A should be sorted in order of increasing dimension 1 (i.e., "column major"
     order instead of "row major" order).
 
-  Deciding when to use sparse_tensor_dense_matmul vs. matmul(sp_a=True):
+  Deciding when to use sparse_tensor_dense_matmul vs. matmul(a_is_sparse=True):
 
   There are a number of questions to ask in the decision process, including:
 
@@ -1249,14 +1249,14 @@ def sparse_tensor_dense_matmul(sp_a,
 
   If the answer to several of these questions is yes, consider
   converting the `SparseTensor` to a dense one and using `tf.matmul` with
-  `sp_a=True`.
+  `a_is_sparse=True`.
 
   This operation tends to perform well when A is more sparse, if the column size
   of the product is small (e.g. matrix-vector multiplication), if
   `sp_a.dense_shape` takes on large values.
 
   Below is a rough speed comparison between sparse_tensor_dense_matmul,
-  labelled 'sparse', and matmul(sp_a=True), labelled 'dense'.  For purposes of
+  labelled 'sparse', and matmul(a_is_sparse=True), labelled 'dense'.  For purposes of
   the comparison, the time spent converting from a SparseTensor to a dense
   Tensor is not included, so it is overly conservative with respect to
   the time ratio.
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 751ccd3d0ef..9d3ac4ed9ed 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -319,7 +319,7 @@ class StreamExecutorInterface {
   // Creates a new DnnSupport object, ownership is transferred to the caller.
   // If SupportsDnn() is false, this will always return null.
   //
-  // If SupportsDnn() is true, this may return null, for example, if the RNG
+  // If SupportsDnn() is true, this may return null, for example, if the DNN
   // initialization fails.
   virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
 
diff --git a/tensorflow/tensorboard/DEVELOPMENT.md b/tensorflow/tensorboard/DEVELOPMENT.md
index 0a35dec42fb..3ff2c87dab7 100644
--- a/tensorflow/tensorboard/DEVELOPMENT.md
+++ b/tensorflow/tensorboard/DEVELOPMENT.md
@@ -21,7 +21,7 @@ Then, cd into the TensorBoard directory:
 
 and install dependencies:
 
-`npm run prepare`
+`npm run prep`
 
 Then, run gulp: `gulp`
 
diff --git a/tensorflow/tensorboard/dist/tf-tensorboard.html b/tensorflow/tensorboard/dist/tf-tensorboard.html
index 2a0a029ffa9..8610940ac3c 100644
--- a/tensorflow/tensorboard/dist/tf-tensorboard.html
+++ b/tensorflow/tensorboard/dist/tf-tensorboard.html
@@ -3325,6 +3325,7 @@ var Categorizer;
       // if undefined, default value (enable for first k runs, disable after).
         type: Object,
         value: TF.URIStorage.getObjectInitializer('runSelectionState', {}),
+        observer: "_storeRunToIsCheckedMapping",
       },
       // (Allows state to persist across regex filtering)
       outSelected: {
@@ -3373,24 +3374,7 @@ var Categorizer;
     },
     observers: [
       "_setIsolatorIcon(runSelectionState, names)",
-      "_storeRunToIsCheckedMappingWithDefault(runSelectionState, namesMatchingRegex)",
     ],
-    _storeRunToIsCheckedMappingWithDefault() {
-      var runSelectionStateIsDefault = Object.keys(this.runSelectionState).length == 0;
-      if (runSelectionStateIsDefault || this.namesMatchingRegex == null) {
-        return;
-      }
-      var _this = this;
-      var allToggledOn = this.namesMatchingRegex
-              .every(function(n) {return _this.runSelectionState[n]});
-      var allToggledOff = this.namesMatchingRegex
-              .every(function(n) {return !_this.runSelectionState[n]});
-      var defaultOff = this.namesMatchingRegex.length > this.maxRunsToEnableByDefault;
-      if (defaultOff && allToggledOff || !defaultOff && allToggledOn) {
-        this.runSelectionState = {};
-      }
-      this._storeRunToIsCheckedMapping(this.runSelectionState);
-    },
     _storeRunToIsCheckedMapping: TF.URIStorage.getObjectObserver('runSelectionState', {}),
     _makeRegex: function(regex) {
       try {
@@ -27156,4 +27140,4 @@ arguments[4][8][0].apply(exports,arguments)
 },{"dup":8}]},{},[35,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34]);
 </script>
 </dom-module>
-</body></html>
\ No newline at end of file
+</body></html>
diff --git a/tensorflow/tensorboard/gulp_tasks/bower.js b/tensorflow/tensorboard/gulp_tasks/bower.js
index 7c0e515c6c9..8f4666a8c15 100644
--- a/tensorflow/tensorboard/gulp_tasks/bower.js
+++ b/tensorflow/tensorboard/gulp_tasks/bower.js
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-var gulp = require('gulp');
-var bower = require('gulp-bower');
+const gulp = require('gulp');
+const bower = require('gulp-bower');
 
 module.exports = function() {
   return function() {
diff --git a/tensorflow/tensorboard/gulp_tasks/compile.js b/tensorflow/tensorboard/gulp_tasks/compile.js
index 3d0d725cfb2..01af60eba77 100644
--- a/tensorflow/tensorboard/gulp_tasks/compile.js
+++ b/tensorflow/tensorboard/gulp_tasks/compile.js
@@ -13,25 +13,25 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-var gulp = require('gulp');
-var ts = require('gulp-typescript');
-var typescript = require('typescript');
-var gutil = require('gulp-util');
-var filter = require('gulp-filter');
-var merge = require('merge2');
-var browserify = require('browserify');
-var tsify = require('tsify');
-var source = require('vinyl-source-stream');
-var glob = require('glob').sync;
-var concat = require('gulp-concat');
+const gulp = require('gulp');
+const ts = require('gulp-typescript');
+const typescript = require('typescript');
+const gutil = require('gulp-util');
+const filter = require('gulp-filter');
+const merge = require('merge2');
+const browserify = require('browserify');
+const tsify = require('tsify');
+const source = require('vinyl-source-stream');
+const glob = require('glob').sync;
+const concat = require('gulp-concat');
 
-var tsProject = ts.createProject('./tsconfig.json', {
+const tsProject = ts.createProject('./tsconfig.json', {
   typescript: typescript,
-  noExternalResolve: true, // opt-in for faster compilation!
+  noExternalResolve: true,  // opt-in for faster compilation!
 });
 
 /** List of components (and their external deps) that are using es6 modules. */
-var ES6_COMPONENTS = [{
+const ES6_COMPONENTS = [{
   name: 'vz_projector',
   deps: [
     'd3/d3.min.js', 'weblas/dist/weblas.js', 'three.js/build/three.min.js',
@@ -44,8 +44,8 @@ module.exports = function(includeDeps) {
   return function() {
     // Compile all components that are using ES6 modules into a bundle.js
     // using browserify.
-    var entries = ['typings/index.d.ts'];
-    var deps = {};
+    const entries = ['typings/index.d.ts'];
+    const deps = {};
     ES6_COMPONENTS.forEach(function(component) {
       // Collect all the typescript files across the components.
       entries = entries.concat(glob(
@@ -79,7 +79,7 @@ module.exports = function(includeDeps) {
 
     // Compile components that are using global namespaces producing 1 js file
     // for each ts file.
-    var isComponent = filter([
+    const isComponent = filter([
       'components/tf_*/**/*.ts', 'components/vz_*/**/*.ts', 'typings/**/*.ts',
       'components/plottable/plottable.d.ts'
       // Ignore components that use es6 modules.
diff --git a/tensorflow/tensorboard/gulp_tasks/test.js b/tensorflow/tensorboard/gulp_tasks/test.js
index ffa8122c7b5..0c8b14a4cda 100644
--- a/tensorflow/tensorboard/gulp_tasks/test.js
+++ b/tensorflow/tensorboard/gulp_tasks/test.js
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-var gulp = require('gulp');
-var tester = require('web-component-tester').test;
+const gulp = require('gulp');
+const tester = require('web-component-tester').test;
 
 module.exports = function(done) {
   tester({}, function(error) {
diff --git a/tensorflow/tensorboard/gulp_tasks/util.js b/tensorflow/tensorboard/gulp_tasks/util.js
index 7a1d2a58ab6..0d73f69c73a 100644
--- a/tensorflow/tensorboard/gulp_tasks/util.js
+++ b/tensorflow/tensorboard/gulp_tasks/util.js
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-var fs = require('fs');
-var path = require('path');
+const fs = require('fs');
+const path = require('path');
 
 /**
  * Returns a list of web components inside the components directory for which
@@ -34,6 +34,6 @@ exports.getComponents = function(namePredicate) {
  * directory.
  */
 exports.tbComponents = exports.getComponents(function(name) {
-  var prefix = name.slice(0, 3);
+  const prefix = name.slice(0, 3);
   return prefix == 'tf_' || prefix == 'vz_';
 });
diff --git a/tensorflow/tensorboard/gulp_tasks/vulcanize.js b/tensorflow/tensorboard/gulp_tasks/vulcanize.js
index 89700e1d4cc..d2286f1d6c5 100644
--- a/tensorflow/tensorboard/gulp_tasks/vulcanize.js
+++ b/tensorflow/tensorboard/gulp_tasks/vulcanize.js
@@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-var gulp = require('gulp');
-var path = require('path');
-var util = require('./util');
-var vulcanize = require('gulp-vulcanize');
-var replace = require('gulp-replace');
-var rename = require('gulp-rename');
-var header = require('gulp-header');
+const gulp = require('gulp');
+const path = require('path');
+const util = require('./util');
+const vulcanize = require('gulp-vulcanize');
+const replace = require('gulp-replace');
+const rename = require('gulp-rename');
+const header = require('gulp-header');
 
-var HEADER_STR = '<!-- Copyright 2015 The TensorFlow Authors. All Rights Reserved.\n\
+const HEADER_STR =
+    '<!-- Copyright 2015 The TensorFlow Authors. All Rights Reserved.\n\
 \n\
 Licensed under the Apache License, Version 2.0 (the "License");\n\
 you may not use this file except in compliance with the License.\n\
@@ -40,16 +41,16 @@ This file is generated by `gulp` & `vulcanize`. Do not directly change it.\n\
 Instead, use `gulp regenerate` to create a new version with your changes.\n\
 -->\n\n'
 
-var base = path.join(__dirname, '../components');
+const base = path.join(__dirname, '../components');
 // List of redirects of the form path1|path2 for every tensorboard component
 // in order to replace dashes with underscores.
 // E.g. .../tf-tensorboard|.../tf_tensorboard
-var redirects = util.tbComponents.map(function(dir) {
+const redirects = util.tbComponents.map(function(dir) {
   return path.join(base, dir.replace(/_/g, '-')) + '|' + path.join(base, dir);
 });
 
-var nonTBComponents = util.getComponents(function(name) {
-  var prefix = name.slice(0, 3);
+const nonTBComponents = util.getComponents(function(name) {
+  const prefix = name.slice(0, 3);
   return prefix !== 'tf_'  && prefix !== 'vz_';
 });
 
@@ -65,7 +66,7 @@ nonTBComponents.push('/tf-imports/plottable.js');
 
 module.exports = function(overwrite) {
   return function() {
-    var suffix = overwrite ? '' : '.OPENSOURCE';
+    const suffix = overwrite ? '' : '.OPENSOURCE';
     // Vulcanize TensorBoard without external libraries.
     gulp.src('components/tf_tensorboard/tf-tensorboard.html')
         .pipe(vulcanize({
diff --git a/tensorflow/tensorboard/gulpfile.js b/tensorflow/tensorboard/gulpfile.js
index 257ee0ab83d..c03c4faebcc 100644
--- a/tensorflow/tensorboard/gulpfile.js
+++ b/tensorflow/tensorboard/gulpfile.js
@@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-var gulp = require('gulp');
-var server = require('gulp-server-livereload');
-var minimist = require('minimist');
-var util = require('./gulp_tasks/util');
+const gulp = require('gulp');
+const server = require('gulp-server-livereload');
+const minimist = require('minimist');
+const util = require('./gulp_tasks/util');
 
-var options = minimist(process.argv.slice(2), {
+const options = minimist(process.argv.slice(2), {
   default: {
-    p: 8000,  // port for gulp server
-    h: '0.0.0.0', // host to serve on
+    p: 8000,       // port for gulp server
+    h: '0.0.0.0',  // host to serve on
   }
 });
 
@@ -43,8 +43,8 @@ gulp.task('watch', [], function() {
       {ignoreInitial: true}, ['compile']);
 });
 
-var httpPrefix = 'http://' + options.h + ':' + options.p + '/components';
-var proxies = util.tbComponents.map(function(component) {
+const httpPrefix = 'http://' + options.h + ':' + options.p + '/components';
+const proxies = util.tbComponents.map(function(component) {
   return {
     source: '/components' + component.replace(/_/g, '-'),
     target: httpPrefix + component
@@ -84,7 +84,7 @@ gulp.task(
 gulp.task('default', ['watch', 'server']);
 
 // Clean all compiled JS files.
-var cleanCompiledTypeScript = require('gulp-clean-compiled-typescript');
+const cleanCompiledTypeScript = require('gulp-clean-compiled-typescript');
 gulp.task('clean', function () {
   return gulp.src(['./components/**/*.ts', '!./components/**/deps.d.ts'])
       .pipe(cleanCompiledTypeScript());
diff --git a/tensorflow/tensorboard/package.json b/tensorflow/tensorboard/package.json
index 5dcf2f21e97..69f08495a30 100644
--- a/tensorflow/tensorboard/package.json
+++ b/tensorflow/tensorboard/package.json
@@ -4,7 +4,7 @@
   "description": "Visualizers for TensorFlow",
   "scripts": {
     "test": "gulp test",
-    "prepare": "npm install && bower install && typings install",
+    "prep": "npm install && bower install && typings install",
     "compile": "gulp compile"
   },
   "keywords": [
diff --git a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt
index e7e36e2bb35..5c77b3dd5cc 100644
--- a/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-fixed-length-record-reader.pbtxt
@@ -13,7 +13,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+    argspec: "args=[\'self\', \'record_bytes\', \'header_bytes\', \'footer_bytes\', \'hop_bytes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
   }
   member_method {
     name: "num_records_produced"
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 7bf7fd5719b..bfac54c6019 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -71,8 +71,8 @@ ENV BAZEL_VERSION 0.4.5
 WORKDIR /
 RUN mkdir /bazel && \
     cd /bazel && \
-    curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
-    curl -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE.txt && \
+    curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+    curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
     chmod +x bazel-*.sh && \
     ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
     cd / && \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index 769731974a2..7726cbdfbf8 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -71,8 +71,8 @@ ENV BAZEL_VERSION 0.4.5
 WORKDIR /
 RUN mkdir /bazel && \
     cd /bazel && \
-    curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
-    curl -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE.txt && \
+    curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
+    curl -H "User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/57.0.2987.133 Safari/537.36" -fSsL -o /bazel/LICENSE.txt https://raw.githubusercontent.com/bazelbuild/bazel/master/LICENSE && \
     chmod +x bazel-*.sh && \
     ./bazel-$BAZEL_VERSION-installer-linux-x86_64.sh && \
     cd / && \
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index e20d74fd4a2..3ee99d5d315 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -29,7 +29,7 @@ from setuptools.dist import Distribution
 # This version string is semver compatible, but incompatible with pip.
 # For pip, we will remove all '-' characters from this string, and use the
 # result for pip.
-_VERSION = '1.1.0-rc1'
+_VERSION = '1.1.0-rc2'
 
 REQUIRED_PACKAGES = [
     'numpy >= 1.11.0',
diff --git a/third_party/jemalloc.BUILD b/third_party/jemalloc.BUILD
index aabff39d7b2..8ed13c51a5d 100644
--- a/third_party/jemalloc.BUILD
+++ b/third_party/jemalloc.BUILD
@@ -89,6 +89,14 @@ cc_library(
         "-D_REENTRANT",
     ],
     includes = ["include"],
+    # pthread_atfork() is called for PPC.
+    linkopts = select({
+        "@%ws%//tensorflow:linux_ppc64le": [
+            "-lpthread",
+        ],
+        "//conditions:default": [
+        ],
+    }),
     visibility = ["//visibility:public"],
 )
 
@@ -183,12 +191,17 @@ sh_binary(
     srcs = ["include/jemalloc/internal/size_classes.sh"],
 )
 
-# Size classes for Linux x86_64. Update if adding builds for other
+# Size classes for Linux x86_64 and ppc64le. Update if adding builds for other
 # architectures. See size_classes.sh for details on the arguments.
+# For default case, kept the arguments same as that of  x86_64 for now.
 genrule(
     name = "size_classes_h",
     outs = ["include/jemalloc/internal/size_classes.h"],
-    cmd = "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
+    cmd = select({
+        "@%ws%//tensorflow:linux_ppc64le": "$(location :size_classes_sh) \"3 4\" 3 16 2 >$@",
+        "@%ws%//tensorflow:linux_x86_64": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
+        "//conditions:default": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
+    }),
     tools = [":size_classes_sh"],
 )
 
@@ -210,7 +223,13 @@ template_rule(
         "#undef JEMALLOC_PREFIX": "#define JEMALLOC_PREFIX \"jemalloc_\"",
         "#undef JEMALLOC_CPREFIX": "#define JEMALLOC_CPREFIX \"JEMALLOC_\"",
         "#undef JEMALLOC_PRIVATE_NAMESPACE": "#define JEMALLOC_PRIVATE_NAMESPACE je_",
-        "#undef CPU_SPINWAIT": "#define CPU_SPINWAIT __asm__ volatile(\"pause\")",
+        "#undef CPU_SPINWAIT": "\n".join([
+            "#if defined(__powerpc64__) || defined(__powerpc__)",
+            "#define CPU_SPINWAIT __asm__ volatile(\"or 27,27,27\")",
+            "#else",
+            "#define CPU_SPINWAIT __asm__ volatile(\"pause\")",
+            "#endif",
+        ]),
         "#undef JEMALLOC_HAVE_BUILTIN_CLZ": "#define JEMALLOC_HAVE_BUILTIN_CLZ",
         "#undef JEMALLOC_USE_SYSCALL": "#define JEMALLOC_USE_SYSCALL",
         "#undef JEMALLOC_HAVE_SECURE_GETENV": "#define JEMALLOC_HAVE_SECURE_GETENV",
@@ -226,7 +245,13 @@ template_rule(
         "#undef JEMALLOC_DSS": "#define JEMALLOC_DSS",
         "#undef JEMALLOC_FILL": "#define JEMALLOC_FILL",
         "#undef LG_TINY_MIN": "#define LG_TINY_MIN 3",
-        "#undef LG_PAGE": "#define LG_PAGE 12",
+        "#undef LG_PAGE": "\n".join([
+            "#if defined(__powerpc64__) || defined(__powerpc__)",
+            "#define LG_PAGE 16",
+            "#else",
+            "#define LG_PAGE 12",
+            "#endif",
+        ]),
         "#undef JEMALLOC_MAPS_COALESCE": "#define JEMALLOC_MAPS_COALESCE",
         "#undef JEMALLOC_TLS": "#define JEMALLOC_TLS",
         "#undef JEMALLOC_INTERNAL_UNREACHABLE": "#define JEMALLOC_INTERNAL_UNREACHABLE __builtin_unreachable",
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD
index 819807220b1..d5ab3262835 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.BUILD
@@ -152,6 +152,11 @@ all_cmake_vars = select({
         cmake_vars + llvm_target_cmake_vars("X86", "x86_64-apple-darwin") +
         darwin_cmake_vars,
     ),
+    "@%ws%//tensorflow:linux_ppc64le": cmake_var_string(
+        cmake_vars +
+        llvm_target_cmake_vars("PowerPC", "powerpc64le-unknown-linux_gnu") +
+        linux_cmake_vars,
+    ),
     "//conditions:default": cmake_var_string(
         cmake_vars +
         llvm_target_cmake_vars("X86", "x86_64-unknown-linux_gnu") +