diff --git a/configure b/configure
index fa07b5a4c38..030300e23a2 100755
--- a/configure
+++ b/configure
@@ -76,6 +76,70 @@ CUDA_TOOLKIT_PATH="$CUDA_TOOLKIT_PATH"
 CUDNN_INSTALL_PATH="$CUDNN_INSTALL_PATH"
 EOF
 
+function UnofficialSetting() {
+  echo -e "\nWARNING: You are configuring unofficial settings in TensorFlow. Because some external libraries are not backward compatible, these settings are largely untested and unsupported. \n"
+
+  # Configure the compute capabilities that TensorFlow builds for.
+  # Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
+  while true; do
+    fromuser=""
+    if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
+cat << EOF
+Please specify a list of comma-separated Cuda compute capabilities you want to build with.
+You can find the compute capability of your device at: https://developer.nvidia.com/cuda-gpus.
+Please note that each additional compute capability significantly increases your build time and binary size.
+EOF
+      read -p "[Default is: \"3.5,5.2\"]: " TF_CUDA_COMPUTE_CAPABILITIES
+      fromuser=1
+    fi
+    # Check whether all capabilities from the input is valid
+    COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES//,/ }
+    ALL_VALID=1
+    for CAPABILITY in $COMPUTE_CAPABILITIES; do
+      if [[ ! "$CAPABILITY" =~ [0-9]+.[0-9]+ ]]; then
+        echo "Invalid compute capability: " $CAPABILITY
+        ALL_VALID=0
+        break
+      fi
+    done
+    if [ "$ALL_VALID" == "0" ]; then
+      if [ -z "$fromuser" ]; then
+        exit 1
+      fi
+    else
+      break
+    fi
+    TF_CUDA_COMPUTE_CAPABILITIES=""
+  done
+
+  if [ ! -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
+    export WARNING="Unofficial setting. DO NOT"" SUBMIT!!!"
+    function CudaGenCodeOpts() {
+      OUTPUT=""
+      for CAPABILITY in $@; do
+        OUTPUT=${OUTPUT}"   \"${CAPABILITY}\",     "
+      done
+      echo $OUTPUT
+    }
+    export CUDA_GEN_CODES_OPTS=$(CudaGenCodeOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
+    perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\[).*?(\]),\n\1# $ENV{WARNING}\n\1\2$ENV{CUDA_GEN_CODES_OPTS}\3,s' third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
+    function CudaVersionOpts() {
+      OUTPUT=""
+      for CAPABILITY in $@; do
+        OUTPUT=$OUTPUT"CudaVersion(\"${CAPABILITY}\"), "
+      done
+      echo $OUTPUT
+    }
+    export CUDA_VERSION_OPTS=$(CudaVersionOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
+    perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\{).*?(\}),\n\1// $ENV{WARNING}\n\1\2$ENV{CUDA_VERSION_OPTS}\3,s' tensorflow/core/common_runtime/gpu/gpu_device.cc
+  fi
+}
+
+# Only run the unofficial settings when users explicitly choose to.
+if [ "$TF_UNOFFICIAL_SETTING" == "1" ]; then
+  UnofficialSetting
+fi
+
 # Invoke the cuda_config.sh and set up the TensorFlow's canonical view of the Cuda libraries
 (cd third_party/gpus/cuda; ./cuda_config.sh;) || exit -1
 
diff --git a/six.BUILD b/six.BUILD
index 0a507257bf1..5047a452e41 100644
--- a/six.BUILD
+++ b/six.BUILD
@@ -9,4 +9,5 @@ py_library(
     name = "six",
     srcs = ["six.py"],
     visibility = ["//visibility:public"],
+    srcs_version = "PY2AND3",
 )
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 7f2473f93b1..2d5a63ac925 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -294,6 +294,31 @@ Status ExecutorImpl::InferAllocAttr(
     const DeviceNameUtils::ParsedName& local_dev_name,
     AllocatorAttributes* attr) {
   Status s;
+  // Note that it's possible for *n to be a Recv and *dst to be a Send,
+  // so these two cases are not mutually exclusive.
+  if (IsRecv(n)) {
+    string src_name;
+    s = GetNodeAttr(n->def(), "send_device", &src_name);
+    if (!s.ok()) return s;
+    DeviceNameUtils::ParsedName parsed_src_name;
+    if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
+      s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
+                           n->name());
+      return s;
+    }
+    if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
+      // Value is going to be the sink of an RPC.
+      attr->set_nic_compatible(true);
+      VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
+    } else if (local_dev_name.type == "CPU" && parsed_src_name.type == "GPU") {
+      // Value is going to be the sink of a local DMA from GPU to CPU.
+      attr->set_gpu_compatible(true);
+      VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
+    } else {
+      VLOG(2) << "default alloc case local type " << local_dev_name.type
+              << " remote type " << parsed_src_name.type;
+    }
+  }
   if (IsSend(dst)) {
     string dst_name;
     s = GetNodeAttr(dst->def(), "recv_device", &dst_name);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 65174135d8a..b6bae7c0f8c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -8,6 +8,7 @@
 
 #include <stdlib.h>
 #include <string.h>
+#include <algorithm>
 
 //#include "base/commandlineflags.h"
 #include "tensorflow/stream_executor/cuda/cuda_activation.h"
@@ -590,10 +591,50 @@ static int GetMinGPUMultiprocessorCount() {
   return kDefaultMinGPUMultiprocessorCount;
 }
 
+namespace {
+
+struct CudaVersion {
+  // Initialize from version_name in the form of "3.5"
+  explicit CudaVersion(const std::string& version_name) {
+    size_t dot_pos = version_name.find('.');
+    CHECK(dot_pos != string::npos);
+    string major_str = version_name.substr(0, dot_pos);
+    CHECK(strings::safe_strto32(major_str.c_str(), &major_part));
+    string minor_str = version_name.substr(dot_pos + 1);
+    CHECK(strings::safe_strto32(minor_str.c_str(), &minor_part));
+  }
+  CudaVersion() {}
+  bool operator<(const CudaVersion& other) const {
+    if (this->major_part != other.major_part) {
+      return this->major_part < other.major_part;
+    }
+    return this->minor_part < other.minor_part;
+  }
+  friend std::ostream& operator<<(std::ostream& os,
+                                  const CudaVersion& version) {
+    os << version.major_part << "." << version.minor_part;
+    return os;
+  }
+  int major_part = -1;
+  int minor_part = -1;
+};
+
+// "configure" uses the specific name to substitute the following string.
+// If you change it, make sure you modify "configure" as well.
+std::vector<CudaVersion> supported_cuda_compute_capabilities = {
+    CudaVersion("3.5"), CudaVersion("5.2")};
+
+}  // namespace
+
 void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
   auto gpu_manager = GPUMachineManager();
   int min_gpu_core_count = GetMinGPUMultiprocessorCount();
   if (gpu_manager) {
+    CHECK(!supported_cuda_compute_capabilities.empty());
+    CudaVersion min_supported_capability =
+        *std::min_element(supported_cuda_compute_capabilities.begin(),
+                          supported_cuda_compute_capabilities.end());
+
     auto visible_device_count = gpu_manager->VisibleDeviceCount();
     for (int i = 0; i < gpu_manager->VisibleDeviceCount(); ++i) {
       auto exec_status = gpu_manager->ExecutorForDevice(i);
@@ -602,17 +643,19 @@ void BaseGPUDeviceFactory::GetValidDeviceIds(std::vector<int>* ids) {
       }
       gpu::StreamExecutor* se = exec_status.ValueOrDie();
       const gpu::DeviceDescription& desc = se->GetDeviceDescription();
-      int major, minor;
-      if (!desc.cuda_compute_capability(&major, &minor)) {
+      CudaVersion device_capability;
+      if (!desc.cuda_compute_capability(&device_capability.major_part,
+                                        &device_capability.minor_part)) {
         continue;
       }
-      // Only consider GPUs with compute capability >= 3.5 (Kepler or
-      // higher)
-      if (major < 3 || (major == 3 && minor < 5)) {
+      // Only GPUs with no less than the minimum supported compute capability is
+      // accepted.
+      if (device_capability < min_supported_capability) {
         LOG(INFO) << "Ignoring gpu device "
                   << "(" << GetShortDeviceDescription(i, desc) << ") "
-                  << "with Cuda compute capability " << major << "." << minor
-                  << ". The minimum required Cuda capability is 3.5.";
+                  << "with Cuda compute capability " << device_capability
+                  << ". The minimum required Cuda capability is "
+                  << min_supported_capability << ".";
         continue;
       }
 
diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc
index 7f551ea65f9..205098d58f5 100644
--- a/tensorflow/core/framework/rendezvous.cc
+++ b/tensorflow/core/framework/rendezvous.cc
@@ -188,9 +188,9 @@ class LocalRendezvousImpl : public Rendezvous {
     // message arrives.
     Item* item = new Item;
     item->waiter = done;
+    item->recv_alloc_attrs = recv_args.alloc_attrs;
     if (recv_args.device_context) {
       item->recv_dev_context = recv_args.device_context;
-      item->recv_alloc_attrs = recv_args.alloc_attrs;
       item->recv_dev_context->Ref();
     }
     CHECK(table_.insert({key, item}).second);
diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h
index 8e2f108c3f0..62e15437897 100644
--- a/tensorflow/core/framework/tensor_slice.h
+++ b/tensorflow/core/framework/tensor_slice.h
@@ -98,9 +98,10 @@ class TensorSlice {
   // We allow NDIMS to be greater than dims(), in which case we will pad the
   // higher dimensions with trivial dimensions.
   template <int NDIMS>
-  void FillIndicesAndSizes(const TensorShape& shape,
-                           Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
-                           Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const;
+  void FillIndicesAndSizes(
+      const TensorShape& shape,
+      Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
+      Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const;
 
   // Interaction with other TensorSlices.
 
@@ -162,8 +163,8 @@ class TensorSlice {
 
 template <int NDIMS>
 void TensorSlice::FillIndicesAndSizes(
-    const TensorShape& shape, Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
-    Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const {
+    const TensorShape& shape, Eigen::DSizes<Eigen::DenseIndex, NDIMS>* indices,
+    Eigen::DSizes<Eigen::DenseIndex, NDIMS>* sizes) const {
   CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape "
                                  << "slices: shape = " << shape.DebugString()
                                  << ", slice = " << DebugString();
diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc
index d8ce6bd85d2..aed36dccef7 100644
--- a/tensorflow/core/kernels/concat_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc
@@ -18,9 +18,9 @@ void ConcatGPU(const GPUDevice& d,
                const std::vector<
                    std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& inputs,
                typename TTypes<T, 2>::Matrix* output) {
-  Eigen::array<ptrdiff_t, 2> offset(0, 0);
+  Eigen::array<Eigen::DenseIndex, 2> offset(0, 0);
   for (int i = 0; i < inputs.size(); ++i) {
-    Eigen::array<ptrdiff_t, 2> size = inputs[i]->dimensions();
+    Eigen::array<Eigen::DenseIndex, 2> size = inputs[i]->dimensions();
     output->slice(offset, size).device(d) = *inputs[i];
     offset[1] += size[1];
   }
diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc
index 122c3c1c816..9828b460da3 100644
--- a/tensorflow/core/kernels/fifo_queue.cc
+++ b/tensorflow/core/kernels/fifo_queue.cc
@@ -17,74 +17,7 @@ namespace tensorflow {
 FIFOQueue::FIFOQueue(int capacity, const DataTypeVector& component_dtypes,
                      const std::vector<TensorShape>& component_shapes,
                      const string& name)
-    : QueueBase(component_dtypes, component_shapes, name),
-      capacity_(capacity),
-      closed_(false) {}
-
-Status FIFOQueue::Initialize() {
-  if (component_dtypes_.empty()) {
-    return errors::InvalidArgument("Empty component types for queue ", name_);
-  }
-  if (!component_shapes_.empty() &&
-      component_dtypes_.size() != component_shapes_.size()) {
-    return errors::InvalidArgument("Different number of component types (",
-                                   component_dtypes_.size(), ") vs. shapes (",
-                                   component_shapes_.size(), ").");
-  }
-
-  mutex_lock lock(mu_);
-  queues_.reserve(num_components());
-  for (int i = 0; i < num_components(); ++i) {
-    queues_.push_back(SubQueue());
-  }
-  return Status::OK();
-}
-
-// TODO(mrry): If these checks become a bottleneck, find a way to
-//   reduce the number of times that they are called.
-Status FIFOQueue::ValidateTuple(const Tuple& tuple) {
-  TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
-  if (specified_shapes()) {
-    for (size_t i = 0; i < tuple.size(); ++i) {
-      if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
-        return errors::InvalidArgument(
-            "Shape mismatch in tuple component ", i, ". Expected ",
-            component_shapes_[i].ShortDebugString(), ", got ",
-            tuple[i].shape().ShortDebugString());
-      }
-    }
-  }
-  return Status::OK();
-}
-
-// TODO(mrry): If these checks become a bottleneck, find a way to
-//   reduce the number of times that they are called.
-Status FIFOQueue::ValidateManyTuple(const Tuple& tuple) {
-  TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
-  const int64 batch_size = tuple[0].dim_size(0);
-  if (specified_shapes()) {
-    for (size_t i = 0; i < tuple.size(); ++i) {
-      // Expected shape is [batch_size] + component_shapes_[i]
-      const TensorShape expected_shape = ManyOutShape(i, batch_size);
-      if (!tuple[i].shape().IsSameSize(expected_shape)) {
-        return errors::InvalidArgument(
-            "Shape mismatch in tuple component ", i, ". Expected ",
-            expected_shape.ShortDebugString(), ", got ",
-            tuple[i].shape().ShortDebugString());
-      }
-    }
-  } else {
-    for (size_t i = 1; i < tuple.size(); ++i) {
-      if (tuple[i].dim_size(0) != batch_size) {
-        return errors::InvalidArgument(
-            "All input tensors must have the same size in the 0th ",
-            "dimension. Component ", i, " has ", tuple[i].dim_size(0),
-            ", and should have ", batch_size);
-      }
-    }
-  }
-  return Status::OK();
-}
+    : TypedQueue(capacity, component_dtypes, component_shapes, name) {}
 
 void FIFOQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
   DCHECK_GT(queues_[0].size(), 0);
@@ -95,113 +28,6 @@ void FIFOQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
   }
 }
 
-void FIFOQueue::Cancel(Action action, CancellationToken token) {
-  DoneCallback callback = nullptr;
-  {
-    mutex_lock lock(mu_);
-    std::deque<Attempt>* attempts =
-        action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
-
-    for (Attempt& attempt : *attempts) {
-      if (attempt.cancellation_token == token) {
-        attempt.is_cancelled = true;
-        if (action == kEnqueue) {
-          attempt.context->SetStatus(
-              errors::Cancelled("Enqueue operation was cancelled"));
-        } else {
-          attempt.context->SetStatus(
-              errors::Cancelled("Dequeue operation was cancelled"));
-        }
-        std::swap(callback, attempt.done_callback);
-        break;
-      }
-    }
-  }
-  if (callback) {
-    callback();
-    FlushUnlocked();
-  }
-}
-
-void FIFOQueue::CloseAndCancel() {
-  std::vector<DoneCallback> callbacks;
-  {
-    mutex_lock lock(mu_);
-    closed_ = true;
-    for (Attempt& attempt : enqueue_attempts_) {
-      attempt.is_cancelled = true;
-      attempt.context->SetStatus(
-          errors::Cancelled("Enqueue operation was cancelled"));
-      callbacks.emplace_back(std::move(attempt.done_callback));
-    }
-  }
-  for (const DoneCallback& callback : callbacks) {
-    callback();
-  }
-  FlushUnlocked();
-}
-
-bool FIFOQueue::TryAttemptLocked(Action action,
-                                 std::vector<CleanUp>* clean_up) {
-  std::deque<Attempt>* attempts =
-      action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
-
-  bool progress = false;
-  bool done = false;
-  while (!done && !attempts->empty()) {
-    if (attempts->front().is_cancelled) {
-      if (action == kEnqueue) {
-        LOG(INFO) << "Skipping cancelled enqueue attempt";
-      } else {
-        LOG(INFO) << "Skipping cancelled dequeue attempt";
-      }
-      attempts->pop_front();
-    } else {
-      Attempt* cur_attempt = &attempts->front();
-      switch (cur_attempt->run_callback(cur_attempt)) {
-        case kNoProgress:
-          done = true;
-          break;
-        case kProgress:
-          done = true;
-          progress = true;
-          break;
-        case kComplete:
-          progress = true;
-          clean_up->emplace_back(std::move(cur_attempt->done_callback),
-                                 cur_attempt->cancellation_token,
-                                 cur_attempt->context->cancellation_manager());
-          attempts->pop_front();
-          break;
-      }
-    }
-  }
-  return progress;
-}
-
-void FIFOQueue::FlushUnlocked() {
-  std::vector<CleanUp> clean_up;
-  Ref();
-  {
-    mutex_lock lock(mu_);
-    bool changed;
-    do {
-      changed = TryAttemptLocked(kEnqueue, &clean_up);
-      changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
-    } while (changed);
-  }
-  Unref();
-  for (const auto& to_clean : clean_up) {
-    if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
-      // NOTE(mrry): We can safely ignore the return value of
-      // DeregisterCallback because the mutex mu_ ensures that the
-      // cleanup action only executes once.
-      to_clean.cm->DeregisterCallback(to_clean.to_deregister);
-    }
-    to_clean.finished();
-  }
-}
-
 void FIFOQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
                            DoneCallback callback) {
   CancellationManager* cm = ctx->cancellation_manager();
@@ -484,30 +310,6 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
   }
 }
 
-void FIFOQueue::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
-                      DoneCallback callback) {
-  if (cancel_pending_enqueues) {
-    CloseAndCancel();
-    callback();
-  } else {
-    {
-      mutex_lock lock(mu_);
-      enqueue_attempts_.emplace_back(
-          0, callback, ctx, CancellationManager::kInvalidToken,
-          [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-            if (closed_) {
-              attempt->context->SetStatus(errors::Aborted(
-                  "FIFOQueue '", name_, "' is already closed."));
-            } else {
-              closed_ = true;
-            }
-            return kComplete;
-          });
-    }
-    FlushUnlocked();
-  }
-}
-
 Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
   TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "FIFOQueue"));
   TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h
index e9fe5f34a4e..4fc0ed75d2d 100644
--- a/tensorflow/core/kernels/fifo_queue.h
+++ b/tensorflow/core/kernels/fifo_queue.h
@@ -6,24 +6,21 @@
 
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/kernels/typed_queue.h"
 #include "tensorflow/core/platform/port.h"
 #include "tensorflow/core/public/tensor.h"
 #include "tensorflow/core/public/tensor_shape.h"
 
 namespace tensorflow {
 
-class FIFOQueue : public QueueBase {
+class FIFOQueue : public TypedQueue<std::deque<PersistentTensor> > {
  public:
   FIFOQueue(int32 capacity, const DataTypeVector& component_dtypes,
             const std::vector<TensorShape>& component_shapes,
             const string& name);
-  Status Initialize();  // Must be called before any other method.
 
   // Implementations of QueueInterface methods --------------------------------
 
-  Status ValidateTuple(const Tuple& tuple) override;
-  Status ValidateManyTuple(const Tuple& tuple) override;
   void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
                   DoneCallback callback) override;
   void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
@@ -31,8 +28,6 @@ class FIFOQueue : public QueueBase {
   void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
   void TryDequeueMany(int num_elements, OpKernelContext* ctx,
                       CallbackWithTuple callback) override;
-  void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
-             DoneCallback callback) override;
   Status MatchesNodeDef(const NodeDef& node_def) override;
 
   int32 size() override {
@@ -40,80 +35,13 @@ class FIFOQueue : public QueueBase {
     return queues_[0].size();
   }
 
-  int32 capacity() const { return capacity_; }
-
  private:
-  enum Action { kEnqueue, kDequeue };
-
   ~FIFOQueue() override {}
 
-  TensorShape ManyOutShape(int i, int64 batch_size) {
-    TensorShape shape({batch_size});
-    shape.AppendShape(component_shapes_[i]);
-    return shape;
-  }
-
   // Helper for dequeuing a single element from queues_.
   void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
       EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
-  void Cancel(Action action, CancellationToken token);
-
-  // Helper for cancelling all pending Enqueue(Many) operations when
-  // Close is called with cancel_pending_enqueues.
-  void CloseAndCancel();
-
-  // Tries to enqueue/dequeue (or close) based on whatever is at the
-  // front of enqueue_attempts_/dequeue_attempts_.  Appends to
-  // *finished the callback for any finished attempt (so it may be
-  // called once mu_ is released).  Returns true if any progress was
-  // made.
-  struct CleanUp {
-    CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
-        : finished(f), to_deregister(ct), cm(cm) {}
-    DoneCallback finished;
-    CancellationToken to_deregister;
-    CancellationManager* cm;
-  };
-  bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
-      EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Tries to make progress on the enqueues or dequeues at the front
-  // of the *_attempts_ queues.
-  void FlushUnlocked();
-
-  const int32 capacity_;
-
-  mutex mu_;
-  typedef std::deque<PersistentTensor> SubQueue;
-  std::vector<SubQueue> queues_ GUARDED_BY(mu_);
-  bool closed_ GUARDED_BY(mu_);
-
-  enum RunResult { kNoProgress, kProgress, kComplete };
-  struct Attempt;
-  typedef std::function<RunResult(Attempt*)> RunCallback;
-  struct Attempt {
-    int32 elements_requested;
-    DoneCallback done_callback;  // must be run outside mu_
-    OpKernelContext* context;
-    CancellationToken cancellation_token;
-    RunCallback run_callback;  // must be run while holding mu_
-    bool is_cancelled;
-    Tuple tuple;
-
-    Attempt(int32 elements_requested, DoneCallback done_callback,
-            OpKernelContext* context, CancellationToken cancellation_token,
-            RunCallback run_callback)
-        : elements_requested(elements_requested),
-          done_callback(done_callback),
-          context(context),
-          cancellation_token(cancellation_token),
-          run_callback(run_callback),
-          is_cancelled(false) {}
-  };
-  std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
-  std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
-
   static Status GetElementComponentFromBatch(const Tuple& tuple, int index,
                                              int component,
                                              OpKernelContext* ctx,
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
index 4b67304a371..bb2657085e0 100644
--- a/tensorflow/core/kernels/lrn_op.cc
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -23,8 +23,8 @@ static void GetBandMatrix(int depth, int64 depth_radius,
   for (int row = 0; row < depth; ++row) {
     const int begin = std::max<int>(0, row - depth_radius);
     const int end = std::min<int64>(depth, row + depth_radius + 1);
-    Eigen::DSizes<ptrdiff_t, 2> start(row, begin);
-    Eigen::DSizes<ptrdiff_t, 2> sizes(1, end - begin);
+    Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
+    Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
     result->slice(start, sizes).setConstant(1.0f);
   }
 }
diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h
index 5bf44b6e404..d086b6850ec 100644
--- a/tensorflow/core/kernels/pooling_ops_common.h
+++ b/tensorflow/core/kernels/pooling_ops_common.h
@@ -243,7 +243,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
             std::min(wpad / params.col_stride + 1, params.out_width);
         const int in_offset =
             (b * params.tensor_in_rows + h) * params.tensor_in_cols + w;
-        Eigen::DSizes<ptrdiff_t, 2> in_indices(0, in_offset);
+        Eigen::DSizes<Eigen::DenseIndex, 2> in_indices(0, in_offset);
         for (int ph = h_start; ph < h_end; ++ph) {
           for (int pw = w_start; pw < w_end; ++pw) {
             const int out_offset =
diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc
index 4217b9ce864..d0f47505c45 100644
--- a/tensorflow/core/kernels/queue_base.cc
+++ b/tensorflow/core/kernels/queue_base.cc
@@ -46,52 +46,14 @@ Status HandleElementToSlice(const Tensor& element, Tensor* parent, int index) {
 
 }  // namespace
 
-// static
-Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
-                                     int index) {
-#define HANDLE_TYPE(DT)                                                   \
-  if (parent.dtype() == DT) {                                             \
-    TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \
-    return Status::OK();                                                  \
-  }
-  HANDLE_TYPE(DT_FLOAT);
-  HANDLE_TYPE(DT_DOUBLE);
-  HANDLE_TYPE(DT_INT32);
-  HANDLE_TYPE(DT_UINT8);
-  HANDLE_TYPE(DT_INT16);
-  HANDLE_TYPE(DT_INT8);
-  HANDLE_TYPE(DT_STRING);
-  HANDLE_TYPE(DT_INT64);
-#undef HANDLE_TYPE
-  return errors::Unimplemented("Unhandled data type: ", parent.dtype());
-}
-
-// static
-Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
-                                     int index) {
-#define HANDLE_TYPE(DT)                                                   \
-  if (element.dtype() == DT) {                                            \
-    TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \
-    return Status::OK();                                                  \
-  }
-  HANDLE_TYPE(DT_FLOAT);
-  HANDLE_TYPE(DT_DOUBLE);
-  HANDLE_TYPE(DT_INT32);
-  HANDLE_TYPE(DT_UINT8);
-  HANDLE_TYPE(DT_INT16);
-  HANDLE_TYPE(DT_INT8);
-  HANDLE_TYPE(DT_STRING);
-  HANDLE_TYPE(DT_INT64);
-#undef HANDLE_TYPE
-  return errors::Unimplemented("Unhandled data type: ", element.dtype());
-}
-
-QueueBase::QueueBase(const DataTypeVector& component_dtypes,
+QueueBase::QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
                      const std::vector<TensorShape>& component_shapes,
                      const string& name)
-    : component_dtypes_(component_dtypes),
+    : capacity_(capacity),
+      component_dtypes_(component_dtypes),
       component_shapes_(component_shapes),
-      name_(name) {}
+      name_(name),
+      closed_(false) {}
 
 Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const {
   if (tuple.size() != static_cast<size_t>(num_components())) {
@@ -172,4 +134,221 @@ Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const {
   return Status::OK();
 }
 
+// TODO(mrry): If these checks become a bottleneck, find a way to
+//   reduce the number of times that they are called.
+Status QueueBase::ValidateTuple(const Tuple& tuple) {
+  TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+  if (specified_shapes()) {
+    for (size_t i = 0; i < tuple.size(); ++i) {
+      if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
+        return errors::InvalidArgument(
+            "Shape mismatch in tuple component ", i, ". Expected ",
+            component_shapes_[i].ShortDebugString(), ", got ",
+            tuple[i].shape().ShortDebugString());
+      }
+    }
+  }
+  return Status::OK();
+}
+
+// TODO(mrry): If these checks become a bottleneck, find a way to
+//   reduce the number of times that they are called.
+Status QueueBase::ValidateManyTuple(const Tuple& tuple) {
+  TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
+  const int64 batch_size = tuple[0].dim_size(0);
+  if (specified_shapes()) {
+    for (size_t i = 0; i < tuple.size(); ++i) {
+      // Expected shape is [batch_size] + component_shapes_[i]
+      const TensorShape expected_shape = ManyOutShape(i, batch_size);
+      if (!tuple[i].shape().IsSameSize(expected_shape)) {
+        return errors::InvalidArgument(
+            "Shape mismatch in tuple component ", i, ". Expected ",
+            expected_shape.ShortDebugString(), ", got ",
+            tuple[i].shape().ShortDebugString());
+      }
+    }
+  } else {
+    for (size_t i = 1; i < tuple.size(); ++i) {
+      if (tuple[i].dim_size(0) != batch_size) {
+        return errors::InvalidArgument(
+            "All input tensors must have the same size in the 0th ",
+            "dimension. Component ", i, " has ", tuple[i].dim_size(0),
+            ", and should have ", batch_size);
+      }
+    }
+  }
+  return Status::OK();
+}
+
+void QueueBase::Cancel(Action action, CancellationToken token) {
+  DoneCallback callback = nullptr;
+  {
+    mutex_lock lock(mu_);
+    std::deque<Attempt>* attempts =
+        action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+    for (Attempt& attempt : *attempts) {
+      if (attempt.cancellation_token == token) {
+        attempt.is_cancelled = true;
+        if (action == kEnqueue) {
+          attempt.context->SetStatus(
+              errors::Cancelled("Enqueue operation was cancelled"));
+        } else {
+          attempt.context->SetStatus(
+              errors::Cancelled("Dequeue operation was cancelled"));
+        }
+        std::swap(callback, attempt.done_callback);
+        break;
+      }
+    }
+  }
+  if (callback) {
+    callback();
+    FlushUnlocked();
+  }
+}
+
+void QueueBase::CloseAndCancel() {
+  std::vector<DoneCallback> callbacks;
+  {
+    mutex_lock lock(mu_);
+    closed_ = true;
+    for (Attempt& attempt : enqueue_attempts_) {
+      attempt.is_cancelled = true;
+      attempt.context->SetStatus(
+          errors::Cancelled("Enqueue operation was cancelled"));
+      callbacks.emplace_back(std::move(attempt.done_callback));
+    }
+  }
+  for (const DoneCallback& callback : callbacks) {
+    callback();
+  }
+  FlushUnlocked();
+}
+
+void QueueBase::Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+                      DoneCallback callback) {
+  if (cancel_pending_enqueues) {
+    CloseAndCancel();
+    callback();
+  } else {
+    {
+      mutex_lock lock(mu_);
+      enqueue_attempts_.emplace_back(
+          0, callback, ctx, CancellationManager::kInvalidToken,
+          [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+            if (closed_) {
+              attempt->context->SetStatus(
+                  errors::Aborted("Queue '", name_, "' is already closed."));
+            } else {
+              closed_ = true;
+            }
+            return kComplete;
+          });
+    }
+    FlushUnlocked();
+  }
+}
+
+bool QueueBase::TryAttemptLocked(Action action,
+                                 std::vector<CleanUp>* clean_up) {
+  std::deque<Attempt>* attempts =
+      action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
+
+  bool progress = false;
+  bool done = false;
+  while (!done && !attempts->empty()) {
+    if (attempts->front().is_cancelled) {
+      if (action == kEnqueue) {
+        LOG(INFO) << "Skipping cancelled enqueue attempt";
+      } else {
+        LOG(INFO) << "Skipping cancelled dequeue attempt";
+      }
+      attempts->pop_front();
+    } else {
+      Attempt* cur_attempt = &attempts->front();
+      switch (cur_attempt->run_callback(cur_attempt)) {
+        case kNoProgress:
+          done = true;
+          break;
+        case kProgress:
+          done = true;
+          progress = true;
+          break;
+        case kComplete:
+          progress = true;
+          clean_up->emplace_back(std::move(cur_attempt->done_callback),
+                                 cur_attempt->cancellation_token,
+                                 cur_attempt->context->cancellation_manager());
+          attempts->pop_front();
+          break;
+      }
+    }
+  }
+  return progress;
+}
+
+void QueueBase::FlushUnlocked() {
+  std::vector<CleanUp> clean_up;
+  Ref();
+  {
+    mutex_lock lock(mu_);
+    bool changed;
+    do {
+      changed = TryAttemptLocked(kEnqueue, &clean_up);
+      changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
+    } while (changed);
+  }
+  Unref();
+  for (const auto& to_clean : clean_up) {
+    if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
+      // NOTE(mrry): We can safely ignore the return value of
+      // DeregisterCallback because the mutex mu_ ensures that the
+      // cleanup action only executes once.
+      to_clean.cm->DeregisterCallback(to_clean.to_deregister);
+    }
+    to_clean.finished();
+  }
+}
+
+// Static method
+Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element,
+                                     int index) {
+#define HANDLE_TYPE(DT)                                                   \
+  if (parent.dtype() == DT) {                                             \
+    TF_RETURN_IF_ERROR(HandleSliceToElement<DT>(parent, element, index)); \
+    return Status::OK();                                                  \
+  }
+  HANDLE_TYPE(DT_FLOAT);
+  HANDLE_TYPE(DT_DOUBLE);
+  HANDLE_TYPE(DT_INT32);
+  HANDLE_TYPE(DT_UINT8);
+  HANDLE_TYPE(DT_INT16);
+  HANDLE_TYPE(DT_INT8);
+  HANDLE_TYPE(DT_STRING);
+  HANDLE_TYPE(DT_INT64);
+#undef HANDLE_TYPE
+  return errors::Unimplemented("Unhandled data type: ", parent.dtype());
+}
+
+// Static method
+Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent,
+                                     int index) {
+#define HANDLE_TYPE(DT)                                                   \
+  if (element.dtype() == DT) {                                            \
+    TF_RETURN_IF_ERROR(HandleElementToSlice<DT>(element, parent, index)); \
+    return Status::OK();                                                  \
+  }
+  HANDLE_TYPE(DT_FLOAT);
+  HANDLE_TYPE(DT_DOUBLE);
+  HANDLE_TYPE(DT_INT32);
+  HANDLE_TYPE(DT_UINT8);
+  HANDLE_TYPE(DT_INT16);
+  HANDLE_TYPE(DT_INT8);
+  HANDLE_TYPE(DT_STRING);
+  HANDLE_TYPE(DT_INT64);
+#undef HANDLE_TYPE
+  return errors::Unimplemented("Unhandled data type: ", element.dtype());
+}
+
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 4897102974a..d32d98b7eb9 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -1,6 +1,9 @@
 #ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
 #define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
 
+#include <deque>
+#include <vector>
+
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/queue_interface.h"
 #include "tensorflow/core/framework/types.h"
@@ -11,7 +14,7 @@
 
 namespace tensorflow {
 
-// Functionality common to QueueInterface implementations.
+// Functionality common to asynchronous QueueInterface implementations.
 class QueueBase : public QueueInterface {
  public:
   // As a possible value of 'capacity'.
@@ -23,7 +26,7 @@ class QueueBase : public QueueInterface {
   //     which must either be empty (if the shapes are not specified) or
   //     or have the same size as component_dtypes.
   //   name: A name to use for the queue.
-  QueueBase(const DataTypeVector& component_dtypes,
+  QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
             const std::vector<TensorShape>& component_shapes,
             const string& name);
 
@@ -32,12 +35,36 @@ class QueueBase : public QueueInterface {
     return component_dtypes_;
   }
 
+  Status ValidateTuple(const Tuple& tuple) override;
+  Status ValidateManyTuple(const Tuple& tuple) override;
+
+  void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
+             DoneCallback callback) override;
+
   // Other public methods -----------------------------------------------------
   const std::vector<TensorShape>& component_shapes() const {
     return component_shapes_;
   }
 
+  int32 capacity() const { return capacity_; }
+
  protected:
+  enum Action { kEnqueue, kDequeue };
+  enum RunResult { kNoProgress, kProgress, kComplete };
+
+  // Tries to enqueue/dequeue (or close) based on whatever is at the
+  // front of enqueue_attempts_/dequeue_attempts_.  Appends to
+  // *finished the callback for any finished attempt (so it may be
+  // called once mu_ is released).  Returns true if any progress was
+  // made.
+  struct CleanUp {
+    CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
+        : finished(f), to_deregister(ct), cm(cm) {}
+    DoneCallback finished;
+    CancellationToken to_deregister;
+    CancellationManager* cm;
+  };
+
   // Returns the number of components in a queue-element tuple.
   int32 num_components() const { return component_dtypes_.size(); }
 
@@ -48,6 +75,12 @@ class QueueBase : public QueueInterface {
   // Code common to Validate*Tuple().
   Status ValidateTupleCommon(const Tuple& tuple) const;
 
+  TensorShape ManyOutShape(int i, int64 batch_size) {
+    TensorShape shape({batch_size});
+    shape.AppendShape(component_shapes_[i]);
+    return shape;
+  }
+
   // Copies the index^th slice (in the first dimension) of parent into element.
   static Status CopySliceToElement(const Tensor& parent, Tensor* element,
                                    int index);
@@ -56,6 +89,19 @@ class QueueBase : public QueueInterface {
   static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
                                    int index);
 
+  void Cancel(Action action, CancellationToken token);
+
+  // Helper for cancelling all pending Enqueue(Many) operations when
+  // Close is called with cancel_pending_enqueues.
+  void CloseAndCancel();
+
+  bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
+      EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+  // Tries to make progress on the enqueues or dequeues at the front
+  // of the *_attempts_ queues.
+  void FlushUnlocked();
+
   ~QueueBase() override {}
 
   // Helpers for implementing MatchesNodeDef().
@@ -65,9 +111,37 @@ class QueueBase : public QueueInterface {
   Status MatchesNodeDefTypes(const NodeDef& node_def) const;
   Status MatchesNodeDefShapes(const NodeDef& node_def) const;
 
+ protected:
+  const int32 capacity_;
   const DataTypeVector component_dtypes_;
   const std::vector<TensorShape> component_shapes_;
   const string name_;
+  mutex mu_;
+  bool closed_ GUARDED_BY(mu_);
+
+  struct Attempt;
+  typedef std::function<RunResult(Attempt*)> RunCallback;
+  struct Attempt {
+    int32 elements_requested;
+    DoneCallback done_callback;  // must be run outside mu_
+    OpKernelContext* context;
+    CancellationToken cancellation_token;
+    RunCallback run_callback;  // must be run while holding mu_
+    bool is_cancelled;
+    Tuple tuple;
+
+    Attempt(int32 elements_requested, DoneCallback done_callback,
+            OpKernelContext* context, CancellationToken cancellation_token,
+            RunCallback run_callback)
+        : elements_requested(elements_requested),
+          done_callback(done_callback),
+          context(context),
+          cancellation_token(cancellation_token),
+          run_callback(run_callback),
+          is_cancelled(false) {}
+  };
+  std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
+  std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
 
   TF_DISALLOW_COPY_AND_ASSIGN(QueueBase);
 };
diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc
index 561ec76e53c..0723e4fc617 100644
--- a/tensorflow/core/kernels/random_shuffle_queue_op.cc
+++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc
@@ -6,7 +6,7 @@
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/queue_base.h"
+#include "tensorflow/core/kernels/typed_queue.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/random/philox_random.h"
 #include "tensorflow/core/lib/random/random.h"
@@ -19,18 +19,16 @@
 
 namespace tensorflow {
 
-class RandomShuffleQueue : public QueueBase {
+class RandomShuffleQueue : public TypedQueue<std::vector<PersistentTensor> > {
  public:
   RandomShuffleQueue(int32 capacity, int32 min_after_dequeue, int64 seed,
                      int64 seed2, const DataTypeVector& component_dtypes,
                      const std::vector<TensorShape>& component_shapes,
                      const string& name);
-  Status Initialize();  // Must be called before any other method.
+
+  Status Initialize() override;  // Must be called before any other method.
 
   // Implementations of QueueInterface methods --------------------------------
-
-  Status ValidateTuple(const Tuple& tuple) override;
-  Status ValidateManyTuple(const Tuple& tuple) override;
   void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
                   DoneCallback callback) override;
   void TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
@@ -38,8 +36,6 @@ class RandomShuffleQueue : public QueueBase {
   void TryDequeue(OpKernelContext* ctx, CallbackWithTuple callback) override;
   void TryDequeueMany(int num_elements, OpKernelContext* ctx,
                       CallbackWithTuple callback) override;
-  void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
-             DoneCallback callback) override;
   Status MatchesNodeDef(const NodeDef& node_def) override;
 
   int32 size() override {
@@ -48,95 +44,30 @@ class RandomShuffleQueue : public QueueBase {
   }
 
  private:
-  enum Action { kEnqueue, kDequeue };
-
   ~RandomShuffleQueue() override {}
 
-  TensorShape ManyOutShape(int i, int batch_size) {
-    TensorShape shape({batch_size});
-    shape.AppendShape(component_shapes_[i]);
-    return shape;
-  }
-
   // Helper for dequeuing a single random element from queues_.
   void DequeueLocked(OpKernelContext* ctx, Tuple* tuple)
       EXCLUSIVE_LOCKS_REQUIRED(mu_);
 
-  void Cancel(Action action, CancellationToken token);
-
-  // Helper for cancelling all pending Enqueue(Many) operations when
-  // Close is called with cancel_pending_enqueues.
-  void CloseAndCancel();
-
-  // Tries to enqueue/dequeue (or close) based on whatever is at the
-  // front of enqueue_attempts_/dequeue_attempts_.  Appends to
-  // *finished the callback for any finished attempt (so it may be
-  // called once mu_ is released).  Returns true if any progress was
-  // made.
-  struct CleanUp {
-    CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
-        : finished(f), to_deregister(ct), cm(cm) {}
-    DoneCallback finished;
-    CancellationToken to_deregister;
-    CancellationManager* cm;
-  };
-  bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
-      EXCLUSIVE_LOCKS_REQUIRED(mu_);
-
-  // Tries to make progress on the enqueues or dequeues at the front
-  // of the *_attempts_ queues.
-  void FlushUnlocked();
-
-  const int32 capacity_;
   const int32 min_after_dequeue_;
   const int64 original_seed_;
   const int64 original_seed2_;
 
-  mutex mu_;
-  typedef std::vector<PersistentTensor> SubQueue;
-  std::vector<SubQueue> queues_ GUARDED_BY(mu_);
-  bool closed_ GUARDED_BY(mu_);
   random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
   random::SingleSampleAdapter<random::PhiloxRandom> generator_ GUARDED_BY(mu_);
 
-  enum RunResult { kNoProgress, kProgress, kComplete };
-  struct Attempt;
-  typedef std::function<RunResult(Attempt*)> RunCallback;
-  struct Attempt {
-    int32 elements_requested;
-    DoneCallback done_callback;  // must be run outside mu_
-    OpKernelContext* context;
-    CancellationToken cancellation_token;
-    RunCallback run_callback;  // must be run while holding mu_
-    bool is_cancelled;
-    Tuple tuple;
-
-    Attempt(int32 elements_requested, DoneCallback done_callback,
-            OpKernelContext* context, CancellationToken cancellation_token,
-            RunCallback run_callback)
-        : elements_requested(elements_requested),
-          done_callback(done_callback),
-          context(context),
-          cancellation_token(cancellation_token),
-          run_callback(run_callback),
-          is_cancelled(false) {}
-  };
-  std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
-  std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
-
   TF_DISALLOW_COPY_AND_ASSIGN(RandomShuffleQueue);
 };
 
 RandomShuffleQueue::RandomShuffleQueue(
-    int capacity, int min_after_dequeue, int64 seed, int64 seed2,
+    int32 capacity, int32 min_after_dequeue, int64 seed, int64 seed2,
     const DataTypeVector& component_dtypes,
     const std::vector<TensorShape>& component_shapes, const string& name)
-    : QueueBase(component_dtypes, component_shapes, name),
-      capacity_(capacity),
+    : TypedQueue(capacity, component_dtypes, component_shapes, name),
       min_after_dequeue_(min_after_dequeue),
       original_seed_(seed),
       original_seed2_(seed2),
-      closed_(false),
       generator_(&parent_generator_) {
   if (seed == 0 && seed2 == 0) {
     // If both seeds are unspecified, use completely random seeds.
@@ -147,71 +78,16 @@ RandomShuffleQueue::RandomShuffleQueue(
 }
 
 Status RandomShuffleQueue::Initialize() {
-  if (component_dtypes_.empty()) {
-    return errors::InvalidArgument("Empty component types for queue ", name_);
-  }
-  if (!component_shapes_.empty() &&
-      component_dtypes_.size() != component_shapes_.size()) {
-    return errors::InvalidArgument("Different number of component types (",
-                                   component_dtypes_.size(), ") vs. shapes (",
-                                   component_shapes_.size(), ").");
-  }
+  Status s = TypedQueue::Initialize();
+  if (!s.ok()) return s;
 
   mutex_lock lock(mu_);
-  queues_.reserve(num_components());
   for (int i = 0; i < num_components(); ++i) {
-    queues_.push_back(SubQueue());
     queues_.back().reserve(min_after_dequeue_);
   }
   return Status::OK();
 }
 
-// TODO(mrry): If these checks become a bottleneck, find a way to
-//   reduce the number of times that they are called.
-Status RandomShuffleQueue::ValidateTuple(const Tuple& tuple) {
-  TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
-  if (specified_shapes()) {
-    for (size_t i = 0; i < tuple.size(); ++i) {
-      if (!tuple[i].shape().IsSameSize(component_shapes_[i])) {
-        return errors::InvalidArgument(
-            "Shape mismatch in tuple component ", i, ". Expected ",
-            component_shapes_[i].ShortDebugString(), ", got ",
-            tuple[i].shape().ShortDebugString());
-      }
-    }
-  }
-  return Status::OK();
-}
-
-// TODO(mrry): If these checks become a bottleneck, find a way to
-//   reduce the number of times that they are called.
-Status RandomShuffleQueue::ValidateManyTuple(const Tuple& tuple) {
-  TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
-  const int64 batch_size = tuple[0].dim_size(0);
-  if (specified_shapes()) {
-    for (size_t i = 0; i < tuple.size(); ++i) {
-      // Expected shape is [batch_size] + component_shapes_[i]
-      const TensorShape expected_shape = ManyOutShape(i, batch_size);
-      if (!tuple[i].shape().IsSameSize(expected_shape)) {
-        return errors::InvalidArgument(
-            "Shape mismatch in tuple component ", i, ". Expected ",
-            expected_shape.ShortDebugString(), ", got ",
-            tuple[i].shape().ShortDebugString());
-      }
-    }
-  } else {
-    for (size_t i = 1; i < tuple.size(); ++i) {
-      if (tuple[i].dim_size(0) != batch_size) {
-        return errors::InvalidArgument(
-            "All input tensors must have the same size in the 0th ",
-            "dimension. Component ", i, " has ", tuple[i].dim_size(0),
-            ", and should have ", batch_size);
-      }
-    }
-  }
-  return Status::OK();
-}
-
 void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
   DCHECK_GT(queues_[0].size(), 0);
   int64 index = generator_() % queues_[0].size();
@@ -223,113 +99,6 @@ void RandomShuffleQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
   }
 }
 
-void RandomShuffleQueue::Cancel(Action action, CancellationToken token) {
-  DoneCallback callback = nullptr;
-  {
-    mutex_lock lock(mu_);
-    std::deque<Attempt>* attempts =
-        action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
-
-    for (Attempt& attempt : *attempts) {
-      if (attempt.cancellation_token == token) {
-        attempt.is_cancelled = true;
-        if (action == kEnqueue) {
-          attempt.context->SetStatus(
-              errors::Cancelled("Enqueue operation was cancelled"));
-        } else {
-          attempt.context->SetStatus(
-              errors::Cancelled("Dequeue operation was cancelled"));
-        }
-        std::swap(callback, attempt.done_callback);
-        break;
-      }
-    }
-  }
-  if (callback) {
-    callback();
-    FlushUnlocked();
-  }
-}
-
-void RandomShuffleQueue::CloseAndCancel() {
-  std::vector<DoneCallback> callbacks;
-  {
-    mutex_lock lock(mu_);
-    closed_ = true;
-    for (Attempt& attempt : enqueue_attempts_) {
-      attempt.is_cancelled = true;
-      attempt.context->SetStatus(
-          errors::Cancelled("Enqueue operation was cancelled"));
-      callbacks.emplace_back(std::move(attempt.done_callback));
-    }
-  }
-  for (const DoneCallback& callback : callbacks) {
-    callback();
-  }
-  FlushUnlocked();
-}
-
-bool RandomShuffleQueue::TryAttemptLocked(
-    Action action, std::vector<CleanUp>* clean_up) {
-  std::deque<Attempt>* attempts =
-      action == kEnqueue ? &enqueue_attempts_ : &dequeue_attempts_;
-
-  bool progress = false;
-  bool done = false;
-  while (!done && !attempts->empty()) {
-    if (attempts->front().is_cancelled) {
-      if (action == kEnqueue) {
-        LOG(INFO) << "Skipping cancelled enqueue attempt";
-      } else {
-        LOG(INFO) << "Skipping cancelled dequeue attempt";
-      }
-      attempts->pop_front();
-    } else {
-      Attempt* cur_attempt = &attempts->front();
-      switch (cur_attempt->run_callback(cur_attempt)) {
-        case kNoProgress:
-          done = true;
-          break;
-        case kProgress:
-          done = true;
-          progress = true;
-          break;
-        case kComplete:
-          progress = true;
-          clean_up->emplace_back(std::move(cur_attempt->done_callback),
-                                 cur_attempt->cancellation_token,
-                                 cur_attempt->context->cancellation_manager());
-          attempts->pop_front();
-          break;
-      }
-    }
-  }
-  return progress;
-}
-
-void RandomShuffleQueue::FlushUnlocked() {
-  std::vector<CleanUp> clean_up;
-  Ref();
-  {
-    mutex_lock lock(mu_);
-    bool changed;
-    do {
-      changed = TryAttemptLocked(kEnqueue, &clean_up);
-      changed = TryAttemptLocked(kDequeue, &clean_up) || changed;
-    } while (changed);
-  }
-  Unref();
-  for (const auto& to_clean : clean_up) {
-    if (to_clean.to_deregister != CancellationManager::kInvalidToken) {
-      // NOTE(mrry): We can safely ignore the return value of
-      // DeregisterCallback because the mutex mu_ ensures that the
-      // cleanup action only executes once.
-      to_clean.cm->DeregisterCallback(to_clean.to_deregister);
-    }
-    to_clean.finished();
-  }
-}
-
 void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
                                     DoneCallback callback) {
   CancellationManager* cm = ctx->cancellation_manager();
@@ -583,31 +352,6 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
   }
 }
 
-void RandomShuffleQueue::Close(OpKernelContext* ctx,
-                               bool cancel_pending_enqueues,
-                               DoneCallback callback) {
-  if (cancel_pending_enqueues) {
-    CloseAndCancel();
-    callback();
-  } else {
-    {
-      mutex_lock lock(mu_);
-      enqueue_attempts_.emplace_back(
-          0, callback, ctx, CancellationManager::kInvalidToken,
-          [this](Attempt* attempt) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
-            if (closed_) {
-              attempt->context->SetStatus(errors::Aborted(
-                  "RandomShuffleQueue '", name_, "' is already closed."));
-            } else {
-              closed_ = true;
-            }
-            return kComplete;
-          });
-    }
-    FlushUnlocked();
-  }
-}
-
 Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
   TF_RETURN_IF_ERROR(MatchesNodeDefOp(node_def, "RandomShuffleQueue"));
   TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
@@ -640,8 +384,6 @@ Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) {
   return Status::OK();
 }
 
-typedef std::shared_ptr<QueueInterface> QueueInterfacePtr;
-
 // Defines a RandomShuffleQueueOp, which produces a Queue (specifically, one
 // backed by RandomShuffleQueue) that persists across different graph
 // executions, and sessions. Running this op produces a single-element
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 3477266d5da..7e55149cd1e 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -171,8 +171,8 @@ class SliceOp : public OpKernel {
   template <int NDIM>
   void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
                   const gtl::ArraySlice<int64>& size, Tensor* result) {
-    Eigen::DSizes<ptrdiff_t, NDIM> indices;
-    Eigen::DSizes<ptrdiff_t, NDIM> sizes;
+    Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
+    Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
     for (int i = 0; i < NDIM; ++i) {
       indices[i] = begin[i];
       sizes[i] = size[i];
@@ -205,8 +205,8 @@ namespace functor {
   void Slice<GPUDevice, T, NDIM>::operator()(                      \
       const GPUDevice& d, typename TTypes<T, NDIM>::Tensor output, \
       typename TTypes<T, NDIM>::ConstTensor input,                 \
-      const Eigen::DSizes<ptrdiff_t, NDIM>& indices,               \
-      const Eigen::DSizes<ptrdiff_t, NDIM>& sizes);                \
+      const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,       \
+      const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes);        \
   extern template struct Slice<GPUDevice, T, NDIM>;
 
 #define DECLARE_FOR_N(T)  \
diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h
index 1b6bd9c1125..89bc8be8acc 100644
--- a/tensorflow/core/kernels/slice_op.h
+++ b/tensorflow/core/kernels/slice_op.h
@@ -13,8 +13,8 @@ template <typename Device, typename T, int NDIMS>
 struct Slice {
   void operator()(const Device& d, typename TTypes<T, NDIMS>::Tensor output,
                   typename TTypes<T, NDIMS>::ConstTensor input,
-                  const Eigen::DSizes<ptrdiff_t, NDIMS>& slice_indices,
-                  const Eigen::DSizes<ptrdiff_t, NDIMS>& slice_sizes) {
+                  const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_indices,
+                  const Eigen::DSizes<Eigen::DenseIndex, NDIMS>& slice_sizes) {
     output.device(d) = input.slice(slice_indices, slice_sizes);
   }
 };
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index f4f9ada0001..e8808c1be24 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -90,17 +90,17 @@ class SplitOp : public OpKernel {
     TensorShape output_shape(input_shape);
     output_shape.set_dim(split_dim, split_dim_output_size);
 
-    Eigen::DSizes<ptrdiff_t, 3> indices{0, 0, 0};
-    Eigen::DSizes<ptrdiff_t, 3> sizes{prefix_dim_size, split_dim_output_size,
-                                      suffix_dim_size};
+    Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
+    Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+        prefix_dim_size, split_dim_output_size, suffix_dim_size};
 
     for (int i = 0; i < num_split; ++i) {
       Tensor* result = nullptr;
       OP_REQUIRES_OK(context,
                      context->allocate_output(i, output_shape, &result));
       if (prefix_dim_size * split_dim_output_size * suffix_dim_size > 0) {
-        Eigen::DSizes<ptrdiff_t, 3> slice_indices;
-        Eigen::DSizes<ptrdiff_t, 3> slice_sizes;
+        Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices;
+        Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes;
         for (int j = 0; j < 3; ++j) {
           slice_indices[j] = indices[j];
           slice_sizes[j] = sizes[j];
diff --git a/tensorflow/core/kernels/split_op.h b/tensorflow/core/kernels/split_op.h
index 2572c77285f..fb81d93a39f 100644
--- a/tensorflow/core/kernels/split_op.h
+++ b/tensorflow/core/kernels/split_op.h
@@ -12,8 +12,8 @@ template <typename Device, typename T>
 struct Split {
   void operator()(const Device& d, typename TTypes<T, 3>::Tensor output,
                   typename TTypes<T, 3>::ConstTensor input,
-                  const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
-                  const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes);
+                  const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+                  const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes);
 };
 
 template <typename T>
@@ -21,8 +21,8 @@ struct Split<Eigen::ThreadPoolDevice, T> {
   void operator()(const Eigen::ThreadPoolDevice& d,
                   typename TTypes<T, 3>::Tensor output,
                   typename TTypes<T, 3>::ConstTensor input,
-                  const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
-                  const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes);
+                  const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+                  const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes);
 };
 
 }  // namespace functor
diff --git a/tensorflow/core/kernels/split_op_cpu.cc b/tensorflow/core/kernels/split_op_cpu.cc
index b86deeb8fbe..eda432b6f99 100644
--- a/tensorflow/core/kernels/split_op_cpu.cc
+++ b/tensorflow/core/kernels/split_op_cpu.cc
@@ -13,8 +13,8 @@ template <typename T>
 void Split<Eigen::ThreadPoolDevice, T>::operator()(
     const Eigen::ThreadPoolDevice& d, typename TTypes<T, 3>::Tensor output,
     typename TTypes<T, 3>::ConstTensor input,
-    const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
-    const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes) {
+    const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+    const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
   if (output.size() < 131072) {
     output = input.slice(slice_indices, slice_sizes);
   } else {
diff --git a/tensorflow/core/kernels/split_op_gpu.cu.cc b/tensorflow/core/kernels/split_op_gpu.cu.cc
index f8931d6a898..d6a68bf9a5f 100644
--- a/tensorflow/core/kernels/split_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/split_op_gpu.cu.cc
@@ -16,8 +16,8 @@ template <typename Device, typename T>
 void Split<Device, T>::operator()(
     const Device& d, typename TTypes<T, 3>::Tensor output,
     typename TTypes<T, 3>::ConstTensor input,
-    const Eigen::DSizes<ptrdiff_t, 3>& slice_indices,
-    const Eigen::DSizes<ptrdiff_t, 3>& slice_sizes) {
+    const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_indices,
+    const Eigen::DSizes<Eigen::DenseIndex, 3>& slice_sizes) {
   output.device(d) = input.slice(slice_indices, slice_sizes);
 }
 
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index d5e0e89d609..decc3207a15 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -273,8 +273,8 @@ class TileGradientOp : public OpKernel {
 #undef HANDLE_DIM
     }
 
-    Eigen::DSizes<ptrdiff_t, NDIM> indices;
-    Eigen::DSizes<ptrdiff_t, NDIM> sizes;
+    Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
+    Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
 
     // Accumulate slices along the dimensions into the output. The number of
     // slices along dimension 'i' is simply the multiple along dimension 'i'
@@ -309,8 +309,8 @@ class TileGradientOp : public OpKernel {
   void HandleReduce(OpKernelContext* context,
                     const std::vector<int32>& reduce_dim_in, Tensor* result) {
     static_assert(NDIM >= REDUCENDIM, "Too many reduced dimensions");
-    Eigen::DSizes<ptrdiff_t, REDUCENDIM> reduce_dim;
-    Eigen::DSizes<ptrdiff_t, NDIM> reshape_dim;
+    Eigen::DSizes<Eigen::DenseIndex, REDUCENDIM> reduce_dim;
+    Eigen::DSizes<Eigen::DenseIndex, NDIM> reshape_dim;
 
     for (int i = 0; i < REDUCENDIM; ++i) {
       reduce_dim[i] = reduce_dim_in[i];
@@ -392,26 +392,26 @@ REGISTER_KERNEL_BUILDER(Name("TileGrad")
   DEFINE_GPU_DIM(T, 4)     \
   DEFINE_GPU_DIM(T, 5)
 
-#define DEFINE_GPU_DIM(T, NDIM)                                       \
-  template <>                                                         \
-  void Tile<GPUDevice, T, NDIM>::operator()(                          \
-      const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out,       \
-      typename TTypes<T, NDIM>::ConstTensor in,                       \
-      const Eigen::array<int32, NDIM>& broadcast_array) const;        \
-  extern template struct Tile<GPUDevice, T, NDIM>;                    \
-  template <>                                                         \
-  void TileGrad<GPUDevice, T, NDIM>::operator()(                      \
-      const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out,       \
-      typename TTypes<T, NDIM>::ConstTensor in,                       \
-      const Eigen::DSizes<ptrdiff_t, NDIM>& indices,                  \
-      const Eigen::DSizes<ptrdiff_t, NDIM>& sizes, bool first) const; \
-  extern template struct TileGrad<GPUDevice, T, NDIM>;                \
-  template <>                                                         \
-  void ReduceAndReshape<GPUDevice, T, NDIM, 1>::operator()(           \
-      const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out,       \
-      typename TTypes<T, NDIM>::ConstTensor in,                       \
-      const Eigen::DSizes<ptrdiff_t, 1>& reduce_dim,                  \
-      const Eigen::DSizes<ptrdiff_t, NDIM>& reshape_dim) const;       \
+#define DEFINE_GPU_DIM(T, NDIM)                                               \
+  template <>                                                                 \
+  void Tile<GPUDevice, T, NDIM>::operator()(                                  \
+      const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out,               \
+      typename TTypes<T, NDIM>::ConstTensor in,                               \
+      const Eigen::array<int32, NDIM>& broadcast_array) const;                \
+  extern template struct Tile<GPUDevice, T, NDIM>;                            \
+  template <>                                                                 \
+  void TileGrad<GPUDevice, T, NDIM>::operator()(                              \
+      const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out,               \
+      typename TTypes<T, NDIM>::ConstTensor in,                               \
+      const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,                  \
+      const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes, bool first) const; \
+  extern template struct TileGrad<GPUDevice, T, NDIM>;                        \
+  template <>                                                                 \
+  void ReduceAndReshape<GPUDevice, T, NDIM, 1>::operator()(                   \
+      const GPUDevice& d, typename TTypes<T, NDIM>::Tensor out,               \
+      typename TTypes<T, NDIM>::ConstTensor in,                               \
+      const Eigen::DSizes<Eigen::DenseIndex, 1>& reduce_dim,                  \
+      const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const;       \
   extern template struct ReduceAndReshape<GPUDevice, T, NDIM, 1>;
 
 namespace functor {
diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops.h
index 41c2deb42dc..1a614fe4f18 100644
--- a/tensorflow/core/kernels/tile_ops.h
+++ b/tensorflow/core/kernels/tile_ops.h
@@ -31,8 +31,8 @@ template <typename Device, typename T, int NDIM>
 struct TileGrad {
   void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
                   typename TTypes<T, NDIM>::ConstTensor in,
-                  const Eigen::DSizes<ptrdiff_t, NDIM>& indices,
-                  const Eigen::DSizes<ptrdiff_t, NDIM>& sizes,
+                  const Eigen::DSizes<Eigen::DenseIndex, NDIM>& indices,
+                  const Eigen::DSizes<Eigen::DenseIndex, NDIM>& sizes,
                   bool first) const {
     if (first) {
       out.device(d) = in.slice(indices, sizes);
@@ -58,10 +58,11 @@ struct TileGrad<Device, T, 0> {
 
 template <typename Device, typename T, int NDIM, int REDUCEDNDIM>
 struct ReduceAndReshape {
-  void operator()(const Device& d, typename TTypes<T, NDIM>::Tensor out,
-                  typename TTypes<T, NDIM>::ConstTensor in,
-                  const Eigen::DSizes<ptrdiff_t, REDUCEDNDIM>& reduce_dim,
-                  const Eigen::DSizes<ptrdiff_t, NDIM>& reshape_dim) const {
+  void operator()(
+      const Device& d, typename TTypes<T, NDIM>::Tensor out,
+      typename TTypes<T, NDIM>::ConstTensor in,
+      const Eigen::DSizes<Eigen::DenseIndex, REDUCEDNDIM>& reduce_dim,
+      const Eigen::DSizes<Eigen::DenseIndex, NDIM>& reshape_dim) const {
     out.device(d) = in.sum(reduce_dim).reshape(reshape_dim);
   }
 };
diff --git a/tensorflow/core/kernels/typed_queue.h b/tensorflow/core/kernels/typed_queue.h
new file mode 100644
index 00000000000..ae2878d87b6
--- /dev/null
+++ b/tensorflow/core/kernels/typed_queue.h
@@ -0,0 +1,54 @@
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
+
+#include <vector>
+
+#include "tensorflow/core/kernels/queue_base.h"
+
+namespace tensorflow {
+
+// TypedQueue builds on QueueBase, with backing class (SubQueue)
+// known and stored within.  Shared methods that need to have access
+// to the backed data sit in this class.
+template <typename SubQueue>
+class TypedQueue : public QueueBase {
+ public:
+  TypedQueue(const int32 capacity, const DataTypeVector& component_dtypes,
+             const std::vector<TensorShape>& component_shapes,
+             const string& name);
+
+  virtual Status Initialize();  // Must be called before any other method.
+
+ protected:
+  std::vector<SubQueue> queues_ GUARDED_BY(mu_);
+};  // class TypedQueue
+
+template <typename SubQueue>
+TypedQueue<SubQueue>::TypedQueue(
+    int32 capacity, const DataTypeVector& component_dtypes,
+    const std::vector<TensorShape>& component_shapes, const string& name)
+    : QueueBase(capacity, component_dtypes, component_shapes, name) {}
+
+template <typename SubQueue>
+Status TypedQueue<SubQueue>::Initialize() {
+  if (component_dtypes_.empty()) {
+    return errors::InvalidArgument("Empty component types for queue ", name_);
+  }
+  if (!component_shapes_.empty() &&
+      component_dtypes_.size() != component_shapes_.size()) {
+    return errors::InvalidArgument("Different number of component types (",
+                                   component_dtypes_.size(), ") vs. shapes (",
+                                   component_shapes_.size(), ").");
+  }
+
+  mutex_lock lock(mu_);
+  queues_.reserve(num_components());
+  for (int i = 0; i < num_components(); ++i) {
+    queues_.push_back(SubQueue());
+  }
+  return Status::OK();
+}
+
+}  // namespace tensorflow
+
+#endif  // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_TYPED_QUEUE_H_
diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc
index 36cfb2c8e59..5d1376be839 100644
--- a/tensorflow/core/kernels/unpack_op.cc
+++ b/tensorflow/core/kernels/unpack_op.cc
@@ -63,8 +63,8 @@ class UnpackOp : public OpKernel {
                      context->allocate_output(i, output_shape, &output));
       auto output_shaped = output->shaped<T, 3>({1, 1, output_size});
 
-      Eigen::DSizes<ptrdiff_t, 3> indices{0, i, 0};
-      Eigen::DSizes<ptrdiff_t, 3> sizes{1, 1, output_size};
+      Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, i, 0};
+      Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, 1, output_size};
       functor::Split<Device, T>()(context->eigen_device<Device>(),
                                   output_shaped, input_reshaped, indices,
                                   sizes);
diff --git a/tensorflow/g3doc/api_docs/cc/index.md b/tensorflow/g3doc/api_docs/cc/index.md
index e30022d9c2c..ddbdd72e164 100644
--- a/tensorflow/g3doc/api_docs/cc/index.md
+++ b/tensorflow/g3doc/api_docs/cc/index.md
@@ -23,28 +23,37 @@ write the graph to a file.
 
 1. Run the graph with a call to `session->Run()`
 
-
-##Classes <a class="md-anchor" id="AUTOGENERATED-classes"></a>
+## Env <a class="md-anchor" id="AUTOGENERATED-env"></a>
 
 * [tensorflow::Env](../../api_docs/cc/ClassEnv.md)
 * [tensorflow::RandomAccessFile](../../api_docs/cc/ClassRandomAccessFile.md)
 * [tensorflow::WritableFile](../../api_docs/cc/ClassWritableFile.md)
 * [tensorflow::EnvWrapper](../../api_docs/cc/ClassEnvWrapper.md)
+
+## Session <a class="md-anchor" id="AUTOGENERATED-session"></a>
+
 * [tensorflow::Session](../../api_docs/cc/ClassSession.md)
+* [tensorflow::SessionOptions](../../api_docs/cc/StructSessionOptions.md)
+
+## Status <a class="md-anchor" id="AUTOGENERATED-status"></a>
+
 * [tensorflow::Status](../../api_docs/cc/ClassStatus.md)
+* [tensorflow::Status::State](../../api_docs/cc/StructState.md)
+
+## Tensor <a class="md-anchor" id="AUTOGENERATED-tensor"></a>
+
 * [tensorflow::Tensor](../../api_docs/cc/ClassTensor.md)
 * [tensorflow::TensorShape](../../api_docs/cc/ClassTensorShape.md)
-* [tensorflow::TensorShapeUtils](../../api_docs/cc/ClassTensorShapeUtils.md)
-* [tensorflow::Thread](../../api_docs/cc/ClassThread.md)
-
-##Structs <a class="md-anchor" id="AUTOGENERATED-structs"></a>
-
-* [tensorflow::SessionOptions](../../api_docs/cc/StructSessionOptions.md)
-* [tensorflow::Status::State](../../api_docs/cc/StructState.md)
 * [tensorflow::TensorShapeDim](../../api_docs/cc/StructTensorShapeDim.md)
+* [tensorflow::TensorShapeUtils](../../api_docs/cc/ClassTensorShapeUtils.md)
+
+## Thread <a class="md-anchor" id="AUTOGENERATED-thread"></a>
+
+* [tensorflow::Thread](../../api_docs/cc/ClassThread.md)
 * [tensorflow::ThreadOptions](../../api_docs/cc/StructThreadOptions.md)
 
 
+
 <div class='sections-order' style="display: none;">
 <!--
 <!-- ClassEnv.md -->
@@ -52,14 +61,14 @@ write the graph to a file.
 <!-- ClassWritableFile.md -->
 <!-- ClassEnvWrapper.md -->
 <!-- ClassSession.md -->
+<!-- StructSessionOptions.md -->
 <!-- ClassStatus.md -->
+<!-- StructState.md -->
 <!-- ClassTensor.md -->
 <!-- ClassTensorShape.md -->
+<!-- StructTensorShapeDim.md -->
 <!-- ClassTensorShapeUtils.md -->
 <!-- ClassThread.md -->
-<!-- StructSessionOptions.md -->
-<!-- StructState.md -->
-<!-- StructTensorShapeDim.md -->
 <!-- StructThreadOptions.md -->
 -->
 </div>
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 5bb7cd9c7a7..4f4292ff5ef 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -597,7 +597,7 @@ For so-called "global normalization" needed for convolutional filters pass
 
 ##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
 
-  Two `Tensors`: `mean` and `variance`.
+  Two `Tensor` objects: `mean` and `variance`.
 
 
 
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 9014a811501..aa0301028a1 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -175,20 +175,20 @@ depends on.
 
 
 Follow instructions [here](http://bazel.io/docs/install.html) to install the
-dependencies for Bazel. Then download and build the Bazel source with the
-following commands:
+dependencies for Bazel. Then download bazel version 0.1.1 using the
+[installer for your system](https://github.com/bazelbuild/bazel/releases) and
+run the installer as mentioned there:
 
 ```bash
-$ git clone https://github.com/bazelbuild/bazel.git
-$ cd bazel
-$ git checkout tags/0.1.0
-$ ./compile.sh
+$ chmod +x PATH_TO_INSTALL.SH
+$ ./PATH_TO_INSTALL.SH --user
 ```
 
-These commands use the commit tag `0.1.0`, which is known to work with
-TensorFlow. `HEAD` may be unstable.
+Remember to replace `PATH_TO_INSTALL.SH` to point to the location where you
+downloaded the installer.
 
-Add the executable `output/bazel` to your `$PATH` environment variable.
+Finally, follow the instructions in that script to place bazel into your binary
+path.
 
 #### Install other dependencies <a class="md-anchor" id="AUTOGENERATED-install-other-dependencies"></a>
 
diff --git a/tensorflow/g3doc/resources/index.md b/tensorflow/g3doc/resources/index.md
index a2bc5733485..57c88f41677 100644
--- a/tensorflow/g3doc/resources/index.md
+++ b/tensorflow/g3doc/resources/index.md
@@ -15,6 +15,11 @@ system, we suggest you cite the paper above.
 You can use this [BibTeX entry](../resources/bib.md).  As the project progresses, we
 may update the suggested citation with new papers.
 
+Please only use the TensorFlow name and marks when accurately referencing this
+software distribution, and do not use our marks in a way that suggests you are
+endorsed by or otherwise affiliated with Google. When referring to our marks,
+please include the following attribution statement: "TensorFlow, the TensorFlow
+logo and any related marks are trademarks of Google Inc."
 
 ## Community <a class="md-anchor" id="AUTOGENERATED-community"></a>
 
diff --git a/tensorflow/models/embedding/BUILD b/tensorflow/models/embedding/BUILD
index f8f7e7bcb23..fe52778fa91 100644
--- a/tensorflow/models/embedding/BUILD
+++ b/tensorflow/models/embedding/BUILD
@@ -12,6 +12,7 @@ py_binary(
     srcs = [
         "word2vec.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":gen_word2vec",
         "//tensorflow:tensorflow_py",
@@ -24,6 +25,7 @@ py_binary(
     srcs = [
         "word2vec_optimized.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":gen_word2vec",
         "//tensorflow:tensorflow_py",
@@ -35,6 +37,7 @@ py_test(
     name = "word2vec_test",
     size = "small",
     srcs = ["word2vec_test.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":word2vec",
         "//tensorflow:tensorflow_py",
@@ -45,6 +48,7 @@ py_test(
     name = "word2vec_optimized_test",
     size = "small",
     srcs = ["word2vec_optimized_test.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":word2vec_optimized",
         "//tensorflow:tensorflow_py",
diff --git a/tensorflow/models/image/alexnet/BUILD b/tensorflow/models/image/alexnet/BUILD
index e1b9cd69652..bbe29da6f5c 100644
--- a/tensorflow/models/image/alexnet/BUILD
+++ b/tensorflow/models/image/alexnet/BUILD
@@ -10,6 +10,7 @@ py_binary(
     srcs = [
         "alexnet_benchmark.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         "//tensorflow:tensorflow_py",
     ],
diff --git a/tensorflow/models/image/cifar10/BUILD b/tensorflow/models/image/cifar10/BUILD
index adf9aaffd41..25dce65f282 100644
--- a/tensorflow/models/image/cifar10/BUILD
+++ b/tensorflow/models/image/cifar10/BUILD
@@ -8,6 +8,7 @@ exports_files(["LICENSE"])
 py_library(
     name = "cifar10_input",
     srcs = ["cifar10_input.py"],
+    srcs_version = "PY2AND3",
     deps = [
         "//tensorflow:tensorflow_py",
     ],
@@ -16,6 +17,7 @@ py_library(
 py_test(
     name = "cifar10_input_test",
     srcs = ["cifar10_input_test.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":cifar10_input",
         "//tensorflow:tensorflow_py",
@@ -27,6 +29,7 @@ py_test(
 py_library(
     name = "cifar10",
     srcs = ["cifar10.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":cifar10_input",
         "//tensorflow:tensorflow_py",
@@ -38,6 +41,7 @@ py_binary(
     srcs = [
         "cifar10_eval.py",
     ],
+    srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
         ":cifar10",
@@ -49,6 +53,7 @@ py_binary(
     srcs = [
         "cifar10_train.py",
     ],
+    srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
         ":cifar10",
@@ -60,6 +65,7 @@ py_binary(
     srcs = [
         "cifar10_multi_gpu_train.py",
     ],
+    srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = [
         ":cifar10",
diff --git a/tensorflow/models/image/mnist/BUILD b/tensorflow/models/image/mnist/BUILD
index 6774810e823..6dd96e1e6f9 100644
--- a/tensorflow/models/image/mnist/BUILD
+++ b/tensorflow/models/image/mnist/BUILD
@@ -10,6 +10,7 @@ py_binary(
     srcs = [
         "convolutional.py",
     ],
+    srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
     deps = ["//tensorflow:tensorflow_py"],
 )
@@ -24,6 +25,7 @@ py_test(
         "--self_test=True",
     ],
     main = "convolutional.py",
+    srcs_version = "PY2AND3",
     deps = ["//tensorflow:tensorflow_py"],
 )
 
diff --git a/tensorflow/models/rnn/BUILD b/tensorflow/models/rnn/BUILD
index 3e5e6b37ca3..1a81ce2801e 100644
--- a/tensorflow/models/rnn/BUILD
+++ b/tensorflow/models/rnn/BUILD
@@ -14,6 +14,7 @@ py_library(
     srcs = [
         "linear.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         "//tensorflow:tensorflow_py",
     ],
@@ -23,6 +24,7 @@ py_test(
     name = "linear_test",
     size = "small",
     srcs = ["linear_test.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":linear",
         "//tensorflow:tensorflow_py",
@@ -34,6 +36,7 @@ py_library(
     srcs = [
         "rnn_cell.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":linear",
         "//tensorflow:tensorflow_py",
@@ -44,6 +47,7 @@ py_test(
     name = "rnn_cell_test",
     size = "small",
     srcs = ["rnn_cell_test.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":rnn_cell",
         "//tensorflow:tensorflow_py",
@@ -55,6 +59,7 @@ py_library(
     srcs = [
         "__init__.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":rnn",
         ":rnn_cell",
@@ -67,6 +72,7 @@ py_library(
     srcs = [
         "rnn.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":rnn_cell",
         "//tensorflow:tensorflow_py",
@@ -88,6 +94,7 @@ py_library(
     srcs = [
         "seq2seq.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":rnn",
         "//tensorflow:tensorflow_py",
@@ -99,6 +106,7 @@ py_test(
     srcs = [
         "seq2seq_test.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":seq2seq",
         "//tensorflow:tensorflow_py",
diff --git a/tensorflow/models/rnn/ptb/BUILD b/tensorflow/models/rnn/ptb/BUILD
index 56d459a0f18..c5feb191d55 100644
--- a/tensorflow/models/rnn/ptb/BUILD
+++ b/tensorflow/models/rnn/ptb/BUILD
@@ -10,12 +10,14 @@ exports_files(["LICENSE"])
 py_library(
     name = "reader",
     srcs = ["reader.py"],
+    srcs_version = "PY2AND3",
     deps = ["//tensorflow:tensorflow_py"],
 )
 
 py_test(
     name = "reader_test",
     srcs = ["reader_test.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":reader",
         "//tensorflow:tensorflow_py",
@@ -27,6 +29,7 @@ py_binary(
     srcs = [
         "ptb_word_lm.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":reader",
         "//tensorflow:tensorflow_py",
diff --git a/tensorflow/models/rnn/translate/BUILD b/tensorflow/models/rnn/translate/BUILD
index 57f17fb5abe..cf3780165b3 100644
--- a/tensorflow/models/rnn/translate/BUILD
+++ b/tensorflow/models/rnn/translate/BUILD
@@ -12,6 +12,7 @@ py_library(
     srcs = [
         "data_utils.py",
     ],
+    srcs_version = "PY2AND3",
     deps = ["//tensorflow:tensorflow_py"],
 )
 
@@ -20,6 +21,7 @@ py_library(
     srcs = [
         "seq2seq_model.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":data_utils",
         "//tensorflow:tensorflow_py",
@@ -32,6 +34,7 @@ py_binary(
     srcs = [
         "translate.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":data_utils",
         ":seq2seq_model",
@@ -49,6 +52,7 @@ py_test(
         "--self_test=True",
     ],
     main = "translate.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":data_utils",
         ":seq2seq_model",
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7002ebfd65d..5c6e08ae44a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -27,6 +27,7 @@ numpy_macosx_include_dir = select({
 py_library(
     name = "python",
     srcs = ["__init__.py"],
+    srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__pkg__"],
     deps = [
         ":client",
@@ -43,6 +44,7 @@ py_library(
 py_library(
     name = "platform",
     srcs = glob(["platform/**/*.py"]),
+    srcs_version = "PY2AND3",
 )
 
 py_library(
@@ -51,6 +53,7 @@ py_library(
         "platform/default/_googletest.py",
         "platform/googletest.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [":platform"],
 )
 
@@ -94,6 +97,7 @@ py_test(
     name = "pywrap_status_test",
     size = "small",
     srcs = ["lib/core/pywrap_status_test.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":platform_test",
@@ -133,6 +137,7 @@ py_library(
         "framework/tensor_util.py",
         "ops/common_shapes.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":platform",
         "//tensorflow/core:protos_all_py",
@@ -143,6 +148,7 @@ py_library(
 
 py_library(
     name = "extra_py_tests_deps",
+    srcs_version = "PY2AND3",
     deps = ["//tensorflow:tensorflow_py"],
 )
 
@@ -151,6 +157,7 @@ py_library(
     srcs = [
         "framework/test_util.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":framework",
         ":platform_test",
@@ -165,6 +172,7 @@ py_library(
     srcs = [
         "platform/test.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":platform_test",
@@ -175,6 +183,7 @@ py_test(
     name = "framework_errors_test",
     srcs = ["framework/errors_test.py"],
     main = "framework/errors_test.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":platform_test",
@@ -187,6 +196,7 @@ py_test(
     name = "framework_importer_test",
     srcs = ["framework/importer_test.py"],
     main = "framework/importer_test.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":ops",
@@ -213,6 +223,7 @@ py_test(
     name = "framework_ops_test",
     srcs = ["framework/ops_test.py"],
     main = "framework/ops_test.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":ops",
@@ -226,6 +237,19 @@ py_test(
     name = "framework_tensor_shape_test",
     srcs = ["framework/tensor_shape_test.py"],
     main = "framework/tensor_shape_test.py",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":framework_test_lib",
+        ":platform_test",
+        "//tensorflow/core:protos_all_py",
+    ],
+)
+
+py_test(
+    name = "framework_tensor_shape_div_test",
+    srcs = ["framework/tensor_shape_div_test.py"],
+    main = "framework/tensor_shape_div_test.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":platform_test",
@@ -237,6 +261,7 @@ py_test(
     name = "framework_tensor_util_test",
     srcs = ["framework/tensor_util_test.py"],
     main = "framework/tensor_util_test.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":ops",
@@ -248,6 +273,7 @@ py_test(
     name = "framework_test_util_test",
     srcs = ["framework/test_util_test.py"],
     main = "framework/test_util_test.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":ops",
@@ -259,6 +285,7 @@ py_test(
     name = "framework_types_test",
     srcs = ["framework/types_test.py"],
     main = "framework/types_test.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":platform_test",
@@ -271,6 +298,7 @@ py_test(
     name = "op_def_library_test",
     srcs = ["ops/op_def_library_test.py"],
     main = "ops/op_def_library_test.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":ops",
@@ -565,6 +593,7 @@ py_library(
         "ops/variables.py",
         "user_ops/user_ops.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":array_ops",
         ":candidate_sampling_ops",
@@ -591,6 +620,7 @@ py_library(
         ["training/**/*.py"],
         exclude = ["**/*test*"],
     ),
+    srcs_version = "PY2AND3",
     deps = [
         ":client",
         ":framework",
@@ -609,6 +639,7 @@ py_library(
         ["client/**/*.py"],
         exclude = ["**/*test*"],
     ),
+    srcs_version = "PY2AND3",
     deps = [
         ":framework",
         ":ops",
@@ -620,6 +651,7 @@ py_library(
 py_library(
     name = "util",
     srcs = glob(["util/**/*.py"]),
+    srcs_version = "PY2AND3",
     deps = ["//google/protobuf:protobuf_python"],
 )
 
@@ -641,6 +673,7 @@ py_test(
     name = "protobuf_compare_test",
     srcs = ["util/protobuf/compare_test.py"],
     main = "util/protobuf/compare_test.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":compare_test_proto_py",
         ":platform_test",
@@ -654,6 +687,7 @@ py_test(
     srcs = [
         "client/events_writer_test.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":framework_test_lib",
         ":lib",
@@ -719,6 +753,7 @@ tf_py_wrap_cc(
 py_library(
     name = "lib",
     srcs = glob(["lib/**/*.py"]),
+    srcs_version = "PY2AND3",
     deps = [
         ":pywrap_tensorflow",
     ],
@@ -727,6 +762,7 @@ py_library(
 py_library(
     name = "session",
     srcs = ["client/session.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":framework",
         ":ops",
@@ -750,6 +786,7 @@ tf_cuda_library(
 py_test(
     name = "session_test",
     srcs = ["client/session_test.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":framework",
         ":framework_test_lib",
@@ -760,6 +797,7 @@ py_test(
 py_test(
     name = "graph_util_test",
     srcs = ["client/graph_util_test.py"],
+    srcs_version = "PY2AND3",
     deps = [
         ":framework",
         ":framework_test_lib",
@@ -770,6 +808,7 @@ py_test(
 py_library(
     name = "kernel_tests/gradient_checker",
     srcs = ["kernel_tests/gradient_checker.py"],
+    srcs_version = "PY2AND3",
 )
 
 cpu_only_kernel_test_list = glob([
@@ -899,6 +938,7 @@ py_library(
         ["summary/**/*.py"],
         exclude = ["**/*test*"],
     ),
+    srcs_version = "PY2AND3",
     deps = [
         ":client",
         ":framework",
@@ -921,6 +961,7 @@ py_library(
     srcs = [
         "framework/docs.py",
     ],
+    srcs_version = "PY2AND3",
     deps = [
         ":platform",
     ],
@@ -932,6 +973,7 @@ py_binary(
         "framework/gen_docs_combined.py",
     ],
     main = "framework/gen_docs_combined.py",
+    srcs_version = "PY2AND3",
     deps = [
         ":docs",
         ":platform",
diff --git a/tensorflow/python/framework/tensor_shape.py b/tensorflow/python/framework/tensor_shape.py
index b5462dcd177..865533cf92d 100644
--- a/tensorflow/python/framework/tensor_shape.py
+++ b/tensorflow/python/framework/tensor_shape.py
@@ -170,18 +170,18 @@ class Dimension(object):
   def __floordiv__(self, other):
     """Returns the quotient of `self` and `other` rounded down.
 
-    Dimensions are summed as follows:
+    Dimensions are divided as follows:
 
-      Dimension(m)    / Dimension(n)    == Dimension(m / n)
-      Dimension(m)    / Dimension(None) == Dimension(None)
-      Dimension(None) / Dimension(n)    == Dimension(None)
-      Dimension(None) / Dimension(None) == Dimension(None)
+      Dimension(m)    // Dimension(n)    == Dimension(m // n)
+      Dimension(m)    // Dimension(None) == Dimension(None)
+      Dimension(None) // Dimension(n)    == Dimension(None)
+      Dimension(None) // Dimension(None) == Dimension(None)
 
     Args:
-      other: Another Dimension.
+      other: Another `Dimension`.
 
     Returns:
-      A Dimension whose value is the sum of `self` and `other`.
+      A `Dimension` whose value is the integer quotient of `self` and `other`.
     """
     other = as_dimension(other)
     if self._value is None or other.value is None:
@@ -189,6 +189,22 @@ class Dimension(object):
     else:
       return Dimension(self._value // other.value)
 
+  def __div__(self, other):
+    """DEPRECATED: Use `__floordiv__` via `x // y` instead.
+
+    This function exists only for backwards compatibility purposes; new code
+    should use `__floordiv__` via the syntax `x // y`.  Using `x // y`
+    communicates clearly that the result rounds down, and is forward compatible
+    to Python 3.
+
+    Args:
+      other: Another `Dimension`.
+
+    Returns:
+      A `Dimension` whose value is the integer quotient of `self` and `other`.
+    """
+    return self // other
+
   def __mod__(self, other):
     """Returns `self` modulo `other.
 
diff --git a/tensorflow/python/framework/tensor_shape_div_test.py b/tensorflow/python/framework/tensor_shape_div_test.py
new file mode 100644
index 00000000000..27219dbb9a9
--- /dev/null
+++ b/tensorflow/python/framework/tensor_shape_div_test.py
@@ -0,0 +1,24 @@
+"""Test that old style division works for Dimension."""
+from __future__ import absolute_import
+# from __future__ import division  # Intentionally skip this import
+from __future__ import print_function
+
+import tensorflow.python.platform
+
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class DimensionDivTest(test_util.TensorFlowTestCase):
+
+  def testDivSucceeds(self):
+    """Without from __future__ import division, __div__ should work."""
+    values = [tensor_shape.Dimension(x) for x in 3, 7, 11, None]
+    for x in values:
+      for y in values:
+        self.assertEqual((x / y).value, (x // y).value)
+
+
+if __name__ == "__main__":
+  googletest.main()
diff --git a/tensorflow/python/framework/tensor_shape_test.py b/tensorflow/python/framework/tensor_shape_test.py
index be5fbb51cb6..43cadc7b139 100644
--- a/tensorflow/python/framework/tensor_shape_test.py
+++ b/tensorflow/python/framework/tensor_shape_test.py
@@ -233,6 +233,12 @@ class ShapeTest(test_util.TensorFlowTestCase):
     tensor_shape.TensorShape(
         [94, 43]).assert_is_compatible_with(tensor_shape.matrix(94, 43))
 
+  def testTruedivFails(self):
+    unknown = tensor_shape.Dimension(None)
+    self.assertEqual((unknown // unknown).value, None)
+    with self.assertRaisesRegexp(TypeError, r"unsupported operand type"):
+      unknown / unknown  # pylint: disable=pointless-statement
+
 
 if __name__ == "__main__":
   googletest.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 28138fbf39f..78262a55f39 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -409,7 +409,7 @@ def split(split_dim, num_split, value, name="split"):
   Args:
     split_dim: A 0-D `int32` `Tensor`. The dimension along which to split.
       Must be in the range `[0, rank(value))`.
-    num_split: A 0-D `int32` `Tensor`. The number of ways to split.
+    num_split: A Python integer. The number of ways to split.
     value: The `Tensor` to split.
     name: A name for the operation (optional).
 
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 63826170bec..c5730dce213 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -138,7 +138,7 @@ class Optimizer(object):
     self._slots = {}
 
   def minimize(self, loss, global_step=None, var_list=None,
-               gate_gradients=GATE_OP, name=None):
+               gate_gradients=GATE_OP, aggregation_method=None, name=None):
     """Add operations to minimize 'loss' by updating 'var_list'.
 
     This method simply combines calls compute_gradients() and
@@ -155,6 +155,8 @@ class Optimizer(object):
         under the key GraphKeys.TRAINABLE_VARIABLES.
       gate_gradients: How to gate the computation of gradients.  Can be
         GATE_NONE, GATE_OP, or  GATE_GRAPH.
+      aggregation_method: Specifies the method used to combine gradient terms.
+        Valid values are defined in the class `AggregationMethod`.
       name: Optional name for the returned operation.
 
     Returns:
@@ -164,12 +166,14 @@ class Optimizer(object):
     Raises:
       ValueError: if some of the variables are not variables.Variable objects.
     """
-    grads_and_vars = self.compute_gradients(loss, var_list=var_list,
-                                            gate_gradients=gate_gradients)
+    grads_and_vars = self.compute_gradients(
+        loss, var_list=var_list, gate_gradients=gate_gradients,
+        aggregation_method=aggregation_method)
     return self.apply_gradients(grads_and_vars, global_step=global_step,
                                 name=name)
 
-  def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP):
+  def compute_gradients(self, loss, var_list=None, gate_gradients=GATE_OP,
+                        aggregation_method=None):
     """Compute gradients of "loss" for the variables in "var_list".
 
     This is the first part of minimize().  It returns a list
@@ -185,6 +189,8 @@ class Optimizer(object):
         under the key GraphKey.TRAINABLE_VARIABLES.
       gate_gradients: How to gate the computation of gradients.  Can be
         GATE_NONE, GATE_OP, or  GATE_GRAPH.
+      aggregation_method: Specifies the method used to combine gradient terms.
+        Valid values are defined in the class `AggregationMethod`.
 
     Returns:
       A list of (gradient, variable) pairs.
@@ -205,7 +211,8 @@ class Optimizer(object):
       if not isinstance(var, variables.Variable):
         raise TypeError("Argument is not a variables.Variable: %s" % var)
     grads = gradients.gradients(
-        loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP))
+        loss, var_list, gate_gradients=(gate_gradients == Optimizer.GATE_OP),
+        aggregation_method=aggregation_method)
     if gate_gradients == Optimizer.GATE_GRAPH:
       grads = control_flow_ops.tuple(grads)
     grads_and_vars = list(zip(grads, var_list))
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
new file mode 100644
index 00000000000..a743240d8a6
--- /dev/null
+++ b/tensorflow/python/training/optimizer_test.py
@@ -0,0 +1,54 @@
+"""Functional test for optimizer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+
+class OptimizerTest(tf.test.TestCase):
+
+  def testBasic(self):
+    with self.test_session():
+      var0 = tf.Variable([1.0, 2.0])
+      var1 = tf.Variable([3.0, 4.0])
+      cost = 5 * var0 + 3 * var1
+      global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
+      sgd_op = tf.train.GradientDescentOptimizer(3.0)
+      opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+
+      tf.initialize_all_variables().run()
+      # Fetch params to validate initial values
+      self.assertAllClose([1.0, 2.0], var0.eval())
+      self.assertAllClose([3.0, 4.0], var1.eval())
+      # Run 1 step of sgd through optimizer
+      opt_op.run()
+      # Validate updated params
+      self.assertAllClose([-14., -13.], var0.eval())
+      self.assertAllClose([-6., -5.], var1.eval())
+
+  def testAggregationMethod(self):
+    with self.test_session():
+      var0 = tf.Variable([1.0, 2.0])
+      var1 = tf.Variable([3.0, 4.0])
+      cost = 5 * var0 + 3 * var1
+      global_step = tf.Variable(tf.zeros([], tf.int64), name='global_step')
+      sgd_op = tf.train.GradientDescentOptimizer(3.0)
+      opt_op = sgd_op.minimize(
+          cost, global_step, [var0, var1], aggregation_method=
+          tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
+
+      tf.initialize_all_variables().run()
+      # Fetch params to validate initial values
+      self.assertAllClose([1.0, 2.0], var0.eval())
+      self.assertAllClose([3.0, 4.0], var1.eval())
+      # Run 1 step of sgd through optimizer
+      opt_op.run()
+      # Validate updated params
+      self.assertAllClose([-14., -13.], var0.eval())
+      self.assertAllClose([-6., -5.], var1.eval())
+
+
+if __name__ == "__main__":
+  tf.test.main()
diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD
index 2dcb5e4fa9f..74bdd1d6ab8 100644
--- a/tensorflow/tensorboard/BUILD
+++ b/tensorflow/tensorboard/BUILD
@@ -20,11 +20,13 @@ py_library(
         "//tensorflow/python:platform",
         "//tensorflow/python:summary",
     ],
+    srcs_version = "PY2AND3",
 )
 
 py_library(
     name = "float_wrapper",
     srcs = ["float_wrapper.py"],
+    srcs_version = "PY2AND3",
 )
 
 py_test(
@@ -35,6 +37,7 @@ py_test(
         ":float_wrapper",
         "//tensorflow/python:platform_test",
     ],
+    srcs_version = "PY2AND3",
 )
 
 py_binary(
@@ -46,4 +49,5 @@ py_binary(
         "//tensorflow/python:platform",
         "//tensorflow/python:summary",
     ],
+    srcs_version = "PY2AND3",
 )
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 4bcfd6234c4..f88f3eb2a69 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -332,7 +332,8 @@ def py_tests(name,
                    deps=[
                        "//tensorflow/python:extra_py_tests_deps",
                        "//tensorflow/python:kernel_tests/gradient_checker",
-                   ] + additional_deps)
+                   ] + additional_deps,
+                   srcs_version="PY2AND3")
 
 
 def cuda_py_tests(name, srcs, additional_deps=[], data=[], shard_count=1):
diff --git a/tensorflow/tools/docker/BUILD b/tensorflow/tools/docker/BUILD
index 2cc540ed3b2..7d5ae0a94d8 100644
--- a/tensorflow/tools/docker/BUILD
+++ b/tensorflow/tools/docker/BUILD
@@ -10,6 +10,7 @@ exports_files(["LICENSE"])
 py_binary(
     name = "simple_console",
     srcs = ["simple_console.py"],
+    srcs_version = "PY2AND3",
     deps = ["//tensorflow:tensorflow_py"],
 )
 
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index b9a50e4288c..d0eb717ad69 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -7,6 +7,7 @@ py_binary(
     name = "simple_console",
     srcs = ["simple_console.py"],
     deps = ["//tensorflow:tensorflow_py"],
+    srcs_version = "PY2AND3",
 )
 
 sh_binary(
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/g3doc/README.md b/third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/g3doc/README.md
index 1c3fe32f9b0..9bc11619768 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/g3doc/README.md
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/g3doc/README.md
@@ -1195,7 +1195,7 @@ for the last dimension).
     input.setRandom();
     kernel.setRandom();
 
-    Eigen::array<ptrdiff_t, 2> dims({1, 2});  // Specify second and third dimension for convolution.
+    Eigen::array<Eigen::DenseIndex, 2> dims({1, 2});  // Specify second and third dimension for convolution.
     output = input.convolve(kernel, dims);
 
     for (int i = 0; i < 3; ++i) {
@@ -1577,7 +1577,7 @@ For example, given the following input tensor:
 Six 2x2 patches can be extracted and indexed using the following code:
 
     Eigen::Tensor<float, 3, DataLayout> patch;
-    Eigen::array<ptrdiff_t, 2> patch_dims;
+    Eigen::array<Eigen::DenseIndex, 2> patch_dims;
     patch_dims[0] = 2;
     patch_dims[1] = 2;
     patch = tensor.extract_patches(patch_dims);
diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
index 0f419a8b61e..cdbd1823ea1 100755
--- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
@@ -232,9 +232,14 @@ def InvokeNvcc(argv, log=False):
   srcs = ' '.join(src_files)
   out = ' -o ' + out_file[0]
 
-  nvccopts = ' '.join([
-      r'-gencode=arch=compute_35,\"code=sm_35,compute_35\"',
-      r'-gencode=arch=compute_52,\"code=sm_52,compute_52\"',])
+  # "configure" uses the specific format to substitute the following string.
+  # If you change it, make sure you modify "configure" as well.
+  supported_cuda_compute_capabilities = [ "3.5", "5.2" ]
+  nvccopts = ''
+  for capability in supported_cuda_compute_capabilities:
+    capability = capability.replace('.', '')
+    nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
+        capability, capability, capability)
   nvccopts += ' ' + nvcc_compiler_options
   nvccopts += undefines
   nvccopts += defines
@@ -260,8 +265,8 @@ def InvokeNvcc(argv, log=False):
          ' -I .' +
          ' -x cu ' + opt + includes + ' -c ' + srcs + out)
 
-  # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'. 
-  # Need to investigate and fix. 
+  # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
+  # Need to investigate and fix.
   cmd = 'PATH=' + PREFIX_DIR + ' ' + cmd
   if log: Log(cmd)
   return os.system(cmd)