diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 2b1a0412354..4667f096e00 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -495,6 +495,11 @@ static void SimplifyGraph(Graph* g) {
 
 void OptimizeGraph(FunctionLibraryRuntime* lib, Graph** g) {
   DumpGraph("Initial", *g);
+
+  // Run SimplifyGraph at least once to rewrite away ops such as
+  // _ListToArray, _ArrayToList, etc.
+  SimplifyGraph(*g);
+
   const int kNumInlineRounds = 10;
   for (int i = 0; i < kNumInlineRounds; ++i) {
     if (!ExpandInlineFunctions(lib, *g)) break;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 3df833594f8..8979c94e3d9 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -65,7 +65,7 @@ GPUBFCAllocator::GPUBFCAllocator(int device_id, size_t total_memory)
   ptr_to_chunk_map_.insert(std::make_pair(c->ptr, c));
 
   // Insert the chunk into the right bin.
-  ReassignChunkToBin(c);
+  InsertFreeChunkIntoBin(c);
 }
 
 GPUBFCAllocator::~GPUBFCAllocator() {
@@ -76,6 +76,7 @@ GPUBFCAllocator::~GPUBFCAllocator() {
   }
 
   gtl::STLDeleteValues(&bins_);
+  gtl::STLDeleteValues(&ptr_to_chunk_map_);
 }
 
 void* GPUBFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) {
@@ -115,10 +116,12 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
     // Start searching from the first bin for the smallest chunk that fits
     // rounded_bytes.
     Bin* b = it->second;
-    for (GPUBFCAllocator::Chunk* chunk : b->chunks) {
-      if (!chunk->in_use && chunk->size > rounded_bytes) {
-        // We found an existing chunk that fits us that wasn't in use.
-        chunk->in_use = true;
+    for (GPUBFCAllocator::Chunk* chunk : b->free_chunks) {
+      DCHECK(!chunk->in_use);
+      if (chunk->size >= rounded_bytes) {
+        // We found an existing chunk that fits us that wasn't in use, so remove
+        // it from the free bin structure prior to using.
+        RemoveFreeChunkFromBin(chunk);
 
         // If we can break the size of the chunk into two reasonably
         // large pieces, do so.
@@ -132,6 +135,7 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
         // The requested size of the returned chunk is what the user
         // has allocated.
         chunk->requested_size = num_bytes;
+        chunk->in_use = true;
 
         VLOG(4) << "Returning: " << chunk->ptr;
         return chunk->ptr;
@@ -152,6 +156,8 @@ void* GPUBFCAllocator::AllocateRawInternal(size_t unused_alignment,
 }
 
 void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
+  CHECK(!c->in_use && !c->bin);
+
   // Create a new chunk starting num_bytes after c
   GPUBFCAllocator::Chunk* new_chunk = new GPUBFCAllocator::Chunk();
   new_chunk->ptr = static_cast<void*>(static_cast<char*>(c->ptr) + num_bytes);
@@ -176,9 +182,8 @@ void GPUBFCAllocator::SplitChunk(GPUBFCAllocator::Chunk* c, size_t num_bytes) {
     c_neighbor->prev = new_chunk;
   }
 
-  // Maintain the bins
-  ReassignChunkToBin(new_chunk);
-  ReassignChunkToBin(c);
+  // Add the newly free chunk to the free bin.
+  InsertFreeChunkIntoBin(new_chunk);
 }
 
 void GPUBFCAllocator::DeallocateRaw(void* ptr) {
@@ -200,11 +205,9 @@ void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {
 
   GPUBFCAllocator::Chunk* c = it->second;
   VLOG(6) << "Chunk at " << c->ptr << " no longer in use";
-  // Mark the chunk as no longer in use
-  c->in_use = false;
 
   // Consider coalescing it.
-  MaybeCoalesce(c);
+  FreeAndMaybeCoalesce(c);
 }
 
 // Merges c1 and c2 when c1->next is c2 and c2->prev is c1.
@@ -212,7 +215,7 @@ void GPUBFCAllocator::DeallocateRawInternal(void* ptr) {
 void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
                             GPUBFCAllocator::Chunk* c2) {
   // We can only merge chunks that are not in use.
-  DCHECK(!c1->in_use && !c2->in_use);
+  CHECK(!c1->in_use && !c2->in_use);
 
   // c1's prev doesn't change, still points to the same ptr, and is
   // still not in use.
@@ -231,62 +234,42 @@ void GPUBFCAllocator::Merge(GPUBFCAllocator::Chunk* c1,
   // Set the new size
   c1->size += c2->size;
 
-  // Delete c2 and cleanup all state
-  RemoveChunkFromBin(c2);
+  DeleteChunk(c2);
 }
 
-void GPUBFCAllocator::ReassignChunkToBin(GPUBFCAllocator::Chunk* c) {
+void GPUBFCAllocator::DeleteChunk(Chunk* c) {
+  // Delete c2 and cleanup all state
+  VLOG(4) << "Removing: " << c->ptr;
+  ptr_to_chunk_map_.erase(c->ptr);
+  delete c;
+}
+
+void GPUBFCAllocator::InsertFreeChunkIntoBin(GPUBFCAllocator::Chunk* c) {
+  CHECK(!c->in_use && !c->bin);
   auto it = bins_.lower_bound(c->size);
   CHECK(it != bins_.end()) << " Tried to reassign to non-existent bin for size "
                            << c->size;
-
   Bin* new_bin = it->second;
-
-  // If the bin has not changed, do nothing.
-  Bin* old_bin = c->bin;
-  if (old_bin != nullptr && new_bin == old_bin) {
-    return;
-  }
-
-  // The bin has changed.  Add the chunk to the new bin and remove
-  // the chunk from the old bin.
-  new_bin->chunks.insert(c);
   c->bin = new_bin;
-
-  if (old_bin == nullptr) {
-    return;
-  }
-
-  // Remove chunk from old bin
-  for (auto it = old_bin->chunks.begin(); it != old_bin->chunks.end(); ++it) {
-    if (*it == c) {
-      old_bin->chunks.erase(it);
-      return;
-    }
-  }
-  CHECK(false) << "Could not find chunk in old bin";
+  new_bin->free_chunks.insert(c);
 }
 
-void GPUBFCAllocator::RemoveChunkFromBin(GPUBFCAllocator::Chunk* c) {
-  Bin* b = c->bin;
-  for (auto it = b->chunks.begin(); it != b->chunks.end(); ++it) {
-    Chunk* other_c = *it;
-    if (other_c->ptr == c->ptr) {
-      b->chunks.erase(it);
-      VLOG(4) << "Removing: " << c->ptr;
-      ptr_to_chunk_map_.erase(c->ptr);
-      delete c;
-      return;
-    }
-  }
-
-  CHECK(false) << "Could not find chunk in bin";
+void GPUBFCAllocator::RemoveFreeChunkFromBin(GPUBFCAllocator::Chunk* c) {
+  CHECK(!c->in_use && c->bin);
+  int count = c->bin->free_chunks.erase(c);
+  CHECK(count > 0) << "Could not find chunk in bin";
+  c->bin = nullptr;
 }
 
-void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
+void GPUBFCAllocator::FreeAndMaybeCoalesce(GPUBFCAllocator::Chunk* c) {
+  CHECK(c->in_use && !c->bin);
+
+  // Mark the chunk as no longer in use
+  c->in_use = false;
+
   // This chunk is no longer in-use, consider coalescing the chunk
   // with adjacent chunks.
-  Chunk* chunk_to_reassign = nullptr;
+  Chunk* chunk_to_reassign = c;
 
   // If the next chunk is free, coalesce the two, if the result would
   // fit in an existing bin.
@@ -296,6 +279,7 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
     chunk_to_reassign = c;
 
     // Deletes c->next
+    RemoveFreeChunkFromBin(c->next);
     Merge(c, c->next);
   }
 
@@ -307,13 +291,11 @@ void GPUBFCAllocator::MaybeCoalesce(GPUBFCAllocator::Chunk* c) {
     chunk_to_reassign = c->prev;
 
     // Deletes c
+    RemoveFreeChunkFromBin(c->prev);
     Merge(c->prev, c);
   }
 
-  // Reassign the final merged chunk into the right bin.
-  if (chunk_to_reassign) {
-    ReassignChunkToBin(chunk_to_reassign);
-  }
+  InsertFreeChunkIntoBin(chunk_to_reassign);
 }
 
 void GPUBFCAllocator::AddAllocVisitor(Visitor visitor) {
@@ -354,7 +336,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
     size_t total_requested_bytes_in_bin = 0;
     size_t total_chunks_in_use = 0;
     size_t total_chunks_in_bin = 0;
-    for (Chunk* c : b->chunks) {
+    for (Chunk* c : b->free_chunks) {
       total_bytes_in_bin += c->size;
       total_requested_bytes_in_bin += c->requested_size;
       ++total_chunks_in_bin;
@@ -388,7 +370,7 @@ void GPUBFCAllocator::DumpMemoryLog(size_t num_bytes) {
               << " was " << strings::HumanReadableNumBytes(b->bin_size)
               << ", Chunk State: ";
 
-    for (Chunk* c : b->chunks) {
+    for (Chunk* c : b->free_chunks) {
       LOG(INFO) << c->DebugString(true);
     }
   }
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index 3d1601e132a..417df6f4136 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -102,28 +102,33 @@ class GPUBFCAllocator : public VisitableAllocator {
   Chunk* AllocateNewChunk(size_t num_bytes);
   void SplitChunk(Chunk* c, size_t num_bytes);
   void Merge(Chunk* c1, Chunk* c2);
-  void MaybeCoalesce(Chunk* c);
-
-  void ReassignChunkToBin(Chunk* c);
-  void RemoveChunkFromBin(Chunk* c);
+  void FreeAndMaybeCoalesce(Chunk* c);
+  void InsertFreeChunkIntoBin(Chunk* c);
+  void RemoveFreeChunkFromBin(Chunk* c);
+  void DeleteChunk(Chunk* c);
 
   void DumpMemoryLog(size_t num_bytes);
 
-  // A Bin is a collection of similar-sized Chunks.
+  // A Bin is a collection of similar-sized free chunks.
   struct Bin {
     // All chunks in this bin have >= bin_size memory.
     size_t bin_size = 0;
 
     struct ChunkComparator {
-      bool operator()(Chunk* a, Chunk* b) { return a->size < b->size; }
+      // Sort first by size and then use pointer address as a tie breaker.
+      bool operator()(const Chunk* a, const Chunk* b) const {
+        if (a->size != b->size) {
+          return a->size < b->size;
+        }
+        return a->ptr < b->ptr;
+      }
     };
 
-    // List of chunks within the bin, sorted by chunk size.
-    std::multiset<Chunk*, ChunkComparator> chunks;
+    // List of free chunks within the bin, sorted by chunk size.
+    // Chunk * not owned.
+    std::set<Chunk*, ChunkComparator> free_chunks;
 
     explicit Bin(size_t bs) : bin_size(bs) {}
-
-    ~Bin() { gtl::STLDeleteElements(&chunks); }
   };
 
   GPUAllocatorRetry retry_helper_;
@@ -142,7 +147,7 @@ class GPUBFCAllocator : public VisitableAllocator {
 
   // Structures mutable after construction
   mutable mutex lock_;
-  // Not owned.
+  // Chunk * owned.
   std::unordered_map<void*, Chunk*> ptr_to_chunk_map_;
 
   // Called once on each region, ASAP.
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
index 474b988d2fa..fa9c0170f54 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -20,7 +20,7 @@ DEFINE_bool(record_mem_types, false,
 DEFINE_bool(brain_mem_reg_cuda_dma, true,
             "If true, register CPU RAM used to copy to/from GPU RAM "
             "with the CUDA driver.");
-DEFINE_bool(brain_gpu_use_bfc_allocator, false,
+DEFINE_bool(brain_gpu_use_bfc_allocator, true,
             "If true, uses the Best-Fit GPU allocator.");
 DEFINE_bool(brain_gpu_region_allocator_debug, false,
             "If true, checks for memory overwrites by writing "
@@ -34,7 +34,7 @@ bool FLAGS_record_mem_types = false;
 bool FLAGS_brain_mem_reg_cuda_dma = true;
 bool FLAGS_brain_gpu_region_allocator_debug = false;
 bool FLAGS_brain_gpu_region_allocator_reset_to_nan = false;
-bool FLAGS_brain_gpu_use_bfc_allocator = false;
+bool FLAGS_brain_gpu_use_bfc_allocator = true;
 #endif
 
 namespace gpu = ::perftools::gputools;
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index b68fcec5151..adc802cb452 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -135,6 +135,7 @@ REGISTER_CONCAT(bfloat16);
                           ConcatOp<GPUDevice, type>)
 
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+REGISTER_GPU(bfloat16);
 #undef REGISTER_GPU
 
 // A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/concat_op_gpu.cu.cc b/tensorflow/core/kernels/concat_op_gpu.cu.cc
index aed36dccef7..0f21d5f07c5 100644
--- a/tensorflow/core/kernels/concat_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/concat_op_gpu.cu.cc
@@ -6,6 +6,7 @@
 
 #include <memory>
 
+#include "tensorflow/core/framework/bfloat16.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/tensor_types.h"
 
@@ -34,6 +35,7 @@ void ConcatGPU(const GPUDevice& d,
       typename TTypes<T, 2>::Matrix* output);
 
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
+REGISTER_GPU(bfloat16);
 #undef REGISTER_GPU
 
 }  // end namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 7bfe8b095f6..eeb9fa4c388 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -541,12 +541,36 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
     // The output image size is the spatial size of the output.
     const int output_image_size = out_rows * out_cols;
 
+    // Shard 'batch' images into 'shard_size' groups of images to be fed
+    // into the parallel matmul. Calculate 'shard_size' by dividing the L3 cache
+    // size ('target_working_set_size') by the matmul size of an individual
+    // image ('work_unit_size').
+
+    // TODO(andydavis)
+    // *) Get L3 cache size from device at runtime (30MB is from ivybridge).
+    // *) Consider reducing 'target_working_set_size' if L3 is shared by
+    //    other concurrently running tensorflow ops.
+    const size_t target_working_set_size = (30LL << 20) / sizeof(T);
+
+    const size_t size_A = output_image_size * filter_total_size;
+
+    const size_t size_B = output_image_size * out_depth;
+
+    const size_t size_C = filter_total_size * out_depth;
+
+    const size_t work_unit_size = size_A + size_B + size_C;
+
+    const size_t shard_size =
+        (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
     Tensor col_buffer;
-    OP_REQUIRES_OK(
-        context,
-        context->allocate_temp(
-            DataTypeToEnum<T>::value,
-            TensorShape({output_image_size, filter_total_size}), &col_buffer));
+    OP_REQUIRES_OK(context,
+                   context->allocate_temp(
+                       DataTypeToEnum<T>::value,
+                       TensorShape({static_cast<int64>(shard_size),
+                                    static_cast<int64>(output_image_size),
+                                    static_cast<int64>(filter_total_size)}),
+                       &col_buffer));
 
     // The input offset corresponding to a single input image.
     const int input_offset = input_rows * input_cols * in_depth;
@@ -571,21 +595,29 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
     contract_dims[0].first = 0;
     contract_dims[0].second = 0;
 
-    for (int image_id = 0; image_id < batch; ++image_id) {
-      // When we compute the gradient with respect to the filters, we need to do
-      // im2col to allow gemm-type computation.
-      Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
-                filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
-                stride, col_buffer_data);
+    for (int image_id = 0; image_id < batch; image_id += shard_size) {
+      const int shard_limit = std::min(static_cast<int>(shard_size),
+                                       static_cast<int>(batch) - image_id);
+      for (int shard_id = 0; shard_id < shard_limit; ++shard_id) {
+        // TODO(andydavis) Parallelize this loop.
+        // When we compute the gradient with respect to the filters, we need
+        // to do im2col to allow gemm-type computation.
+        Im2col<T>(input_data, in_depth, input_rows, input_cols, filter_rows,
+                  filter_cols, pad_top, pad_left, pad_bottom, pad_right, stride,
+                  stride, col_buffer_data + shard_id * size_A);
 
-      ConstTensorMap A(col_buffer_data, output_image_size, filter_total_size);
-      ConstTensorMap B(out_backprop_data + output_offset * image_id,
-                       output_image_size, out_depth);
+        input_data += input_offset;
+      }
+
+      ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
+                       filter_total_size);
+      ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
+                       out_depth);
 
       // Gradient with respect to filter.
       C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
 
-      input_data += input_offset;
+      out_backprop_data += output_offset * shard_limit;
     }
   }
 
diff --git a/tensorflow/core/kernels/tile_ops.h b/tensorflow/core/kernels/tile_ops.h
index 1a614fe4f18..99455adce2a 100644
--- a/tensorflow/core/kernels/tile_ops.h
+++ b/tensorflow/core/kernels/tile_ops.h
@@ -46,8 +46,9 @@ template <typename Device, typename T>
 struct TileGrad<Device, T, 0> {
   void operator()(const Device& d, typename TTypes<T, 0>::Tensor out,
                   typename TTypes<T, 0>::ConstTensor in,
-                  const Eigen::DSizes<ptrdiff_t, 0>&,
-                  const Eigen::DSizes<ptrdiff_t, 0>&, bool first) const {
+                  const Eigen::DSizes<Eigen::DenseIndex, 0>&,
+                  const Eigen::DSizes<Eigen::DenseIndex, 0>&,
+                  bool first) const {
     if (first) {
       out.device(d) = in;
     } else {
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 8c0571b50e9..321d9c02768 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -407,7 +407,7 @@ reshape(t, [3, 3]) ==> [[1, 2, 3]
 
 # tensor 't' is [[[1, 1], [2, 2]]
 #                [[3, 3], [4, 4]]]
-# tensor 't' has shape [2, 2]
+# tensor 't' has shape [2, 2, 2]
 reshape(t, [2, 4]) ==> [[1, 1, 2, 2]
                         [3, 3, 4, 4]]
 
diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md
index f34e12b1dcb..10a6df41ff2 100644
--- a/tensorflow/g3doc/api_docs/python/array_ops.md
+++ b/tensorflow/g3doc/api_docs/python/array_ops.md
@@ -529,7 +529,7 @@ tf.shape(split0) ==> [5, 10]
 
 *  <b>`split_dim`</b>: A 0-D `int32` `Tensor`. The dimension along which to split.
     Must be in the range `[0, rank(value))`.
-*  <b>`num_split`</b>: A 0-D `int32` `Tensor`. The number of ways to split.
+*  <b>`num_split`</b>: A Python integer. The number of ways to split.
 *  <b>`value`</b>: The `Tensor` to split.
 *  <b>`name`</b>: A name for the operation (optional).
 
diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md
index edbb8101a49..0e3abf66766 100644
--- a/tensorflow/g3doc/api_docs/python/constant_op.md
+++ b/tensorflow/g3doc/api_docs/python/constant_op.md
@@ -17,7 +17,7 @@ Note: Functions taking `Tensor` arguments can also take anything accepted by
   * [`tf.constant(value, dtype=None, shape=None, name='Const')`](#constant)
 * [Sequences](#AUTOGENERATED-sequences)
   * [`tf.linspace(start, stop, num, name=None)`](#linspace)
-  * [`tf.range(start, limit, delta=1, name='range')`](#range)
+  * [`tf.range(start, limit=None, delta=1, name='range')`](#range)
 * [Random Tensors](#AUTOGENERATED-random-tensors)
   * [Examples:](#AUTOGENERATED-examples-)
   * [`tf.random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None, name=None)`](#random_normal)
@@ -273,12 +273,15 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0  11.0  12.0]
 
 - - -
 
-### `tf.range(start, limit, delta=1, name='range')` <a class="md-anchor" id="range"></a>
+### `tf.range(start, limit=None, delta=1, name='range')` <a class="md-anchor" id="range"></a>
 
 Creates a sequence of integers.
 
-This operation creates a sequence of integers that begins at `start` and
-extends by increments of `delta` up to but not including `limit`.
+Creates a sequence of integers that begins at `start` and extends by
+increments of `delta` up to but not including `limit`.
+
+Like the Python builtin `range`, `start` defaults to 0, so that
+`range(n) = range(0, n)`.
 
 For example:
 
@@ -287,12 +290,16 @@ For example:
 # 'limit' is 18
 # 'delta' is 3
 tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+
+# 'limit' is 5
+tf.range(limit) ==> [0, 1, 2, 3, 4]
 ```
 
 ##### Args: <a class="md-anchor" id="AUTOGENERATED-args-"></a>
 
 
 *  <b>`start`</b>: A 0-D (scalar) of type `int32`. First entry in sequence.
+    Defaults to 0.
 *  <b>`limit`</b>: A 0-D (scalar) of type `int32`. Upper limit of sequence,
     exclusive.
 *  <b>`delta`</b>: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1.
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index 90c70bec24f..85739e6d3bd 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -143,7 +143,7 @@ This must be called by the constructors of subclasses.
 
 - - -
 
-#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, name=None)` <a class="md-anchor" id="Optimizer.minimize"></a>
+#### `tf.train.Optimizer.minimize(loss, global_step=None, var_list=None, gate_gradients=1, aggregation_method=None, name=None)` <a class="md-anchor" id="Optimizer.minimize"></a>
 
 Add operations to minimize 'loss' by updating 'var_list'.
 
@@ -163,6 +163,8 @@ this function.
     under the key GraphKeys.TRAINABLE_VARIABLES.
 *  <b>`gate_gradients`</b>: How to gate the computation of gradients.  Can be
     GATE_NONE, GATE_OP, or  GATE_GRAPH.
+*  <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
+    Valid values are defined in the class `AggregationMethod`.
 *  <b>`name`</b>: Optional name for the returned operation.
 
 ##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
@@ -178,7 +180,7 @@ this function.
 
 - - -
 
-#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1)` <a class="md-anchor" id="Optimizer.compute_gradients"></a>
+#### `tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None)` <a class="md-anchor" id="Optimizer.compute_gradients"></a>
 
 Compute gradients of "loss" for the variables in "var_list".
 
@@ -197,6 +199,8 @@ given variable.
     under the key GraphKey.TRAINABLE_VARIABLES.
 *  <b>`gate_gradients`</b>: How to gate the computation of gradients.  Can be
     GATE_NONE, GATE_OP, or  GATE_GRAPH.
+*  <b>`aggregation_method`</b>: Specifies the method used to combine gradient terms.
+    Valid values are defined in the class `AggregationMethod`.
 
 ##### Returns: <a class="md-anchor" id="AUTOGENERATED-returns-"></a>
 
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index aa0301028a1..e5cd0165443 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -268,6 +268,34 @@ $ bazel-bin/tensorflow/cc/tutorials_example_trainer --use_gpu
 
 Note that "--config=cuda" is needed to enable the GPU support.
 
+##### Enabling Cuda 3.0. <a class="md-anchor" id="AUTOGENERATED-enabling-cuda-3.0."></a>
+TensorFlow officially supports Cuda devices with 3.5 and 5.2 compute
+capabilities. In order to enable earlier Cuda devices such as Grid K520, you
+need to target Cuda 3.0. This can be done through TensorFlow unofficial
+settings with "configure".
+
+```bash
+$ TF_UNOFFICIAL_SETTING=1 ./configure
+
+# Same as the official settings above
+
+WARNING: You are configuring unofficial settings in TensorFlow. Because some
+external libraries are not backward compatible, these settings are largely
+untested and unsupported.
+
+Please specify a list of comma-separated Cuda compute capabilities you want to 
+build with. You can find the compute capability of your device at: 
+https://developer.nvidia.com/cuda-gpus. 
+Please note that each additional compute capability significantly increases 
+your build time and binary size. [Default is: "3.5,5.2"]: 3.0
+
+Setting up Cuda include
+Setting up Cuda lib64
+Setting up Cuda bin
+Setting up Cuda nvvm
+Configuration finished
+```
+
 ##### Known issues <a class="md-anchor" id="AUTOGENERATED-known-issues"></a>
 
 * Although it is possible to build both Cuda and non-Cuda configs under the same
@@ -360,14 +388,20 @@ Make sure you followed the the GPU installation [instructions](#install_cuda).
 
 #### Can't find setup.py <a class="md-anchor" id="AUTOGENERATED-can-t-find-setup.py"></a>
 
-If, during pip install, you encounter an error like:
+If, during `pip install`, you encounter an error like:
 
 ```bash
 ...
 IOError: [Errno 2] No such file or directory: '/tmp/pip-o6Tpui-build/setup.py'
 ```
 
-Solution: upgrade your version of pip.
+Solution: upgrade your version of `pip`:
+
+```bash
+pip install --upgrade pip
+```
+
+This may require `sudo`, depending on how `pip` is installed.
 
 #### SSLError: SSL_VERIFY_FAILED <a class="md-anchor" id="AUTOGENERATED-sslerror--ssl_verify_failed"></a>
 
diff --git a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md
index 8de7b080ebd..e431f4c26d6 100644
--- a/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md
+++ b/tensorflow/g3doc/how_tos/summaries_and_tensorboard/index.md
@@ -66,7 +66,7 @@ every hundred steps or so, as in the following code example.
 
 ```python
 merged_summary_op = tf.merge_all_summaries()
-summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph)
+summary_writer = tf.train.SummaryWriter('/tmp/mnist_logs', sess.graph_def)
 total_step = 0
 while training:
   total_step += 1
diff --git a/tensorflow/g3doc/tutorials/mnist/input_data.py b/tensorflow/g3doc/tutorials/mnist/input_data.py
index 391d133ea1f..890a552010f 100644
--- a/tensorflow/g3doc/tutorials/mnist/input_data.py
+++ b/tensorflow/g3doc/tutorials/mnist/input_data.py
@@ -5,9 +5,9 @@ from __future__ import print_function
 
 import gzip
 import os
-import urllib
 
 import numpy
+from six.moves import urllib
 from six.moves import xrange  # pylint: disable=redefined-builtin
 
 SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
@@ -19,7 +19,7 @@ def maybe_download(filename, work_directory):
     os.mkdir(work_directory)
   filepath = os.path.join(work_directory, filename)
   if not os.path.exists(filepath):
-    filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
+    filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
     statinfo = os.stat(filepath)
     print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
   return filepath
diff --git a/tensorflow/g3doc/tutorials/mnist/mnist.py b/tensorflow/g3doc/tutorials/mnist/mnist.py
index 64be52293a5..925debac6ef 100644
--- a/tensorflow/g3doc/tutorials/mnist/mnist.py
+++ b/tensorflow/g3doc/tutorials/mnist/mnist.py
@@ -91,7 +91,7 @@ def loss(logits, labels):
   # be a 1.0 in the entry corresponding to the label).
   batch_size = tf.size(labels)
   labels = tf.expand_dims(labels, 1)
-  indices = tf.expand_dims(tf.range(0, batch_size, 1), 1)
+  indices = tf.expand_dims(tf.range(batch_size), 1)
   concated = tf.concat(1, [indices, labels])
   onehot_labels = tf.sparse_to_dense(
       concated, tf.pack([batch_size, NUM_CLASSES]), 1.0, 0.0)
diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md
index a83d3eabd52..0a7d1aeae0f 100644
--- a/tensorflow/g3doc/tutorials/mnist/pros/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md
@@ -39,7 +39,7 @@ Tensorflow relies on a highly efficient C++ backend to do its computation. The
 connection to this backend is called a session.  The common usage for TensorFlow
 programs is to first create a graph and then launch it in a session.
 
-Here we instead use the convenience `InteractiveSession` class, which
+Here we instead use the convenient `InteractiveSession` class, which
 makes TensorFlow more flexible about how you
 structure your code.
 It allows you to interleave operations which build a
@@ -232,7 +232,7 @@ print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 ## Build a Multilayer Convolutional Network <a class="md-anchor" id="AUTOGENERATED-build-a-multilayer-convolutional-network"></a>
 
 Getting 91% accuracy on MNIST is bad. It's almost embarrassingly bad. In this
-section, we'll fix that, jumping from a very simple model to something moderatly
+section, we'll fix that, jumping from a very simple model to something moderately
 sophisticated: a small convolutional neural network. This will get us to around
 99.2% accuracy -- not state of the art, but respectable.
 
diff --git a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
index a9e0f284369..d7acb3960ef 100644
--- a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
@@ -9,9 +9,9 @@ import math
 import numpy as np
 import os
 import random
+from six.moves import urllib
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import tensorflow as tf
-import urllib
 import zipfile
 
 # Step 1: Download the data.
@@ -20,7 +20,7 @@ url = 'http://mattmahoney.net/dc/'
 def maybe_download(filename, expected_bytes):
   """Download a file if not present, and make sure it's the right size."""
   if not os.path.exists(filename):
-    filename, _ = urllib.urlretrieve(url + filename, filename)
+    filename, _ = urllib.request.urlretrieve(url + filename, filename)
   statinfo = os.stat(filename)
   if statinfo.st_size == expected_bytes:
     print('Found and verified', filename)
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index 8fcd7901300..627cf01b6a2 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -25,9 +25,9 @@ import os
 import re
 import sys
 import tarfile
-import urllib
 
 import tensorflow.python.platform
+from six.moves import urllib
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import tensorflow as tf
 
@@ -366,7 +366,7 @@ def loss(logits, labels):
   # Reshape the labels into a dense Tensor of
   # shape [batch_size, NUM_CLASSES].
   sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1])
-  indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1])
+  indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1])
   concated = tf.concat(1, [indices, sparse_labels])
   dense_labels = tf.sparse_to_dense(concated,
                                     [FLAGS.batch_size, NUM_CLASSES],
@@ -478,7 +478,8 @@ def maybe_download_and_extract():
       sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
           float(count * block_size) / float(total_size) * 100.0))
       sys.stdout.flush()
-    filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress)
+    filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath,
+                                             reporthook=_progress)
     print()
     statinfo = os.stat(filepath)
     print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index e388b772fe9..8e9275ab6a1 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -11,11 +11,11 @@ from __future__ import print_function
 import gzip
 import os
 import sys
-import urllib
 
 import tensorflow.python.platform
 
 import numpy
+from six.moves import urllib
 from six.moves import xrange  # pylint: disable=redefined-builtin
 import tensorflow as tf
 
@@ -41,7 +41,7 @@ def maybe_download(filename):
     os.mkdir(WORK_DIRECTORY)
   filepath = os.path.join(WORK_DIRECTORY, filename)
   if not os.path.exists(filepath):
-    filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
+    filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
     statinfo = os.stat(filepath)
     print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
   return filepath
diff --git a/tensorflow/models/rnn/rnn_cell_test.py b/tensorflow/models/rnn/rnn_cell_test.py
index 937e1557bd6..447ddfebd43 100644
--- a/tensorflow/models/rnn/rnn_cell_test.py
+++ b/tensorflow/models/rnn/rnn_cell_test.py
@@ -118,7 +118,7 @@ class RNNCellTest(tf.test.TestCase):
       with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
         x = tf.zeros([1, 3])
         m = tf.zeros([1, 3])
-        keep = tf.zeros([1]) + 1
+        keep = tf.zeros([]) + 1
         g, new_m = rnn_cell.DropoutWrapper(rnn_cell.GRUCell(3),
                                            keep, keep)(x, m)
         sess.run([tf.variables.initialize_all_variables()])
diff --git a/tensorflow/models/rnn/seq2seq.py b/tensorflow/models/rnn/seq2seq.py
index 875bcb5e6ee..63e6181a263 100644
--- a/tensorflow/models/rnn/seq2seq.py
+++ b/tensorflow/models/rnn/seq2seq.py
@@ -636,7 +636,7 @@ def sequence_loss_by_example(logits, targets, weights, num_decoder_symbols,
         # SparseToDense does not accept batched inputs, we need to do this by
         # re-indexing and re-sizing. When TensorFlow adds SparseCrossEntropy,
         # rewrite this method.
-        indices = targets[i] + num_decoder_symbols * tf.range(0, batch_size)
+        indices = targets[i] + num_decoder_symbols * tf.range(batch_size)
         with tf.device("/cpu:0"):  # Sparse-to-dense must happen on CPU for now.
           dense = tf.sparse_to_dense(indices, tf.expand_dims(length, 0), 1.0,
                                      0.0)
diff --git a/tensorflow/models/rnn/translate/data_utils.py b/tensorflow/models/rnn/translate/data_utils.py
index b9d951ccd7a..00f77599af5 100644
--- a/tensorflow/models/rnn/translate/data_utils.py
+++ b/tensorflow/models/rnn/translate/data_utils.py
@@ -7,9 +7,9 @@ import gzip
 import os
 import re
 import tarfile
-import urllib
 
 from tensorflow.python.platform import gfile
+from six.moves import urllib
 
 # Special vocabulary symbols - we always put them at the start.
 _PAD = "_PAD"
@@ -40,7 +40,7 @@ def maybe_download(directory, filename, url):
   filepath = os.path.join(directory, filename)
   if not os.path.exists(filepath):
     print("Downloading %s to %s" % (url, filepath))
-    filepath, _ = urllib.urlretrieve(url, filepath)
+    filepath, _ = urllib.request.urlretrieve(url, filepath)
     statinfo = os.stat(filepath)
     print("Succesfully downloaded", filename, statinfo.st_size, "bytes")
   return filepath
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 2cbdf191c63..1b7e0eb9a9a 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -13,8 +13,16 @@ import tensorflow as tf
 
 """
 
-import tensorflow.python.platform
-from tensorflow.core.framework.graph_pb2 import *
+try:
+  import tensorflow.python.platform
+  from tensorflow.core.framework.graph_pb2 import *
+except ImportError as e:
+  msg = """Error importing tensorflow: you should not try to import
+  tensorflow from its source directory; please exit the tensorflow source tree,
+  and relaunch your python interpreter from there.
+  Original ImportError: %s""" % str(e)
+  raise ImportError(msg)
+
 from tensorflow.core.framework.summary_pb2 import *
 from tensorflow.core.framework.config_pb2 import *
 from tensorflow.core.util.event_pb2 import *
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 70321e76dcb..2801d588e89 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1551,6 +1551,8 @@ class Graph(object):
     # True if the graph is considered "finalized".  In that case no
     # new operations can be added.
     self._finalized = False
+    # Functions defined in the graph
+    self._functions = []
 
   def _check_not_finalized(self):
     """Check if the graph is finalized.
@@ -1655,8 +1657,30 @@ class Graph(object):
         bytesize += op.node_def.ByteSize()
         if bytesize >= (1 << 31) or bytesize < 0:
           raise ValueError("GraphDef cannot be larger than 2GB.")
+    if self._functions:
+      for f in self._functions:
+        bytesize += f.ByteSize()
+        if bytesize >= (1 << 31) or bytesize < 0:
+          raise ValueError("GraphDef cannot be larger than 2GB.")
+      graph.library.function.extend(self._functions)
     return graph
 
+  def _add_function(self, function_def):
+    """Adds a function to the graph.
+
+    The function is specified as a [`FunctionDef`]
+    (https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/core/framework/graph.proto)
+    protocol buffer.
+
+    After the function has been added, you can call to the function by
+    passing the function name in place of an op name to
+    `Graph.create_op()`.
+
+    Args:
+      function_def: A `FunctionDef` protocol buffer.
+    """
+    self._functions.append(function_def)
+
   # Helper functions to create operations.
   def create_op(self, op_type, inputs, dtypes,
                 input_types=None, name=None, attrs=None, op_def=None,
@@ -1869,7 +1893,6 @@ class Graph(object):
       A list of Operations.
     """
     return list(self._nodes_by_id.values())
-
   def get_operation_by_name(self, name):
     """Returns the `Operation` with the given `name`.
 
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index b4959485543..16efddc6973 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -75,18 +75,28 @@ class ConcatOpTest(tf.test.TestCase):
     # Random dim to concat on
     concat_dim = np.random.randint(5)
     params = {}
+    if dtype == tf.bfloat16:
+      dtype_feed = tf.float32
+    else:
+      dtype_feed = dtype
     with self.test_session(use_gpu=use_gpu):
       p = []
       for i in np.arange(num_tensors):
         input_shape = shape
         input_shape[concat_dim] = np.random.randint(1, 5)
-        placeholder = tf.placeholder(dtype, shape=input_shape)
+        placeholder = tf.placeholder(dtype_feed, shape=input_shape)
         p.append(placeholder)
 
-        t = dtype.as_numpy_dtype
+        t = dtype_feed.as_numpy_dtype
         params[placeholder] = np.random.rand(*input_shape).astype(t)
 
-      c = tf.concat(concat_dim, p)
+      if dtype != dtype_feed:
+        concat_inputs = [tf.cast(p_i, dtype) for p_i in p]
+      else:
+        concat_inputs = p
+      c = tf.concat(concat_dim, concat_inputs)
+      if dtype != dtype_feed:
+        c = tf.cast(c, dtype_feed)
       result = c.eval(feed_dict=params)
 
     self.assertEqual(result.shape, c.get_shape())
@@ -100,15 +110,17 @@ class ConcatOpTest(tf.test.TestCase):
       ind[concat_dim] = slice(cur_offset,
                               cur_offset + params[p[i]].shape[concat_dim])
       cur_offset += params[p[i]].shape[concat_dim]
-      self.assertAllEqual(result[ind], params[p[i]])
+      if dtype == dtype_feed:
+        self.assertAllEqual(result[ind], params[p[i]])
+      else:
+        self.assertAllClose(result[ind], params[p[i]], 0.01)
 
   def testRandom(self):
     self._testRandom(tf.float32)
     self._testRandom(tf.int16)
     self._testRandom(tf.int32, use_gpu=True)
-    # Note that the following does not work since bfloat16 is not supported in
-    # numpy.
-    # self._testRandom(tf.bfloat16)
+    self._testRandom(tf.bfloat16)
+    self._testRandom(tf.bfloat16, use_gpu=True)
 
   def _testGradientsSimple(self, use_gpu):
     with self.test_session(use_gpu=use_gpu):
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 03844d61771..5a987c912c6 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -262,7 +262,7 @@ class EmbeddingLookupTest(tf.test.TestCase):
           self.assertAllEqual(simple, tf.gather(params, ids).eval())
           # Run a few random sharded versions
           for procs in 1, 2, 3:
-            stride = procs * tf.range(0, params.shape[0] // procs)
+            stride = procs * tf.range(params.shape[0] // procs)
             split_params = [tf.gather(params, stride + p)
                             for p in xrange(procs)]
             sharded = tf.nn.embedding_lookup(split_params, ids).eval()
diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py
index 1b9f1323e8b..8036989a0e6 100644
--- a/tensorflow/python/kernel_tests/init_ops_test.py
+++ b/tensorflow/python/kernel_tests/init_ops_test.py
@@ -190,6 +190,10 @@ class RangeTest(tf.test.TestCase):
         self._Range(100, 500, 100), np.array([100, 200, 300, 400])))
     self.assertEqual(tf.range(0, 5, 1).dtype, tf.int32)
 
+  def testLimitOnly(self):
+    with self.test_session():
+      self.assertAllEqual(np.arange(5), tf.range(5).eval())
+
   def testEmpty(self):
     for start in 0, 5:
       self.assertTrue(np.array_equal(self._Range(start, start, 1), []))
diff --git a/tensorflow/python/kernel_tests/lookup_table_op_test.py b/tensorflow/python/kernel_tests/lookup_table_op_test.py
deleted file mode 100644
index 7b5942cacd4..00000000000
--- a/tensorflow/python/kernel_tests/lookup_table_op_test.py
+++ /dev/null
@@ -1,218 +0,0 @@
-"""Tests for lookup table ops from tf."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow.python.platform
-
-import numpy as np
-import tensorflow as tf
-
-
-class HashTableOpTest(tf.test.TestCase):
-
-  def testHashTable(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      # Initialize with keys and values tensors.
-      keys = tf.constant(['brain', 'salad', 'surgery'])
-      values = tf.constant([0, 1, 2], tf.int64)
-      init = table.initialize_from(keys, values)
-      init.run()
-      self.assertAllEqual(3, table.size().eval())
-
-      input_string = tf.constant(['brain', 'salad', 'tank'])
-      output = table.lookup(input_string)
-
-      result = output.eval()
-      self.assertAllEqual([0, 1, -1], result)
-
-  def testHashTableFindHighRank(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      # Initialize with keys and values tensors.
-      keys = tf.constant(['brain', 'salad', 'surgery'])
-      values = tf.constant([0, 1, 2], tf.int64)
-      init = table.initialize_from(keys, values)
-      init.run()
-      self.assertAllEqual(3, table.size().eval())
-
-      input_string = tf.constant([['brain', 'salad'], ['tank', 'tarkus']])
-      output = table.lookup(input_string)
-
-      result = output.eval()
-      self.assertAllEqual([[0, 1], [-1, -1]], result)
-
-  def testHashTableInitWithPythonArrays(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-      # Empty table.
-      self.assertAllEqual(0, table.size().eval())
-
-      # Initialize with keys and values tensors.
-      keys = ['brain', 'salad', 'surgery']
-      values = [0, 1, 2]
-      init = table.initialize_from(keys, values)
-      init.run()
-      self.assertAllEqual(3, table.size().eval())
-
-      input_string = tf.constant(['brain', 'salad', 'tank'])
-      output = table.lookup(input_string)
-
-      result = output.eval()
-      self.assertAllEqual([0, 1, -1], result)
-
-  def testHashTableInitWithNumPyArrays(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      # Initialize with keys and values tensors.
-      keys = np.array(['brain', 'salad', 'surgery'], dtype=np.str)
-      values = np.array([0, 1, 2], dtype=np.int64)
-      init = table.initialize_from(keys, values)
-      init.run()
-      self.assertAllEqual(3, table.size().eval())
-
-      input_string = tf.constant(['brain', 'salad', 'tank'])
-      output = table.lookup(input_string)
-
-      result = output.eval()
-      self.assertAllEqual([0, 1, -1], result)
-
-  def testMultipleHashTables(self):
-    with self.test_session() as sess:
-      shared_name = ''
-      default_val = -1
-      table1 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-      table2 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-      table3 = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      keys = tf.constant(['brain', 'salad', 'surgery'])
-      values = tf.constant([0, 1, 2], tf.int64)
-      table1.initialize_from(keys, values)
-      table2.initialize_from(keys, values)
-      table3.initialize_from(keys, values)
-
-      tf.initialize_all_tables().run()
-      self.assertAllEqual(3, table1.size().eval())
-      self.assertAllEqual(3, table2.size().eval())
-      self.assertAllEqual(3, table3.size().eval())
-
-      input_string = tf.constant(['brain', 'salad', 'tank'])
-      output1 = table1.lookup(input_string)
-      output2 = table2.lookup(input_string)
-      output3 = table3.lookup(input_string)
-
-      out1, out2, out3 = sess.run([output1, output2, output3])
-      self.assertAllEqual([0, 1, -1], out1)
-      self.assertAllEqual([0, 1, -1], out2)
-      self.assertAllEqual([0, 1, -1], out3)
-
-  def testHashTableWithTensorDefault(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = tf.constant(-1, tf.int64)
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      # Initialize with keys and values tensors.
-      keys = tf.constant(['brain', 'salad', 'surgery'])
-      values = tf.constant([0, 1, 2], tf.int64)
-      init = table.initialize_from(keys, values)
-      init.run()
-
-      input_string = tf.constant(['brain', 'salad', 'tank'])
-      output = table.lookup(input_string)
-
-      result = output.eval()
-      self.assertAllEqual([0, 1, -1], result)
-
-  def testSignatureMismatch(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      # Initialize with keys and values tensors.
-      keys = tf.constant(['brain', 'salad', 'surgery'])
-      values = tf.constant([0, 1, 2], tf.int64)
-      init = table.initialize_from(keys, values)
-      init.run()
-
-      input_string = tf.constant([1, 2, 3], tf.int64)
-      with self.assertRaises(TypeError):
-        table.lookup(input_string)
-
-      with self.assertRaises(TypeError):
-        tf.HashTable(tf.string, tf.int64, 'UNK', shared_name)
-
-  def testDTypes(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      with self.assertRaises(TypeError):
-        tf.HashTable([tf.string], tf.string, default_val, shared_name)
-
-  def testNotInitialized(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      input_string = tf.constant(['brain', 'salad', 'surgery'])
-      output = table.lookup(input_string)
-
-      with self.assertRaisesOpError('Table not initialized'):
-        output.eval()
-
-  def testInitializeTwice(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      # Initialize with keys and values tensors.
-      keys = tf.constant(['brain', 'salad', 'surgery'])
-      values = tf.constant([0, 1, 2], tf.int64)
-      init = table.initialize_from(keys, values)
-      init.run()
-
-      with self.assertRaisesOpError('Table already initialized'):
-        init.run()
-
-  def testInitializationWithInvalidDimensions(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      # Initialize with keys and values tensors.
-      keys = tf.constant(['brain', 'salad', 'surgery'])
-      values = tf.constant([0, 1, 2, 3, 4], tf.int64)
-      with self.assertRaises(ValueError):
-        table.initialize_from(keys, values)
-
-  def testInitializationWithInvalidDataTypes(self):
-    with self.test_session():
-      shared_name = ''
-      default_val = -1
-      table = tf.HashTable(tf.string, tf.int64, default_val, shared_name)
-
-      # Initialize with keys and values tensors.
-      keys = [0, 1, 2]
-      values = ['brain', 'salad', 'surgery']
-      with self.assertRaises(TypeError):
-        table.initialize_from(keys, values)
-
-
-if __name__ == '__main__':
-  tf.test.main()
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 3dedd33cb9a..9893c9b824b 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -74,7 +74,7 @@ def clip_by_norm(t, clip_norm, name=None):
 
     # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
     l2norm_inv = math_ops.rsqrt(
-        math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
+        math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t))))
     tclip = array_ops.identity(t * clip_norm * math_ops.minimum(
         l2norm_inv, constant_op.constant(1.0 / clip_norm)), name=name)
 
@@ -228,7 +228,7 @@ def clip_by_average_norm(t, clip_norm, name=None):
     # L2-norm per element
     n_element = math_ops.cast(array_ops.size(t), types.float32)
     l2norm_inv = math_ops.rsqrt(
-        math_ops.reduce_sum(t * t, math_ops.range(0, array_ops.rank(t))))
+        math_ops.reduce_sum(t * t, math_ops.range(array_ops.rank(t))))
     tclip = array_ops.identity(
         t * clip_norm * math_ops.minimum(
             l2norm_inv * n_element, constant_op.constant(1.0 / clip_norm)),
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 178f716e48a..ed09ff3655e 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -10,7 +10,6 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework import types
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import common_shapes
 from tensorflow.python.ops import control_flow_ops
@@ -408,174 +407,12 @@ class FIFOQueue(QueueBase):
 # TODO(josh11b): class BatchQueue(QueueBase):
 
 
-# pylint: disable=protected-access
-class LookupTableBase(object):
-  """Represents a lookup table that persists across different steps."""
-
-  def __init__(self, key_dtype, value_dtype, default_value, table_ref):
-    """Construct a table object from a table reference.
-
-    Args:
-      key_dtype:  The table key type.
-      value_dtype:  The table value type.
-      default_value: The value to use if a key is missing in the table.
-      table_ref: The table reference, i.e. the output of the lookup table ops.
-    """
-    self._key_dtype = types.as_dtype(key_dtype)
-    self._value_dtype = types.as_dtype(value_dtype)
-    self._shapes = [tensor_shape.TensorShape([1])]
-    self._table_ref = table_ref
-    self._name = self._table_ref.op.name.split("/")[-1]
-    self._default_value = ops.convert_to_tensor(default_value,
-                                                dtype=self._value_dtype)
-    self._default_value.get_shape().merge_with(tensor_shape.scalar())
-
-  @property
-  def table_ref(self):
-    """Get the underlying table reference."""
-    return self._table_ref
-
-  @property
-  def key_dtype(self):
-    """The table key dtype."""
-    return self._key_dtype
-
-  @property
-  def value_dtype(self):
-    """The table value dtype."""
-    return self._value_dtype
-
-  @property
-  def name(self):
-    """The name of the table."""
-    return self._name
-
-  @property
-  def default_value(self):
-    """The default value of the table."""
-    return self._default_value
-
-  def size(self, name=None):
-    """Compute the number of elements in this table.
-
-    Args:
-      name: A name for the operation (optional).
-
-    Returns:
-      A scalar tensor containing the number of elements in this table.
-    """
-    if name is None:
-      name = "%s_Size" % self._name
-    return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
-
-  def lookup(self, keys, name=None):
-    """Looks up `keys` in a table, outputs the corresponding values.
-
-    The `default_value` is use for keys not present in the table.
-
-    Args:
-      keys: Keys to look up.
-      name: Optional name for the op.
-
-    Returns:
-      The operation that looks up the keys.
-
-    Raises:
-      TypeError: when `keys` or `default_value` doesn't match the table data
-        types.
-    """
-    if name is None:
-      name = "%s_lookup_table_find" % self._name
-
-    if keys.dtype != self._key_dtype:
-      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." % (
-          self._key_dtype, keys.dtype))
-
-    return gen_data_flow_ops._lookup_table_find(
-        self._table_ref, keys, self._default_value, name=name)
-
-  def initialize_from(self, keys, values, name=None):
-    """Initialize the table with the provided keys and values tensors.
-
-    Construct an initializer object from keys and value tensors.
-
-    Args:
-      keys: The tensor for the keys.
-      values: The tensor for the values.
-      name: Optional name for the op.
-
-    Returns:
-      The operation that initializes the table.
-
-    Raises:
-      TypeError: when the keys and values data types do not match the table
-      key and value data types.
-    """
-    if name is None:
-      name = "%s_initialize_table" % self.name
-    with ops.op_scope([keys, values], None, name):
-      keys = ops.convert_to_tensor(keys, dtype=self.key_dtype, name="keys")
-      values = ops.convert_to_tensor(values, dtype=self.value_dtype,
-                                     name="values")
-
-    init_op = gen_data_flow_ops._initialize_table(
-        self.table_ref, keys, values, name=name)
-    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
-    return init_op
-
-  def _check_table_dtypes(self, key_dtype, value_dtype):
-    """Check that the given key_dtype and value_dtype matches the table dtypes'.
-
-    Args:
-      key_dtype: The key data type to check.
-      value_dtype: The value data type to check.
-
-    Raises:
-      TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
-        types.
-    """
-    if key_dtype != self.key_dtype:
-      raise TypeError("Invalid key dtype, expected %s but got %s." % (
-          self.key_dtype, key_dtype))
-    if value_dtype != self.value_dtype:
-      raise TypeError("Invalid value dtype, expected %s but got %s." % (
-          self.value_dtype, value_dtype))
-
-
-class HashTable(LookupTableBase):
-  """A generic hash table implementation."""
-
-  def __init__(self, key_dtype, value_dtype, default_value, shared_name=None,
-               name="hash_table"):
-    """Creates a non-initialized hash table.
-
-    This op creates a hash table, specifying the type of its keys and values.
-    Before using the table you will have to initialize it.  After initialization
-    the table will be immutable.
-
-    Args:
-      key_dtype:  Type of the table keys.
-      value_dtype:  Type of the table values.
-      default_value: The scalar tensor to be used when a key is missing in the
-        table.
-      shared_name: Optional. If non-empty, this table will be shared under
-        the given name across multiple sessions.
-      name: Optional name for the hash table op.
-
-    Returns:
-      A `HashTable` object.
-    """
-    table_ref = gen_data_flow_ops._hash_table(
-        shared_name=shared_name, key_dtype=key_dtype,
-        value_dtype=value_dtype, name=name)
-
-    super(HashTable, self).__init__(key_dtype, value_dtype, default_value,
-                                    table_ref)
-
-
 def initialize_all_tables(name="init_all_tables"):
   """Returns an Op that initializes all tables of the default graph.
 
+  Args:
+    name: Optional name for the initialization op.
+
   Returns:
     An Op that initializes all tables.  Note that if there are
     not tables the returned Op is a NoOp.
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 80bedd49846..b74f8f5426a 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -51,7 +51,7 @@ def embedding_lookup(params, ids, name=None):
     else:
       ids = ops.convert_to_tensor(ids, name="ids")
       flat_ids = array_ops.reshape(ids, [-1])
-      original_indices = math_ops.range(0, array_ops.size(flat_ids))
+      original_indices = math_ops.range(array_ops.size(flat_ids))
       # Compute flat_ids % partitions for each id
       ids_mod_p = flat_ids % np
       if ids_mod_p.dtype != types.int32:
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index b404fbc7d70..b8b0741e07a 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -21,7 +21,7 @@ def _ReductionGradAssist(op):
   indices = op.inputs[1]                            # [1, 2]
   indices_shape = array_ops.shape(indices)          # [2]
   new_output_shape = data_flow_ops.dynamic_stitch(  # [2, 1, 1, 7]
-      [math_ops.range(0, input_rank),               # [0, 1, 2, 3]
+      [math_ops.range(input_rank),                  # [0, 1, 2, 3]
        indices],                                    # [1, 2]
       [input_shape,                                 # [2, 3, 5, 7]
        array_ops.fill(indices_shape, 1)])           # [1, 1]
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index f7289ff2340..4a2473cae50 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -536,11 +536,14 @@ ops.Tensor._override_operator("__gt__", greater)
 ops.Tensor._override_operator("__ge__", greater_equal)
 
 
-def range(start, limit, delta=1, name="range"):
+def range(start, limit=None, delta=1, name="range"):
   """Creates a sequence of integers.
 
-  This operation creates a sequence of integers that begins at `start` and
-  extends by increments of `delta` up to but not including `limit`.
+  Creates a sequence of integers that begins at `start` and extends by
+  increments of `delta` up to but not including `limit`.
+
+  Like the Python builtin `range`, `start` defaults to 0, so that
+  `range(n) = range(0, n)`.
 
   For example:
 
@@ -549,10 +552,14 @@ def range(start, limit, delta=1, name="range"):
   # 'limit' is 18
   # 'delta' is 3
   tf.range(start, limit, delta) ==> [3, 6, 9, 12, 15]
+
+  # 'limit' is 5
+  tf.range(limit) ==> [0, 1, 2, 3, 4]
   ```
 
   Args:
     start: A 0-D (scalar) of type `int32`. First entry in sequence.
+      Defaults to 0.
     limit: A 0-D (scalar) of type `int32`. Upper limit of sequence,
       exclusive.
     delta: A 0-D `Tensor` (scalar) of type `int32`. Optional. Default is 1.
@@ -562,6 +569,8 @@ def range(start, limit, delta=1, name="range"):
   Returns:
     An 1-D `int32` `Tensor`.
   """
+  if limit is None:
+    start, limit = 0, start
   return gen_math_ops._range(start, limit, delta, name=name)
 
 
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index caf47b1431c..5a5c06f975e 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -173,6 +173,7 @@ from __future__ import print_function
 
 from six.moves import xrange  # pylint: disable=redefined-builtin
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import types
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import candidate_sampling_ops
@@ -347,7 +348,8 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
 
   Args:
     x: A tensor.
-    keep_prob: A Python float. The probability that each element is kept.
+    keep_prob: A scalar `Tensor` with the same type as x. The probability
+      that each element is kept.
     noise_shape: A 1-D `Tensor` of type `int32`, representing the
       shape for randomly generated keep/drop flags.
     seed: A Python integer. Used to create random seeds. See
@@ -361,10 +363,15 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
   Raises:
     ValueError: If `keep_prob` is not in `(0, 1]`.
   """
-  if not (0 < keep_prob <= 1):
-    raise ValueError("Expected keep_prob in (0, 1], got %g" % keep_prob)
   with ops.op_scope([x], name, "dropout") as name:
     x = ops.convert_to_tensor(x, name="x")
+    if isinstance(keep_prob, float) and not(0 < keep_prob <= 1):
+      raise ValueError("keep_prob must be a scalar tensor or a float in the "
+                       "range (0, 1], got %g" % keep_prob)
+    keep_prob = ops.convert_to_tensor(
+        keep_prob, dtype=x.dtype, name="keep_prob")
+    keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar())
+
     noise_shape = noise_shape or array_ops.shape(x)
     # uniform [keep_prob, 1.0 + keep_prob)
     random_tensor = keep_prob
@@ -372,7 +379,9 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
         noise_shape, seed=seed, dtype=x.dtype)
     # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
     binary_tensor = math_ops.floor(random_tensor)
-    return x * (1.0 / keep_prob) * binary_tensor
+    ret = x * math_ops.inv(keep_prob) * binary_tensor
+    ret.set_shape(x.get_shape())
+    return ret
 
 
 def depthwise_conv2d(input, filter, strides, padding, name=None):
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 535d55f00f0..919ce784918 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -85,7 +85,7 @@ def _BiasAddGrad(unused_bias_op, received_grad):
     Two tensors, the first one for the "tensor" input of the BiasOp,
     the second one for the "bias" input of the BiasOp.
   """
-  reduction_dim_tensor = math_ops.range(0, array_ops.rank(received_grad) - 1)
+  reduction_dim_tensor = math_ops.range(array_ops.rank(received_grad) - 1)
   return (received_grad, math_ops.reduce_sum(received_grad, reduction_dim_tensor))
 
 
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 48f7a4c9876..c24bd0e3726 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -405,28 +405,82 @@ class DropoutTest(test_util.TensorFlowTestCase):
             sorted_value = np.unique(np.sort(value[i, :]))
             self.assertEqual(sorted_value.size, 1)
 
+  def testDropoutPlaceholderKeepProb(self):
+    # Runs dropout with 0-1 tensor 10 times, sum the number of ones and validate
+    # that it is producing approximately the right number of ones over a large
+    # number of samples, based on the keep probability.
+    x_dim = 40
+    y_dim = 30
+    num_iter = 10
+    for keep_prob in [0.1, 0.5, 0.8]:
+      with self.test_session():
+        t = constant_op.constant(1.0,
+                                 shape=[x_dim, y_dim],
+                                 dtype=types.float32)
+        keep_prob_placeholder = array_ops.placeholder(types.float32)
+        dropout = nn.dropout(t, keep_prob_placeholder)
+        final_count = 0
+        self.assertEqual([x_dim, y_dim], dropout.get_shape())
+        for _ in xrange(0, num_iter):
+          value = dropout.eval(feed_dict={keep_prob_placeholder: keep_prob})
+          final_count += np.count_nonzero(value)
+          # Verifies that there are only two values: 0 and 1/keep_prob.
+          sorted_value = np.unique(np.sort(value))
+          self.assertEqual(0, sorted_value[0])
+          self.assertAllClose(1 / keep_prob, sorted_value[1])
+      # Check that we are in the 15% error range
+      expected_count = x_dim * y_dim * keep_prob * num_iter
+      rel_error = math.fabs(final_count - expected_count) / expected_count
+      print(rel_error)
+      self.assertTrue(rel_error < 0.15)
+
+  def testShapedDropoutUnknownShape(self):
+    x_dim = 40
+    y_dim = 30
+    keep_prob = 0.5
+    x = constant_op.constant(1.0, shape=[x_dim, y_dim], dtype=types.float32)
+    dropout_x = nn.dropout(
+        x, keep_prob, noise_shape=array_ops.placeholder(types.int32))
+    self.assertEqual(x.get_shape(), dropout_x.get_shape())
+
+  def testInvalidKeepProb(self):
+    x_dim = 40
+    y_dim = 30
+    t = constant_op.constant(1.0,
+                             shape=[x_dim, y_dim],
+                             dtype=types.float32)
+    with self.assertRaises(ValueError):
+      nn.dropout(t, -1.0)
+    with self.assertRaises(ValueError):
+      nn.dropout(t, 1.1)
+    with self.assertRaises(ValueError):
+      nn.dropout(t, [0.0, 1.0])
+    with self.assertRaises(ValueError):
+      nn.dropout(t, array_ops.placeholder(types.float64))
+    with self.assertRaises(ValueError):
+      nn.dropout(t, array_ops.placeholder(types.float32, shape=[2]))
+
   def testShapedDropoutShapeError(self):
     # Runs shaped dropout and verifies an error is thrown on misshapen noise.
     x_dim = 40
     y_dim = 30
     keep_prob = 0.5
-    with self.test_session():
-      t = constant_op.constant(1.0,
-                               shape=[x_dim, y_dim],
-                               dtype=types.float32)
-      with self.assertRaises(ValueError):
-        _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
-      with self.assertRaises(ValueError):
-        _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
-      with self.assertRaises(ValueError):
-        _ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
-      with self.assertRaises(ValueError):
-        _ = nn.dropout(t, keep_prob, noise_shape=[x_dim])
-      # test that broadcasting proceeds
-      _ = nn.dropout(t, keep_prob, noise_shape=[y_dim])
-      _ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
-      _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
-      _ = nn.dropout(t, keep_prob, noise_shape=[1, 1])
+    t = constant_op.constant(1.0,
+                             shape=[x_dim, y_dim],
+                             dtype=types.float32)
+    with self.assertRaises(ValueError):
+      _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim + 10])
+    with self.assertRaises(ValueError):
+      _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, y_dim, 5])
+    with self.assertRaises(ValueError):
+      _ = nn.dropout(t, keep_prob, noise_shape=[x_dim + 3])
+    with self.assertRaises(ValueError):
+      _ = nn.dropout(t, keep_prob, noise_shape=[x_dim])
+    # test that broadcasting proceeds
+    _ = nn.dropout(t, keep_prob, noise_shape=[y_dim])
+    _ = nn.dropout(t, keep_prob, noise_shape=[1, y_dim])
+    _ = nn.dropout(t, keep_prob, noise_shape=[x_dim, 1])
+    _ = nn.dropout(t, keep_prob, noise_shape=[1, 1])
 
 
 class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 1a7af78a33a..23d3a974cfb 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -437,8 +437,7 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
         default_value, dtype=sp_input.values.dtype)
 
     num_rows = math_ops.cast(sp_input.shape[0], types.int32)
-    all_row_indices = math_ops.cast(
-        math_ops.range(0, num_rows, 1), types.int64)
+    all_row_indices = math_ops.cast(math_ops.range(num_rows), types.int64)
     empty_row_indices, _ = array_ops.list_diff(
         all_row_indices, sp_input.indices[:, 0])
     empty_row_indicator = gen_sparse_ops.sparse_to_dense(
diff --git a/tensorflow/python/platform/default/_logging.py b/tensorflow/python/platform/default/_logging.py
index 66bf2c08890..23318691ee8 100644
--- a/tensorflow/python/platform/default/_logging.py
+++ b/tensorflow/python/platform/default/_logging.py
@@ -6,18 +6,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import logging
 import os
 import sys
 import time
 import thread
-from logging import getLogger
-from logging import log
-from logging import debug
-from logging import error
-from logging import fatal
-from logging import info
-from logging import warn
-from logging import warning
 from logging import DEBUG
 from logging import ERROR
 from logging import FATAL
@@ -25,13 +18,25 @@ from logging import INFO
 from logging import WARN
 
 # Controls which methods from pyglib.logging are available within the project
-# Do not add methods here without also adding to platform/default/_logging.py
+# Do not add methods here without also adding to platform/google/_logging.py
 __all__ = ['log', 'debug', 'error', 'fatal', 'info', 'warn', 'warning',
            'DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN',
            'flush', 'log_every_n', 'log_first_n', 'vlog',
            'TaskLevelStatusMessage', 'get_verbosity', 'set_verbosity']
 
-warning = warn
+# Scope the tensorflow logger to not conflict with users' loggers
+_logger = logging.getLogger('tensorflow')
+_handler = logging.StreamHandler()
+_handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT, None))
+_logger.addHandler(_handler)
+
+log = _logger.log
+debug = _logger.debug
+error = _logger.error
+fatal = _logger.fatal
+info = _logger.info
+warn = _logger.warn
+warning = _logger.warn
 
 _level_names = {
     FATAL: 'FATAL',
@@ -61,7 +66,7 @@ def flush():
 
 # Code below is taken from pyglib/logging
 def vlog(level, msg, *args, **kwargs):
-  log(level, msg, *args, **kwargs)
+  _logger.log(level, msg, *args, **kwargs)
 
 
 def _GetNextLogCountPerToken(token):
@@ -169,12 +174,12 @@ def google2_log_prefix(level, timestamp=None, file_and_line=None):
 
 def get_verbosity():
   """Return how much logging output will be produced."""
-  return getLogger().getEffectiveLevel()
+  return _logger.getEffectiveLevel()
 
 
 def set_verbosity(verbosity):
   """Sets the threshold for what messages will be logged."""
-  getLogger().setLevel(verbosity)
+  _logger.setLevel(verbosity)
 
 
 def _get_thread_id():
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 6734690397c..77c496fb85c 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -163,7 +163,7 @@ def range_input_producer(limit, num_epochs=None, shuffle=True, seed=None,
     is added to the current Graph's QUEUE_RUNNER collection.
   """
   with ops.op_scope([limit], name, "input_producer") as name:
-    range_tensor = math_ops.range(0, limit)
+    range_tensor = math_ops.range(limit)
     return _input_producer(
         range_tensor, types.int32, num_epochs, shuffle, seed, capacity, name,
         "fraction_of_%d_full" % capacity)
diff --git a/tensorflow/python/training/learning_rate_decay.py b/tensorflow/python/training/learning_rate_decay.py
index 8450fae5cb6..c1f886d6644 100644
--- a/tensorflow/python/training/learning_rate_decay.py
+++ b/tensorflow/python/training/learning_rate_decay.py
@@ -33,9 +33,9 @@ def exponential_decay(learning_rate, global_step, decay_steps, decay_rate,
   ...
   global_step = tf.Variable(0, trainable=False)
   starter_learning_rate = 0.1
-  learning_rate = tf.exponential_decay(starter_learning_rate, global_step,
-                                       100000, 0.96, staircase=True)
-  optimizer = tf.GradientDescent(learning_rate)
+  learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
+                                             100000, 0.96, staircase=True)
+  optimizer = tf.GradientDescentOptimizer(learning_rate)
   # Passing global_step to minimize() will increment it at each step.
   optimizer.minimize(...my loss..., global_step=global_step)
   ```
diff --git a/tensorflow/tensorboard/bower.json b/tensorflow/tensorboard/bower.json
index bdd16d662aa..995ba30363f 100644
--- a/tensorflow/tensorboard/bower.json
+++ b/tensorflow/tensorboard/bower.json
@@ -19,26 +19,27 @@
     "dagre": "~0.7.4",
     "es6-promise": "~3.0.2",
     "graphlib": "~1.0.7",
-    "iron-ajax": "PolymerElements/iron-ajax#~1.0.8",
-    "iron-collapse": "PolymerElements/iron-collapse#~1.0.4",
-    "iron-list": "PolymerElements/iron-list#~1.1.5",
-    "paper-button": "PolymerElements/paper-button#~1.0.7",
-    "paper-checkbox": "PolymerElements/paper-checkbox#~1.0.6",
-    "paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#~1.0.4",
-    "paper-header-panel": "PolymerElements/paper-header-panel#~1.0.5",
-    "paper-icon-button": "PolymerElements/paper-icon-button#~1.0.3",
-    "paper-input": "PolymerElements/paper-input#~1.0.15",
-    "paper-item": "PolymerElements/paper-item#~1.0.3",
-    "paper-menu": "PolymerElements/paper-menu#~1.1.1",
-    "paper-progress": "PolymerElements/paper-progress#~1.0.7",
-    "paper-radio-button": "PolymerElements/paper-radio-button#~1.0.8",
-    "paper-radio-group": "PolymerElements/paper-radio-group#~1.0.4",
-    "paper-slider": "PolymerElements/paper-slider#~1.0.4",
-    "paper-styles": "PolymerElements/paper-styles#~1.0.11",
-    "paper-toggle-button": "PolymerElements/paper-toggle-button#~1.0.6",
-    "paper-toolbar": "PolymerElements/paper-toolbar#~1.0.4",
+    "iron-ajax": "PolymerElements/iron-ajax#1.0.7",
+    "iron-collapse": "PolymerElements/iron-collapse#1.0.4",
+    "iron-list": "PolymerElements/iron-list#1.1.5",
+    "iron-selector": "PolymerElements/iron-selector#1.0.7",
+    "paper-button": "PolymerElements/paper-button#1.0.8",
+    "paper-checkbox": "PolymerElements/paper-checkbox#1.0.13",
+    "paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#1.0.5",
+    "paper-header-panel": "PolymerElements/paper-header-panel#1.0.5",
+    "paper-icon-button": "PolymerElements/paper-icon-button#1.0.5",
+    "paper-input": "PolymerElements/paper-input#1.0.16",
+    "paper-item": "PolymerElements/paper-item#1.0.5",
+    "paper-menu": "PolymerElements/paper-menu#1.1.1",
+    "paper-progress": "PolymerElements/paper-progress#1.0.7",
+    "paper-radio-button": "PolymerElements/paper-radio-button#1.0.10",
+    "paper-radio-group": "PolymerElements/paper-radio-group#1.0.6",
+    "paper-slider": "PolymerElements/paper-slider#1.0.7",
+    "paper-styles": "PolymerElements/paper-styles#1.0.12",
+    "paper-toggle-button": "PolymerElements/paper-toggle-button#1.0.11",
+    "paper-toolbar": "PolymerElements/paper-toolbar#1.0.4",
     "plottable": "~1.16.1",
-    "polymer": "~1.2.0"
+    "polymer": "1.1.5"
   },
   "devDependencies": {
     "iron-component-page": "PolymerElements/iron-component-page#^1.0.0",
diff --git a/tensorflow/tensorboard/tensorboard_handler.py b/tensorflow/tensorboard/tensorboard_handler.py
index 0ea7f3a58d5..2cec3b8812b 100644
--- a/tensorflow/tensorboard/tensorboard_handler.py
+++ b/tensorflow/tensorboard/tensorboard_handler.py
@@ -17,9 +17,9 @@ import json
 import mimetypes
 import os
 import StringIO
-import urllib
 import urlparse
 
+from six.moves import urllib
 from six.moves import xrange  # pylint: disable=redefined-builtin
 from google.protobuf import text_format
 import tensorflow.python.platform
@@ -289,7 +289,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
       A string representation of a URL that will load the index-th
       sampled image in the given run with the given tag.
     """
-    query_string = urllib.urlencode({
+    query_string = urllib.parse.urlencode({
         'run': run,
         'tag': tag,
         'index': index