From 1dec49ebb0e076a2ebb513a3f3aaa725714330db Mon Sep 17 00:00:00 2001
From: Jianwei Xie <xiejw@google.com>
Date: Tue, 7 Nov 2017 16:18:26 -0800
Subject: [PATCH 001/115] Automated g4 rollback of changelist 174708213

PiperOrigin-RevId: 174930262
---
 tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 060b3f91292..5a3b8314291 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -66,7 +66,7 @@ _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
 _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY]
 
 # TODO(b/65703635): Flip the value and remove all dead code.
-_WRAP_INPUT_FN_INTO_WHILE_LOOP = True
+_WRAP_INPUT_FN_INTO_WHILE_LOOP = False
 
 
 def _create_global_step(graph):

From 72e0355c498f6f4531ffdb9c40997cad40684da5 Mon Sep 17 00:00:00 2001
From: Artem Belevich <tra@google.com>
Date: Tue, 7 Nov 2017 16:24:37 -0800
Subject: [PATCH 002/115] Added profiler traces for GPU back-end operations.

PiperOrigin-RevId: 174931093
---
 .../compiler/xla/service/gpu/gpu_compiler.cc   | 18 +++++++++++++-----
 .../gpu/llvm_gpu_backend/gpu_backend_lib.cc    |  4 ++++
 2 files changed, 17 insertions(+), 5 deletions(-)

diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index ceb0e530c15..187b4a705c5 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -75,6 +75,7 @@ limitations under the License.
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
 #include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/tracing.h"
 
 namespace se = ::perftools::gputools;
 
@@ -87,6 +88,7 @@ namespace gpu {
 
 namespace {
 
+using tensorflow::port::Tracing;
 using tensorflow::strings::StrCat;
 
 // Any address of a variable residing in global memory or returned by one of the
@@ -231,6 +233,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
 // code (i.e. a cubin) as a byte array.
 StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major,
                                         int cc_minor) {
+  Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true);
   const string ptxas_path =
       tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas");
   VLOG(2) << "Using ptxas at " << ptxas_path;
@@ -295,11 +298,15 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
     std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec) {
   TF_RET_CHECK(stream_exec != nullptr);
 
-  TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(),
-                                       stream_exec->GetDeviceDescription(),
-                                       ShapeSizeBytesFunction()));
-  TF_RETURN_IF_ERROR(
-      PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
+  {
+    Tracing::TraceMe annotation("HLO Transforms", module->name(),
+                                /*is_expensive=*/true);
+    TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(),
+                                         stream_exec->GetDeviceDescription(),
+                                         ShapeSizeBytesFunction()));
+    TF_RETURN_IF_ERROR(
+        PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction()));
+  }
 
   llvm::LLVMContext llvm_context;
   std::string buffer;
@@ -444,6 +451,7 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
 std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx,
                                                             int cc_major,
                                                             int cc_minor) {
+  Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true);
   bool inserted;
   decltype(compilation_cache_.begin()) iter;
   // Pointers into compilation_cache_ where the ptx and (optional) cubin are
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
index 817e95a31c5..1cb963be611 100644
--- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -60,6 +60,7 @@ limitations under the License.
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tracing.h"
 
 namespace xla {
 namespace gpu {
@@ -488,6 +489,9 @@ StatusOr<string> CompileToPtx(llvm::Module* module,
 
   string ptx;
   {
+    tensorflow::port::Tracing::TraceMe annotation(
+        "Compiling IR", llvm_ir::AsString(module->getName()),
+        /*is_expensive=*/true);
     ScopedLoggingTimer compilation_timer(
         "Compile module " + llvm_ir::AsString(module->getName()),
         /*vlog_level=*/2);

From 2815673bcc5db2aa246083dc2fe08b0cc95711c4 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 7 Nov 2017 16:29:35 -0800
Subject: [PATCH 003/115] [tf.data] Saveable iterator for dataset.take(.),
 dataset.skip(.) and dataset.repeat(.).

PiperOrigin-RevId: 174931742
---
 .../contrib/data/python/kernel_tests/BUILD    |  2 +
 .../dataset_serialization_test_base.py        | 73 +++++++++++------
 .../kernel_tests/sequence_dataset_op_test.py  | 78 +++++++++++++++++++
 tensorflow/core/kernels/repeat_dataset_op.cc  | 32 ++++++++
 tensorflow/core/kernels/skip_dataset_op.cc    | 46 ++++++++++-
 tensorflow/core/kernels/take_dataset_op.cc    | 46 ++++++++++-
 6 files changed, 245 insertions(+), 32 deletions(-)

diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 7283f0ff0ae..c1f1d90c5da 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -365,6 +365,7 @@ py_test(
     srcs = ["sequence_dataset_op_test.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":dataset_serialization_test",
         "//tensorflow/contrib/data/python/ops:dataset_ops",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
@@ -428,6 +429,7 @@ py_test(
     srcs = ["zip_dataset_op_test.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":dataset_serialization_test",
         "//tensorflow/contrib/data/python/ops:dataset_ops",
         "//tensorflow/contrib/data/python/ops:iterator_ops",
         "//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
index df9147af6c0..369b789a521 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
@@ -32,7 +32,7 @@ from tensorflow.python.util import nest
 
 
 class DatasetSerializationTestBase(test.TestCase):
-  """Base class for testing finite serializable datasets."""
+  """Base class for testing serializable datasets."""
 
   def tearDown(self):
     self._delete_ckpt()
@@ -58,17 +58,19 @@ class DatasetSerializationTestBase(test.TestCase):
     if ds_fn2:
       self.verify_restore_in_modified_graph(ds_fn1, ds_fn2, num_outputs)
 
-  def verify_unused_iterator(self, ds_fn, num_outputs):
+  def verify_unused_iterator(self, ds_fn, num_outputs, verify_exhausted=True):
     """Verifies that saving and restoring an unused iterator works.
 
     Args:
       ds_fn: See `run_core_tests`.
       num_outputs: See `run_core_tests`.
+      verify_exhausted: See `gen_outputs`.
 
     Raises:
       AssertionError if any test fails.
     """
-    self.verify_run_with_breaks(ds_fn, [0], num_outputs)
+    self.verify_run_with_breaks(
+        ds_fn, [0], num_outputs, verify_exhausted=verify_exhausted)
 
   def verify_fully_used_iterator(self, ds_fn, num_outputs):
     """Verifies that saving and restoring a fully used iterator works.
@@ -104,12 +106,16 @@ class DatasetSerializationTestBase(test.TestCase):
         ds_fn, [], 0, ckpt_saved=True, verify_exhausted=True)
     self.assertEqual(len(actual), 0)
 
-  def verify_init_before_restore(self, ds_fn, num_outputs):
+  def verify_init_before_restore(self,
+                                 ds_fn,
+                                 num_outputs,
+                                 verify_exhausted=True):
     """Verifies that retoring into an already initilized iterator works.
 
     Args:
       ds_fn: See `run_core_tests`.
       num_outputs: See `run_core_tests`.
+      verify_exhausted: See `gen_outputs`.
 
     Raises:
       AssertionError if any test fails.
@@ -118,9 +124,14 @@ class DatasetSerializationTestBase(test.TestCase):
         ds_fn,
         self.gen_break_points(num_outputs),
         num_outputs,
-        init_before_restore=True)
+        init_before_restore=True,
+        verify_exhausted=verify_exhausted)
 
-  def verify_multiple_breaks(self, ds_fn, num_outputs, num_breaks=10):
+  def verify_multiple_breaks(self,
+                             ds_fn,
+                             num_outputs,
+                             num_breaks=10,
+                             verify_exhausted=True):
     """Attempts to save/restore at multiple break points.
 
     Args:
@@ -128,16 +139,22 @@ class DatasetSerializationTestBase(test.TestCase):
       num_outputs: See `run_core_tests`.
       num_breaks: The number of break points. These are uniformly spread in
         [0, num_outputs] both inclusive.
+      verify_exhausted: See `gen_outputs`.
 
     Raises:
       AssertionError if any test fails.
     """
-    self.verify_run_with_breaks(ds_fn,
-                                self.gen_break_points(num_outputs, num_breaks),
-                                num_outputs)
+    self.verify_run_with_breaks(
+        ds_fn,
+        self.gen_break_points(num_outputs),
+        num_outputs,
+        verify_exhausted=verify_exhausted)
 
-  def verify_reset_restored_iterator(self, ds_fn, num_outputs,
-                                     break_point=None):
+  def verify_reset_restored_iterator(self,
+                                     ds_fn,
+                                     num_outputs,
+                                     break_point=None,
+                                     verify_exhausted=True):
     """Attempts to re-initialize a restored iterator.
 
     This is useful when restoring a training checkpoint during validation.
@@ -146,6 +163,7 @@ class DatasetSerializationTestBase(test.TestCase):
       ds_fn: See `run_core_tests`.
       num_outputs: See `run_core_tests`.
       break_point: Break point. Optional. Defaults to num_outputs/2.
+      verify_exhausted: See `gen_outputs`.
 
     Raises:
       AssertionError if any test fails.
@@ -153,7 +171,8 @@ class DatasetSerializationTestBase(test.TestCase):
     break_point = num_outputs // 2 if not break_point else break_point
 
     # Collect ground truth containing all outputs.
-    expected = self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True)
+    expected = self.gen_outputs(
+        ds_fn, [], num_outputs, verify_exhausted=verify_exhausted)
 
     # Skip some items and save checkpoint.
     self.gen_outputs(ds_fn, [], break_point, verify_exhausted=False)
@@ -168,15 +187,17 @@ class DatasetSerializationTestBase(test.TestCase):
         sess.run(init_op)
         for _ in range(num_outputs):
           actual.append(sess.run(get_next_op))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next_op)
+        if verify_exhausted:
+          with self.assertRaises(errors.OutOfRangeError):
+            sess.run(get_next_op)
     self.match(expected, actual)
 
   def verify_restore_in_modified_graph(self,
                                        ds_fn1,
                                        ds_fn2,
                                        num_outputs,
-                                       break_point=None):
+                                       break_point=None,
+                                       verify_exhausted=True):
     """Attempts to restore an iterator in a modified graph.
 
     Builds an input pipeline using ds_fn1, runs it for `break_point` steps
@@ -188,6 +209,7 @@ class DatasetSerializationTestBase(test.TestCase):
       ds_fn2: See `run_core_tests`.
       num_outputs: See `run_core_tests`.
       break_point: Break point. Optional. Defaults to num_outputs/2.
+      verify_exhausted: See `gen_outputs`.
 
     Raises:
       AssertionError if any test fails.
@@ -196,15 +218,15 @@ class DatasetSerializationTestBase(test.TestCase):
 
     # Skip `break_point` items and store the remaining produced from ds_fn1
     # in `expected`.
-    self.gen_outputs(ds_fn1, [], break_point)
+    self.gen_outputs(ds_fn1, [], break_point, verify_exhausted=False)
     expected = self.gen_outputs(
         ds_fn1, [],
         num_outputs - break_point,
         ckpt_saved=True,
-        verify_exhausted=True)
+        verify_exhausted=verify_exhausted)
 
     # Generate `break_point` items from ds_fn1 and save checkpoint.
-    self.gen_outputs(ds_fn1, [], break_point)
+    self.gen_outputs(ds_fn1, [], break_point, verify_exhausted=False)
 
     actual = []
     # Build graph for ds_fn2 but load checkpoint for ds_fn1.
@@ -214,8 +236,9 @@ class DatasetSerializationTestBase(test.TestCase):
         self._restore(saver, sess)
         for _ in range(num_outputs - break_point):
           actual.append(sess.run(get_next_op))
-        with self.assertRaises(errors.OutOfRangeError):
-          sess.run(get_next_op)
+        if verify_exhausted:
+          with self.assertRaises(errors.OutOfRangeError):
+            sess.run(get_next_op)
 
     self.match(expected, actual)
 
@@ -223,6 +246,7 @@ class DatasetSerializationTestBase(test.TestCase):
                              ds_fn,
                              break_points,
                              num_outputs,
+                             verify_exhausted=True,
                              init_before_restore=False):
     """Verifies that ds_fn() produces the same outputs with and without breaks.
 
@@ -237,6 +261,7 @@ class DatasetSerializationTestBase(test.TestCase):
       ds_fn: See `gen_outputs`.
       break_points: See `gen_outputs`.
       num_outputs: See `gen_outputs`.
+      verify_exhausted: See `gen_outputs`.
       init_before_restore: See `gen_outputs`.
 
     Raises:
@@ -245,13 +270,13 @@ class DatasetSerializationTestBase(test.TestCase):
     expected = self.gen_outputs(
         ds_fn, [],
         num_outputs,
-        verify_exhausted=True,
+        verify_exhausted=verify_exhausted,
         init_before_restore=init_before_restore)
     actual = self.gen_outputs(
         ds_fn,
         break_points,
         num_outputs,
-        verify_exhausted=True,
+        verify_exhausted=verify_exhausted,
         init_before_restore=init_before_restore)
     self.match(expected, actual)
 
@@ -261,7 +286,7 @@ class DatasetSerializationTestBase(test.TestCase):
                   num_outputs,
                   ckpt_saved=False,
                   init_before_restore=False,
-                  verify_exhausted=False):
+                  verify_exhausted=True):
     """Generates elements from input dataset while stopping at break points.
 
     Produces `num_outputs` outputs and saves the state of the iterator in the
@@ -285,7 +310,7 @@ class DatasetSerializationTestBase(test.TestCase):
         after producing `num_outputs` elements.
 
     Returns:
-      A list if `num_outputs` items.
+      A list of `num_outputs` items.
     """
     outputs = []
 
diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
index 91615e9f620..1a26da82e53 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
 from tensorflow.contrib.data.python.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -207,5 +208,82 @@ class SequenceDatasetTest(test.TestCase):
         sess.run(get_next)
 
 
+class SequenceDatasetSerializationTest(
+    dataset_serialization_test_base.DatasetSerializationTestBase):
+
+  def _build_skip_dataset(self, count):
+    components = (np.arange(10),)
+    return dataset_ops.Dataset.from_tensor_slices(components).skip(count)
+
+  def testSkipFewerThanInputs(self):
+    count = 4
+    num_outputs = 10 - count
+    self.run_core_tests(lambda: self._build_skip_dataset(count),
+                        lambda: self._build_skip_dataset(count + 2),
+                        num_outputs)
+
+  def testSkipVarious(self):
+    # Skip more than inputs
+    self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0)
+    # Skip exactly the input size
+    self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0)
+    self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0)
+    # Skip nothing
+    self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10)
+
+  def _build_take_dataset(self, count):
+    components = (np.arange(10),)
+    return dataset_ops.Dataset.from_tensor_slices(components).take(count)
+
+  def testTakeFewerThanInputs(self):
+    count = 4
+    self.run_core_tests(
+        lambda: self._build_take_dataset(count),
+        lambda: self._build_take_dataset(count + 2),
+        count,
+    )
+
+  def testTakeVarious(self):
+    # Take more than inputs
+    self.run_core_tests(lambda: self._build_take_dataset(20), None, 10)
+    # Take exactly the input size
+    self.run_core_tests(lambda: self._build_take_dataset(10), None, 10)
+    # Take all
+    self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10)
+    # Take nothing
+    self.run_core_tests(lambda: self._build_take_dataset(0), None, 0)
+
+  def _build_repeat_dataset(self, count, take_count=3):
+    components = (np.arange(10),)
+    return dataset_ops.Dataset.from_tensor_slices(components).take(
+        take_count).repeat(count)
+
+  def testFiniteRepeat(self):
+    count = 10
+    self.run_core_tests(lambda: self._build_repeat_dataset(count),
+                        lambda: self._build_repeat_dataset(count + 2),
+                        3 * count)
+
+  def testEmptyRepeat(self):
+    self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0)
+
+  def testInfiniteRepeat(self):
+    self.verify_unused_iterator(
+        lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
+    self.verify_init_before_restore(
+        lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
+    self.verify_multiple_breaks(
+        lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
+    self.verify_reset_restored_iterator(
+        lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
+    self.verify_restore_in_modified_graph(
+        lambda: self._build_repeat_dataset(-1),
+        lambda: self._build_repeat_dataset(2),
+        20,
+        verify_exhausted=False)
+    # Test repeat empty dataset
+    self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0)
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc
index 9813e99a70b..6c0f4118e6d 100644
--- a/tensorflow/core/kernels/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/repeat_dataset_op.cc
@@ -95,6 +95,15 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
         *end_of_sequence = true;
         return Status::OK();
       }
+
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        return Status::OK();
+      }
+      Status RestoreInternal(OpKernelContext* ctx,
+                             IteratorStateReader* reader) override {
+        return Status::OK();
+      }
     };
 
     class FiniteIterator : public DatasetIterator<Dataset> {
@@ -183,6 +192,29 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
         } while (true);
       }
 
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        if (input_impl_)
+          TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        else
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("uninitialized"), ""));
+        return Status::OK();
+      }
+
+      Status RestoreInternal(OpKernelContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        if (reader->Contains(full_name("uninitialized"))) {
+          input_impl_.reset();
+        } else {
+          input_impl_ = dataset()->input_->MakeIterator(prefix());
+          TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        }
+        return Status::OK();
+      }
+
      private:
       mutex mu_;
       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/skip_dataset_op.cc b/tensorflow/core/kernels/skip_dataset_op.cc
index 52a6116a7cb..05152db1ae2 100644
--- a/tensorflow/core/kernels/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/skip_dataset_op.cc
@@ -35,14 +35,14 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
     int64 count;
     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
 
-    *output = new Dataset(count, input);
+    *output = new Dataset(ctx, count, input);
   }
 
  private:
-  class Dataset : public DatasetBase {
+  class Dataset : public GraphDatasetBase {
    public:
-    Dataset(int64 count, const DatasetBase* input)
-        : count_(count), input_(input) {
+    Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
+        : GraphDatasetBase(ctx), count_(count), input_(input) {
       input_->Ref();
     }
 
@@ -71,6 +71,18 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
 
     string DebugString() override { return "SkipDatasetOp::Dataset"; }
 
+   protected:
+    Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      Node* input_graph_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node));
+      Node* count = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
+      TF_RETURN_IF_ERROR(
+          b->AddDataset(this, {input_graph_node, count}, output));
+      return Status::OK();
+    }
+
    private:
     class EmptyIterator : public DatasetIterator<Dataset> {
      public:
@@ -82,6 +94,16 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
         *end_of_sequence = true;
         return Status::OK();
       }
+
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        return Status::OK();
+      }
+
+      Status RestoreInternal(OpKernelContext* ctx,
+                             IteratorStateReader* reader) override {
+        return Status::OK();
+      }
     };
 
     class FiniteIterator : public DatasetIterator<Dataset> {
@@ -119,6 +141,22 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
         return Status::OK();
       }
 
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
+        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        return Status::OK();
+      }
+
+      Status RestoreInternal(OpKernelContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
+        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        return Status::OK();
+      }
+
      private:
       mutex mu_;
       int64 i_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/take_dataset_op.cc b/tensorflow/core/kernels/take_dataset_op.cc
index c3f33d663cd..f9f675abdae 100644
--- a/tensorflow/core/kernels/take_dataset_op.cc
+++ b/tensorflow/core/kernels/take_dataset_op.cc
@@ -35,14 +35,14 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
     // Create a new TakeDatasetOp::Dataset, and return it as the output.
     int64 count;
     OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "count", &count));
-    *output = new Dataset(count, input);
+    *output = new Dataset(ctx, count, input);
   }
 
  private:
-  class Dataset : public DatasetBase {
+  class Dataset : public GraphDatasetBase {
    public:
-    Dataset(int64 count, const DatasetBase* input)
-        : count_(count), input_(input) {
+    Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
+        : GraphDatasetBase(ctx), count_(count), input_(input) {
       input_->Ref();
     }
 
@@ -72,6 +72,18 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
 
     string DebugString() override { return "TakeDatasetOp::Dataset"; }
 
+   protected:
+    Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      Node* input_graph_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddParentDataset(input_, &input_graph_node));
+      Node* count = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(count_, &count));
+      TF_RETURN_IF_ERROR(
+          b->AddDataset(this, {input_graph_node, count}, output));
+      return Status::OK();
+    }
+
    private:
     class EmptyIterator : public DatasetIterator<Dataset> {
      public:
@@ -83,6 +95,16 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
         *end_of_sequence = true;
         return Status::OK();
       }
+
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        return Status::OK();
+      }
+
+      Status RestoreInternal(OpKernelContext* ctx,
+                             IteratorStateReader* reader) override {
+        return Status::OK();
+      }
     };
 
     class FiniteIterator : public DatasetIterator<Dataset> {
@@ -110,6 +132,22 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
         return Status::OK();
       }
 
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
+        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        return Status::OK();
+      }
+
+      Status RestoreInternal(OpKernelContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
+        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        return Status::OK();
+      }
+
      private:
       mutex mu_;
       int64 i_ GUARDED_BY(mu_);

From 71e279c0c567c700fd02ba7b0a7481b1c1462227 Mon Sep 17 00:00:00 2001
From: Sergio Guadarrama <sguada@google.com>
Date: Tue, 7 Nov 2017 16:30:33 -0800
Subject: [PATCH 004/115] Allow passing other global_steps to summaries.

PiperOrigin-RevId: 174931874
---
 tensorflow/contrib/summary/summary_ops.py     | 43 +++++++++++--------
 .../contrib/summary/summary_ops_test.py       | 12 ++++++
 2 files changed, 38 insertions(+), 17 deletions(-)

diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index 56e31985936..9238671c4a2 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -57,12 +57,14 @@ def should_record_summaries():
 
 # TODO(apassos) consider how to handle local step here.
 @tf_contextlib.contextmanager
-def record_summaries_every_n_global_steps(n):
+def record_summaries_every_n_global_steps(n, global_step=None):
   """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
+  if global_step is None:
+    global_step = training_util.get_global_step()
   collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
   old = collection_ref[:]
   with ops.device("cpu:0"):
-    collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
+    collection_ref[:] = [math_ops.equal(global_step % n, 0)]
   yield
   collection_ref[:] = old
 
@@ -204,68 +206,75 @@ def summary_writer_function(name, tensor, function, family=None):
   return op
 
 
-def generic(name, tensor, metadata, family=None):
+def generic(name, tensor, metadata, family=None, global_step=None):
   """Writes a tensor summary if possible."""
-
+  if global_step is None:
+    global_step = training_util.get_global_step()
   def function(tag, scope):
     # Note the identity to move the tensor to the CPU.
     return gen_summary_ops.write_summary(
         context.context().summary_writer_resource,
-        training_util.get_global_step(), array_ops.identity(tensor),
+        global_step, array_ops.identity(tensor),
         tag, metadata, name=scope)
   return summary_writer_function(name, tensor, function, family=family)
 
 
-def scalar(name, tensor, family=None):
+def scalar(name, tensor, family=None, global_step=None):
   """Writes a scalar summary if possible."""
-
+  if global_step is None:
+    global_step = training_util.get_global_step()
   def function(tag, scope):
     # Note the identity to move the tensor to the CPU.
     return gen_summary_ops.write_scalar_summary(
         context.context().summary_writer_resource,
-        training_util.get_global_step(), tag, array_ops.identity(tensor),
+        global_step, tag, array_ops.identity(tensor),
         name=scope)
 
   return summary_writer_function(name, tensor, function, family=family)
 
 
-def histogram(name, tensor, family=None):
+def histogram(name, tensor, family=None, global_step=None):
   """Writes a histogram summary if possible."""
-
+  if global_step is None:
+    global_step = training_util.get_global_step()
   def function(tag, scope):
     # Note the identity to move the tensor to the CPU.
     return gen_summary_ops.write_histogram_summary(
         context.context().summary_writer_resource,
-        training_util.get_global_step(), tag, array_ops.identity(tensor),
+        global_step, tag, array_ops.identity(tensor),
         name=scope)
 
   return summary_writer_function(name, tensor, function, family=family)
 
 
-def image(name, tensor, bad_color=None, max_images=3, family=None):
+def image(name, tensor, bad_color=None, max_images=3, family=None,
+          global_step=None):
   """Writes an image summary if possible."""
-
+  if global_step is None:
+    global_step = training_util.get_global_step()
   def function(tag, scope):
     bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
                   if bad_color is None else bad_color)
     # Note the identity to move the tensor to the CPU.
     return gen_summary_ops.write_image_summary(
         context.context().summary_writer_resource,
-        training_util.get_global_step(), tag, array_ops.identity(tensor),
+        global_step, tag, array_ops.identity(tensor),
         bad_color_,
         max_images, name=scope)
 
   return summary_writer_function(name, tensor, function, family=family)
 
 
-def audio(name, tensor, sample_rate, max_outputs, family=None):
+def audio(name, tensor, sample_rate, max_outputs, family=None,
+          global_step=None):
   """Writes an audio summary if possible."""
-
+  if global_step is None:
+    global_step = training_util.get_global_step()
   def function(tag, scope):
     # Note the identity to move the tensor to the CPU.
     return gen_summary_ops.write_audio_summary(
         context.context().summary_writer_resource,
-        training_util.get_global_step(),
+        global_step,
         tag,
         array_ops.identity(tensor),
         sample_rate=sample_rate,
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index de7ae6ec277..466e1940969 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -86,6 +86,18 @@ class TargetTest(test_util.TensorFlowTestCase):
       self.assertEqual(len(events), 2)
       self.assertEqual(events[1].summary.value[0].tag, 'scalar')
 
+  def testSummaryGlobalStep(self):
+    global_step = training_util.get_or_create_global_step()
+    logdir = tempfile.mkdtemp()
+    with summary_ops.create_summary_file_writer(
+        logdir, max_queue=0,
+        name='t2').as_default(), summary_ops.always_record_summaries():
+
+      summary_ops.scalar('scalar', 2.0, global_step=global_step)
+
+      events = summary_test_util.events_from_file(logdir)
+      self.assertEqual(len(events), 2)
+      self.assertEqual(events[1].summary.value[0].tag, 'scalar')
 
 if __name__ == '__main__':
   test.main()

From c8530b907a686b92c94d13f854dc504fa10901db Mon Sep 17 00:00:00 2001
From: Allen Lavoie <allenl@google.com>
Date: Tue, 7 Nov 2017 16:50:18 -0800
Subject: [PATCH 005/115] tfe.Network naming under variable scopes. Networks
 take on the full prefix of their parent variable scopes.

Fixes #14164.

PiperOrigin-RevId: 174934769
---
 tensorflow/contrib/eager/python/network.py    |  63 ++++++----
 .../contrib/eager/python/network_test.py      | 108 ++++++++++++++++--
 tensorflow/python/layers/base.py              |  15 ++-
 3 files changed, 148 insertions(+), 38 deletions(-)

diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index c6e628b074e..1a5c6e8aec6 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -244,6 +244,12 @@ class Network(base.Layer):
     self._owned_layers = {}
     # The scope to use if we end up without a parent.
     self._default_parent_variable_scope = variable_scope.get_variable_scope()
+    # Hold on to the variable scope counts from init to check whether a scope
+    # with the name we want was ever created in our parent scope. Without this
+    # check we might have name collisions if the parent scope on init gets
+    # closed before build is called.
+    self._variable_scope_counts_on_init = (
+        variable_scope._get_default_variable_store().variable_scopes_count)
     self._custom_getter, self._deferred_restorations = (
         _make_custom_getter_for_deferred_restorations())
 
@@ -261,18 +267,29 @@ class Network(base.Layer):
 
   def _finalize_name(self, parent_network):
     if not self._name:
-      if not parent_network:
-        name_uid_map = base._get_default_graph_uid_map()
-      else:
-        name_uid_map = parent_network._sub_layer_name_uids
       # Were were not passed a name explicitly (or it was blank), so this is an
       # anonymous Network. We make up a unique name.
       if parent_network:
         avoid_names = parent_network._owned_layers
+        name_uid_map = parent_network._sub_layer_name_uids
       else:
-        avoid_names = None
+        name_uid_map = base._get_default_graph_uid_map()
+        # Figure out which names we have to avoid based on which variable scope
+        # we're nested in.
+        strip_name = self._default_parent_variable_scope.name
+        if strip_name:
+          strip_name += "/"
+        def _strip_on_init_scope(name):
+          if name.startswith(strip_name):
+            return name[len(strip_name):]
+          else:
+            return None
+        avoid_names = set(
+            _strip_on_init_scope(name)
+            for name in self._variable_scope_counts_on_init.keys() if name)
       self._name, self._base_name = self._make_unique_name(
-          name_uid_map=name_uid_map, avoid_names=avoid_names)
+          name_uid_map=name_uid_map, avoid_names=avoid_names,
+          namespace=self._default_parent_variable_scope.name)
     if self._first_parent is None or (self._first_parent  # False = no parent
                                       and self._first_parent() is None):
       # Save a pointer to the parent Network so that we can later check that the
@@ -302,7 +319,13 @@ class Network(base.Layer):
         parent_scope = first_parent._scope
       else:
         parent_scope = self._default_parent_variable_scope
-      with variable_scope.variable_scope(parent_scope):
+      with variable_scope.variable_scope(parent_scope) as parent_vs:
+        expected_scope_name = parent_vs.name + "/" + self._name
+        if expected_scope_name in self._variable_scope_counts_on_init:
+          raise ValueError(
+              ("A Network named '%s' already exists (or a variable_scope was "
+               "created with this name). Names must be unique.") % (
+                   self._name,))
         # Make sure variables with this prefix will be unique.
         with variable_scope.variable_scope(
             None, use_resource=True, default_name=self._name) as scope:
@@ -319,25 +342,22 @@ class Network(base.Layer):
                  "created with this name). Names must be unique.") % (
                      self._name,))
           if (first_parent
-              and scope_prefix[:-1] != first_parent._scope.name):
+              and scope_prefix[:-1] != first_parent.scope_name):
             raise ValueError(
                 ("Network variable names must match a nesting of sub-Network "
                  "names. Expected prefix '%s' from parent network, but got "
                  "'%s' when attempting to create a variable_scope for Network "
                  "'%s'. Likely an explicit variable_scope was inserted into "
                  "the nesting.") % (
-                     first_parent._scope.name,
+                     first_parent.scope_name,
                      scope_prefix[:-1],
                      self._name))
           elif not first_parent and scope_prefix:
             # For the case when this Network is not nested inside any other
-            # Network, but is in a variable_scope. This is an error for now.
-            raise ValueError(
-                "Creating Networks inside named variable_scopes is currently "
-                "not supported (to ensure that variable names match the names "
-                "of Networks in which they were first created). To set "
-                "options, try `with tf.variable_scope(''):`. If this "
-                "limitation bothers you, please file a feature request.")
+            # Network, but is in a variable_scope. This Network's name takes on
+            # the full variable scope prefix.
+            self._name = scope_name
+
       for non_network_sublayer in self._non_network_sublayers:
         self._set_scope_for_nonnetwork_sublayer(non_network_sublayer)
 
@@ -355,8 +375,7 @@ class Network(base.Layer):
         raise ValueError(
             ("The parent of a Layer added to Network %s was garbage collected "
              "before the Layer was built. If this limitation bothers you "
-             "please, comment on "
-             "https://github.com/tensorflow/tensorflow/issues/14164.") %
+             "please file a feature request.") %
             (self.name,))
       with variable_scope.variable_scope(parent_scope):
         # Horrid hack to make Layer variable names which are direct
@@ -420,7 +439,9 @@ class Network(base.Layer):
             # name, and we should respect it (subject to error checking).
             layer._name, layer._base_name = layer._make_unique_name(
                 name_uid_map=self._sub_layer_name_uids,
-                avoid_names=self._owned_layers)
+                avoid_names=self._owned_layers
+                # No namespace required, since we've specified our own UID map.
+            )
           layer._first_parent = weakref.ref(self)
         self._non_network_sublayers.append(layer)
     if (not layer.built
@@ -556,7 +577,7 @@ class Network(base.Layer):
     if os.path.isdir(save_path):
       # If we were passed a directory, default to naming based on the Network
       # name.
-      save_path = os.path.join(save_path, self.name)
+      save_path = os.path.join(save_path, self.name.replace("/", "_"))
     user_map_func = map_func
     if map_func is None:
       map_func = _make_prefix_stripping_map_fn(self.scope_name)
@@ -750,7 +771,7 @@ class Network(base.Layer):
     self._set_scope()  # scope_name should be available to map_funcs
     if os.path.isdir(save_path):
       # If we don't have a name yet, set no parent.
-      save_path = os.path.join(save_path, self.name)
+      save_path = os.path.join(save_path, self.name.replace("/", "_"))
     user_map_func = map_func
     if map_func is None:
       map_func = _make_prefix_stripping_map_fn(self.scope_name)
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index 14adbafe573..1127055c050 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -410,19 +410,103 @@ class NetworkTest(test.TestCase):
 
   @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testWrappingInVariableScope(self):
+    one = constant_op.constant([[1.]])
+    # Naming happens in the order of first build rather than the order of
+    # construction, but for clarity they're the same here and construction is
+    # annotated.
+    outside_net_before = MyNetwork()  # name=my_network_1
+    outside_net_before(one)
+    captured_scope = variable_scope.get_variable_scope()
     with variable_scope.variable_scope("outside_scope"):
-      net = MyNetwork()
-      one = constant_op.constant([[1.]])
-      with self.assertRaisesRegexp(
-          ValueError,
-          ("Creating Networks inside named variable_scopes is currently not "
-           "supported")):
-        net(one)
-      # Alternatively, we could re-name the Network to match the variable_scope:
-      # self.assertEqual("outside_scope/my_network_1", net.name)
-      # self.assertStartsWith(
-      #     expected_start="outside_scope/my_network_1/dense/",
-      #     actual=net.trainable_weights[0].name)
+      net1 = MyNetwork()  # name=outside_scope/my_network_1
+      net1(one)
+      name_conflict1 = MyNetwork(name="name_conflict")  # fine, unique so far
+      name_conflict2 = MyNetwork(name="name_conflict")  # error on build
+      with variable_scope.variable_scope("inside_scope"):
+        # No issue here since the name is unique within its scope.
+        name_conflict3 = MyNetwork(name="name_conflict")
+      net2 = MyNetwork()  # name=outside_scope/my_network_3 to avoid the
+                          # variable_scope my_network_2 below.
+      vs_name_conflict = MyNetwork(name="vs_name_conflict")  # conflict below
+    with variable_scope.variable_scope("intervening_scope"):
+      with variable_scope.variable_scope(captured_scope):
+        with variable_scope.variable_scope("outside_scope"):
+          name_conflict4 = MyNetwork(name="name_conflict")  # error on build
+          with variable_scope.variable_scope("my_network_2"):
+            pass
+          with variable_scope.variable_scope("vs_name_conflict"):
+            pass
+          net3 = MyNetwork()  # name=outside_scope/my_network_4
+    name_conflict1(one)
+    with self.assertRaisesRegexp(
+        ValueError, "named 'name_conflict' already exists"):
+      name_conflict2(one)
+    name_conflict3(one)
+    net2(one)
+    with self.assertRaisesRegexp(
+        ValueError, "or a variable_scope was created with this name"):
+      vs_name_conflict(one)
+    with self.assertRaisesRegexp(
+        ValueError, "named 'name_conflict' already exists"):
+      name_conflict4(one)
+    self.assertEqual("outside_scope/name_conflict",
+                     name_conflict1.name)
+    self.assertStartsWith(
+        expected_start="outside_scope/name_conflict/dense_1/",
+        actual=name_conflict1.variables[0].name)
+    self.assertEqual("outside_scope/inside_scope/name_conflict",
+                     name_conflict3.name)
+    self.assertStartsWith(
+        expected_start="outside_scope/inside_scope/name_conflict/dense_1/",
+        actual=name_conflict3.variables[0].name)
+    self.assertEqual("outside_scope/my_network_1", net1.name)
+    self.assertStartsWith(
+        expected_start="outside_scope/my_network_1/dense_1/",
+        actual=net1.trainable_weights[0].name)
+    self.assertEqual("outside_scope/my_network_3", net2.name)
+    self.assertStartsWith(
+        expected_start="outside_scope/my_network_3/dense_1/",
+        actual=net2.trainable_weights[0].name)
+    net3(one)
+    self.assertEqual("outside_scope/my_network_4", net3.name)
+    self.assertStartsWith(
+        expected_start="outside_scope/my_network_4/dense_1/",
+        actual=net3.trainable_weights[0].name)
+    outside_net_after = MyNetwork()
+    outside_net_after(one)
+    self.assertEqual("my_network_1", outside_net_before.name)
+    self.assertStartsWith(
+        expected_start="my_network_1/dense_1/",
+        actual=outside_net_before.trainable_weights[0].name)
+    self.assertEqual("my_network_2", outside_net_after.name)
+    self.assertStartsWith(
+        expected_start="my_network_2/dense_1/",
+        actual=outside_net_after.trainable_weights[0].name)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testVariableScopeStripping(self):
+    with variable_scope.variable_scope("scope1"):
+      with variable_scope.variable_scope("scope2"):
+        net = MyNetwork()
+    net(constant_op.constant([[2.0]]))
+    self.evaluate(net.variables[0].assign([[42.]]))
+    self.assertEqual(net.name, "scope1/scope2/my_network_1")
+    self.assertStartsWith(
+        expected_start="scope1/scope2/my_network_1/dense_1/",
+        actual=net.trainable_weights[0].name)
+    save_path = net.save(self.get_temp_dir())
+    self.assertIn("scope1_scope2_my_network_1", save_path)
+    restore_net = MyNetwork()
+    # Delayed restoration
+    restore_net.restore(save_path)
+    restore_net(constant_op.constant([[1.0]]))
+    self.assertAllEqual([[42.]],
+                        self.evaluate(restore_net.variables[0]))
+    self.evaluate(restore_net.variables[0].assign([[-1.]]))
+    # Immediate restoration
+    restore_net.restore(save_path)
+    self.assertAllEqual([[42.]],
+                        self.evaluate(restore_net.variables[0]))
 
   @test_util.run_in_graph_and_eager_modes()
   def testLayerNamesRespected(self):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 07b9d9b7a62..8c8d774b754 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -401,10 +401,11 @@ class Layer(object):
     """
     return input_shape
 
-  def _make_unique_name(self, name_uid_map=None, avoid_names=None):
+  def _make_unique_name(self, name_uid_map=None, avoid_names=None,
+                        namespace=''):
     base_name = _to_snake_case(self.__class__.__name__)
     name = _unique_layer_name(base_name, name_uid_map=name_uid_map,
-                              avoid_names=avoid_names)
+                              avoid_names=avoid_names, namespace=namespace)
     return (name, base_name)
 
   def _set_scope(self, scope=None):
@@ -2370,7 +2371,7 @@ def _get_default_graph_uid_map():
   return name_uid_map
 
 
-def _unique_layer_name(name, name_uid_map=None, avoid_names=None):
+def _unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace=''):
   """Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
 
   Arguments:
@@ -2379,6 +2380,9 @@ def _unique_layer_name(name, name_uid_map=None, avoid_names=None):
       names. If None (default), uses a per-Graph dictionary.
     avoid_names: An optional set or dict with names which should not be used. If
       None (default) does not avoid any names.
+    namespace: Gets a name which is unique within the (graph, namespace). Layers
+      which are not Networks use a blank namespace and so get graph-global
+      names.
 
   Returns:
     Unique string name.
@@ -2396,6 +2400,7 @@ def _unique_layer_name(name, name_uid_map=None, avoid_names=None):
     avoid_names = set()
   proposed_name = None
   while proposed_name is None or proposed_name in avoid_names:
-    name_uid_map[name] += 1
-    proposed_name = name + '_' + str(name_uid_map[name])
+    name_key = (namespace, name)
+    name_uid_map[name_key] += 1
+    proposed_name = name + '_' + str(name_uid_map[name_key])
   return proposed_name

From 788344009ed1a9e550e980415be1d271bccb8bef Mon Sep 17 00:00:00 2001
From: Suharsh Sivakumar <suharshs@google.com>
Date: Tue, 7 Nov 2017 16:52:51 -0800
Subject: [PATCH 006/115] Fix FakeQuant to correctly set zero on CPU.

PiperOrigin-RevId: 174935134
---
 tensorflow/core/kernels/fake_quant_ops_functor.h | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h
index b41b22d634d..7aaad6e6c7a 100644
--- a/tensorflow/core/kernels/fake_quant_ops_functor.h
+++ b/tensorflow/core/kernels/fake_quant_ops_functor.h
@@ -132,7 +132,7 @@ struct FakeQuantWithMinMaxVarsFunctor {
     const float max_val = max();
     // If min and max are both zero, we should just return zero.
     if (min_val == 0.0f && max_val == 0.0f) {
-      outputs.setZero();
+      outputs.device(d) = outputs.constant(0.0f);
       return;
     }
     float nudged_min, nudged_max, nudged_scale;
@@ -163,8 +163,8 @@ struct FakeQuantWithMinMaxVarsGradientFunctor {
     // If min and max are both zero, we propagate everything to inputs.
     if (min_val == 0.0f && max_val == 0.0f) {
       backprops_wrt_input.device(d) = gradients;
-      backprop_wrt_min.setZero();
-      backprop_wrt_max.setZero();
+      backprop_wrt_min.device(d) = backprop_wrt_min.constant(0.0f);
+      backprop_wrt_max.device(d) = backprop_wrt_max.constant(0.0f);
       return;
     }
     float nudged_min, nudged_max, nudged_scale;
@@ -205,7 +205,8 @@ struct FakeQuantWithMinMaxVarsPerChannelFunctor {
       const float max_val = max(i);
       // If min and max are both zero, we should just return zero.
       if (min_val == 0.0f && max_val == 0.0f) {
-        outputs.chip<1>(i).setZero();
+        auto chip = outputs.chip<1>(i);
+        chip.device(d) = chip.constant(0.0f);
         continue;
       }
       float nudged_min, nudged_max, nudged_scale;
@@ -242,8 +243,10 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor {
       // If min and max are both zero, we propagate everything to inputs.
       if (min_val == 0.0f && max_val == 0.0f) {
         backprops_wrt_input.chip<1>(i).device(d) = gradients_chip;
-        backprop_wrt_min.chip<0>(i).setZero();
-        backprop_wrt_max.chip<0>(i).setZero();
+        auto min_chip = backprop_wrt_min.chip<0>(i);
+        auto max_chip = backprop_wrt_max.chip<0>(i);
+        min_chip.device(d) = min_chip.constant(0.0f);
+        max_chip.device(d) = max_chip.constant(0.0f);
         continue;
       }
       float nudged_min, nudged_max, nudged_scale;

From aefb02c008f9870c6de6bb10c05725d89427dcb9 Mon Sep 17 00:00:00 2001
From: Chris Leary <leary@google.com>
Date: Tue, 7 Nov 2017 17:08:18 -0800
Subject: [PATCH 007/115] [XLA] Add binary operation name to shape inference
 error message.

PiperOrigin-RevId: 174937290
---
 tensorflow/compiler/xla/service/shape_inference.cc | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 791d17365b1..9c7dc2185e3 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -31,6 +31,7 @@ limitations under the License.
 #include "tensorflow/core/lib/core/stringpiece.h"
 #include "tensorflow/core/lib/math/math_util.h"
 #include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/protobuf.h"
@@ -770,8 +771,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
   TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
 
-  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation"));
-  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation"));
+  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+      lhs, tensorflow::strings::StrCat("lhs of binary operation ",
+                                       BinaryOperation_Name(operation))));
+  TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+      rhs, tensorflow::strings::StrCat("rhs of binary operation ",
+                                       BinaryOperation_Name(operation))));
   switch (operation) {
     case BINOP_DOT:
       return InferDotOpShape(lhs, rhs);

From 5278fa03a9e703d1e414ccebd858f7fdf22dbba5 Mon Sep 17 00:00:00 2001
From: Suharsh Sivakumar <suharshs@google.com>
Date: Tue, 7 Nov 2017 17:12:18 -0800
Subject: [PATCH 008/115] Make quant_delay work even if user didn't create
 global step.

PiperOrigin-RevId: 174937793
---
 tensorflow/contrib/quantize/python/quantize.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 548e33663e8..6382d3f7b41 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -387,7 +387,7 @@ class _QuantizeContext(object):
 
     if delay_requested and self.quant_delay and self.quant_delay > 0:
       activate_quant = math_ops.greater_equal(
-          training_util.get_global_step(),
+          training_util.get_or_create_global_step(),
           self.quant_delay,
           name=scope + '/activate_quant')
       quant = control_flow_ops.cond(

From b4668cc0702a78a2195116c332ec63b743af274b Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 7 Nov 2017 17:12:57 -0800
Subject: [PATCH 009/115] Used tf.where to simplify conditional expression in
 div-sharding.

PiperOrigin-RevId: 174937860
---
 tensorflow/python/ops/embedding_ops.py | 9 +++------
 1 file changed, 3 insertions(+), 6 deletions(-)

diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 8c1ccc68404..f4561d1a830 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -191,12 +191,9 @@ def _embedding_lookup_and_transform(params,
             (flat_ids - extras) // ids_per_partition)
 
         # Emulate a conditional using a boolean indicator tensor
-        is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
-                                                      flat_ids.dtype)
-        new_ids = (is_in_first_extras_partitions * (flat_ids %
-                                                    (ids_per_partition + 1)) +
-                   (1 - is_in_first_extras_partitions) *
-                   ((flat_ids - extras) % ids_per_partition))
+        new_ids = array_ops.where(p_assignments < extras,
+                                  flat_ids % (ids_per_partition + 1),
+                                  (flat_ids - extras) % ids_per_partition)
       else:
         raise ValueError("Unrecognized partition strategy: " +
                          partition_strategy)

From a6de80a90d10797279c950559eed5c101cee6030 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 7 Nov 2017 17:16:55 -0800
Subject: [PATCH 010/115] Removed an unused temporary variable from
 DeviceNameUtils::ParseFullName.

PiperOrigin-RevId: 174938299
---
 tensorflow/core/util/device_name_utils.cc | 1 -
 1 file changed, 1 deletion(-)

diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc
index 2d797c855a5..90c3fed2e82 100644
--- a/tensorflow/core/util/device_name_utils.cc
+++ b/tensorflow/core/util/device_name_utils.cc
@@ -116,7 +116,6 @@ bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
   if (fullname == "/") {
     return true;
   }
-  StringPiece tmp;
   while (!fullname.empty()) {
     bool progress = false;
     if (str_util::ConsumePrefix(&fullname, "/job:")) {

From e32e74d0c350be042647b0cbef9e7a619832e7d5 Mon Sep 17 00:00:00 2001
From: Kay Zhu <kayzhu@google.com>
Date: Tue, 7 Nov 2017 17:22:20 -0800
Subject: [PATCH 011/115] [XLA] Fix comments for arg_literals parameter in
 HloEvaluator::Evaluate.

PiperOrigin-RevId: 174939009
---
 tensorflow/compiler/xla/service/hlo_evaluator.h | 15 +++++++++------
 1 file changed, 9 insertions(+), 6 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 67b6e215fcb..7557aaa2484 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -39,16 +39,18 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
   HloEvaluator();
   // Evaluates an HLO module and an array of pointers to literals.
   // Returns the evaluated result as a literal if successful.
-  // Precondition: argument literals correspond to each input computation's
-  // parameters in their post-ordering. See comment below for example.
+  // Precondition: The indices of arg_literals correspond to the parameter
+  // numbers of the HLO parameters in the computation. See comment below for an
+  // example.
   StatusOr<std::unique_ptr<Literal>> Evaluate(
       const HloModule& module,
       tensorflow::gtl::ArraySlice<const Literal*> arg_literals);
 
   // Evaluates an HLO computation and an array of pointers to literals.
   // Returns the evaluated result as a literal if successful.
-  // Precondition: argument literals correspond to the input computation's
-  // parameters in their post-ordering. For e.g., consider the following graph:
+  // Precondition: The indices of arg_literals correspond to the parameter
+  // numbers of the HLO parameters in the computation. For e.g., consider the
+  // following graph:
   //
   //                *
   //            /       \
@@ -57,8 +59,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
   //       /        \
   //    Parameter0  Constant
   //
-  // The input literals array will have its first literal map to Parameter0 and
-  // the second map to Parameter1.
+  // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number
+  // 1 in this computation. The input literals array will then have its first
+  // literal map to Parameter0 and the second map to Parameter1.
   StatusOr<std::unique_ptr<Literal>> Evaluate(
       const HloComputation& computation,
       tensorflow::gtl::ArraySlice<const Literal*> arg_literals);

From 5fd3810acf0e22130491f300cb75cf450bc9d290 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 7 Nov 2017 17:47:46 -0800
Subject: [PATCH 012/115] [tf.data] Saveable iterator for dataset.zip(..).

PiperOrigin-RevId: 174941651
---
 .../kernel_tests/zip_dataset_op_test.py       | 27 +++++++++++++
 tensorflow/core/kernels/zip_dataset_op.cc     | 40 +++++++++++++++++--
 2 files changed, 63 insertions(+), 4 deletions(-)

diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py
index b0e72183019..5d34b0024c4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
 from tensorflow.contrib.data.python.ops import dataset_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -110,5 +111,31 @@ class ZipDatasetTest(test.TestCase):
         sess.run(get_next)
 
 
+class ZipDatasetSerializationTest(
+    dataset_serialization_test_base.DatasetSerializationTestBase):
+
+  def _build_dataset(self, arr):
+    components = [
+        np.tile(np.array([[1], [2], [3], [4]]), 20),
+        np.tile(np.array([[12], [13], [14], [15]]), 22),
+        np.array(arr)
+    ]
+    datasets = [
+        dataset_ops.Dataset.from_tensor_slices(component)
+        for component in components
+    ]
+    return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
+
+  def testCore(self):
+    # Equal length components
+    arr = [37.0, 38.0, 39.0, 40.0]
+    num_outputs = len(arr)
+    self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs)
+    # Variable length components
+    diff_size_arr = [1.0, 2.0]
+    self.run_core_tests(lambda: self._build_dataset(diff_size_arr),
+                        lambda: self._build_dataset(arr), 2)
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc
index a80b9edbe46..30d64ea6343 100644
--- a/tensorflow/core/kernels/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/zip_dataset_op.cc
@@ -35,14 +35,15 @@ class ZipDatasetOp : public DatasetOpKernel {
       OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
       inputs.push_back(input);
     }
-    *output = new Dataset(inputs);
+    *output = new Dataset(ctx, inputs);
   }
 
  private:
-  class Dataset : public DatasetBase {
+  class Dataset : public GraphDatasetBase {
    public:
-    explicit Dataset(const std::vector<DatasetBase*>& inputs)
-        : inputs_(inputs) {
+    explicit Dataset(OpKernelContext* ctx,
+                     const std::vector<DatasetBase*>& inputs)
+        : GraphDatasetBase(ctx), inputs_(inputs) {
       for (const auto& input : inputs_) {
         input->Ref();
         for (DataType dt : input->output_dtypes()) {
@@ -76,6 +77,21 @@ class ZipDatasetOp : public DatasetOpKernel {
 
     string DebugString() override { return "ZipDatasetOp::Dataset"; }
 
+   protected:
+    Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      std::vector<NodeBuilder::NodeOut> input_graph_nodes;
+      input_graph_nodes.reserve(inputs_.size());
+      for (const auto& input : inputs_) {
+        Node* input_node;
+        TF_RETURN_IF_ERROR(b->AddParentDataset(input, &input_node));
+        input_graph_nodes.emplace_back(input_node);
+      }
+      TF_RETURN_IF_ERROR(
+          b->AddDatasetWithInputAsList(this, input_graph_nodes, output));
+      return Status::OK();
+    }
+
    private:
     class Iterator : public DatasetIterator<Dataset> {
      public:
@@ -109,6 +125,22 @@ class ZipDatasetOp : public DatasetOpKernel {
         return Status::OK();
       }
 
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        for (auto& input_impl : input_impls_)
+          TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
+        return Status::OK();
+      }
+
+      Status RestoreInternal(OpKernelContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        for (auto& input_impl : input_impls_)
+          TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
+        return Status::OK();
+      }
+
      private:
       mutex mu_;
       std::vector<std::unique_ptr<IteratorBase>> input_impls_ GUARDED_BY(mu_);

From 980b74475f3674bd729d35dbc9b2de9f39a8dd6c Mon Sep 17 00:00:00 2001
From: Shanqing Cai <cais@google.com>
Date: Tue, 7 Nov 2017 18:16:47 -0800
Subject: [PATCH 013/115] Register int64 for GPU StridedSlice kernel

PiperOrigin-RevId: 174944857
---
 tensorflow/core/kernels/strided_slice_op.cc      |  1 +
 tensorflow/core/kernels/strided_slice_op_impl.h  |  2 ++
 tensorflow/python/kernel_tests/array_ops_test.py | 13 ++++++++++++-
 3 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 8fc40db3cc2..73b6d4cf6a2 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -427,6 +427,7 @@ REGISTER_STRIDED_SLICE(bfloat16);
 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
 TF_CALL_complex64(REGISTER_GPU);
 TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
 
 // A special GPU kernel for int32.
 // TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h
index de651475724..afe3a051e64 100644
--- a/tensorflow/core/kernels/strided_slice_op_impl.h
+++ b/tensorflow/core/kernels/strided_slice_op_impl.h
@@ -284,6 +284,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_GPU);
 TF_CALL_complex64(DECLARE_FOR_N_GPU);
 TF_CALL_complex128(DECLARE_FOR_N_GPU);
 DECLARE_FOR_N_GPU(int32);
+DECLARE_FOR_N_GPU(int64);
 #endif  // END GOOGLE_CUDA
 
 TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU);
@@ -299,6 +300,7 @@ DECLARE_FOR_N_CPU(bfloat16);
 TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL);
 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_FOR_N_SYCL);
 DECLARE_FOR_N_SYCL(int32);
+DECLARE_FOR_N_SYCL(int64);
 
 #undef DECLARE_FOR_N_SYCL
 #endif // TENSORFLOW_USE_SYCL
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 8f4c94f318b..6eb9c66d068 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -486,7 +486,7 @@ class StridedSliceTest(test_util.TensorFlowTestCase):
         _ = checker2[...]
         _ = checker2[tuple()]
 
-  def testInt64GPU(self):
+  def testFloatSlicedArrayAndInt64IndicesGPU(self):
     if not test_util.is_gpu_available():
       self.skipTest("No GPU available")
     with self.test_session(use_gpu=True, force_gpu=True):
@@ -497,6 +497,17 @@ class StridedSliceTest(test_util.TensorFlowTestCase):
       s = array_ops.strided_slice(x, begin, end, strides)
       self.assertAllEqual([3.], self.evaluate(s))
 
+  def testInt64SlicedArrayAndIndicesGPU(self):
+    if not test_util.is_gpu_available():
+      self.skipTest("No GPU available")
+    with self.test_session(use_gpu=True, force_gpu=True):
+      x = constant_op.constant([1, 2, 3], dtype=dtypes.int64)
+      begin = constant_op.constant([2], dtype=dtypes.int64)
+      end = constant_op.constant([3], dtype=dtypes.int64)
+      strides = constant_op.constant([1], dtype=dtypes.int64)
+      s = array_ops.strided_slice(x, begin, end, strides)
+      self.assertAllEqual([3], self.evaluate(s))
+
   def testDegenerateSlices(self):
     with self.test_session(use_gpu=True):
       checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)

From db85753667aa6eb52c2eefc0b9c5446c6b1a6cd7 Mon Sep 17 00:00:00 2001
From: Suharsh Sivakumar <suharshs@google.com>
Date: Tue, 7 Nov 2017 18:38:12 -0800
Subject: [PATCH 014/115] Add functionality to perform training of additional
 fixed point layer on top of quantized base model. Also modify retrain_test to
 test creation of model info for fixed point mobilenet.

PiperOrigin-RevId: 174946745
---
 .../examples/image_retraining/retrain.py      | 82 +++++++++++++++----
 .../examples/image_retraining/retrain_test.py | 23 +++++-
 2 files changed, 85 insertions(+), 20 deletions(-)

diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 3549891461e..ebddfb20f4b 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -69,11 +69,18 @@ to validate that you have gathered good training data, but if you want to deploy
 on resource-limited platforms, you can try the `--architecture` flag with a
 Mobilenet model. For example:
 
+Run floating-point version of mobilenet:
 ```bash
 python tensorflow/examples/image_retraining/retrain.py \
     --image_dir ~/flower_photos --architecture mobilenet_1.0_224
 ```
 
+Run quantized version of mobilenet:
+```bash
+python tensorflow/examples/image_retraining/retrain.py \
+    --image_dir ~/flower_photos/   --architecture mobilenet_1.0_224_quantized
+```
+
 There are 32 different Mobilenet models to choose from, with a variety of file
 size and latency options. The first number can be '1.0', '0.75', '0.50', or
 '0.25' to control the size, and the second controls the input image size, either
@@ -107,6 +114,7 @@ import numpy as np
 from six.moves import urllib
 import tensorflow as tf
 
+from tensorflow.contrib.quantize.python import quant_ops
 from tensorflow.python.framework import graph_util
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.platform import gfile
@@ -271,6 +279,7 @@ def create_model_graph(model_info):
   """
   with tf.Graph().as_default() as graph:
     model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
+    print('Model path: ', model_path)
     with gfile.FastGFile(model_path, 'rb') as f:
       graph_def = tf.GraphDef()
       graph_def.ParseFromString(f.read())
@@ -337,7 +346,10 @@ def maybe_download_and_extract(data_url):
     statinfo = os.stat(filepath)
     tf.logging.info('Successfully downloaded', filename, statinfo.st_size,
                     'bytes.')
-  tarfile.open(filepath, 'r:gz').extractall(dest_directory)
+    print('Extracting file from ', filepath)
+    tarfile.open(filepath, 'r:gz').extractall(dest_directory)
+  else:
+    print('Not extracting or downloading files, model already present in disk')
 
 
 def ensure_dir_exists(dir_name):
@@ -733,7 +745,7 @@ def variable_summaries(var):
 
 
 def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
-                           bottleneck_tensor_size):
+                           bottleneck_tensor_size, quantize_layer):
   """Adds a new softmax and fully-connected layer for training.
 
   We need to retrain the top layer to identify our new classes, so this function
@@ -745,10 +757,12 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
 
   Args:
     class_count: Integer of how many categories of things we're trying to
-    recognize.
+        recognize.
     final_tensor_name: Name string for the new final node that produces results.
     bottleneck_tensor: The output of the main CNN graph.
     bottleneck_tensor_size: How many entries in the bottleneck vector.
+    quantize_layer: Boolean, specifying whether the newly added layer should be
+        quantized.
 
   Returns:
     The tensors for the training and cross entropy results, and tensors for the
@@ -771,18 +785,41 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
     with tf.name_scope('weights'):
       initial_value = tf.truncated_normal(
           [bottleneck_tensor_size, class_count], stddev=0.001)
-
       layer_weights = tf.Variable(initial_value, name='final_weights')
+      if quantize_layer:
+        quantized_layer_weights = quant_ops.MovingAvgQuantize(
+            layer_weights, is_training=True)
+        variable_summaries(quantized_layer_weights)
 
       variable_summaries(layer_weights)
     with tf.name_scope('biases'):
       layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases')
+      if quantize_layer:
+        quantized_layer_biases = quant_ops.MovingAvgQuantize(
+            layer_biases, is_training=True)
+        variable_summaries(quantized_layer_biases)
+
       variable_summaries(layer_biases)
+
     with tf.name_scope('Wx_plus_b'):
-      logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
-      tf.summary.histogram('pre_activations', logits)
+      if quantize_layer:
+        logits = tf.matmul(bottleneck_input,
+                           quantized_layer_weights) + quantized_layer_biases
+        logits = quant_ops.MovingAvgQuantize(
+            logits,
+            init_min=-32.0,
+            init_max=32.0,
+            is_training=True,
+            num_bits=8,
+            narrow_range=False,
+            ema_decay=0.5)
+        tf.summary.histogram('pre_activations', logits)
+      else:
+        logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases
+        tf.summary.histogram('pre_activations', logits)
 
   final_tensor = tf.nn.softmax(logits, name=final_tensor_name)
+
   tf.summary.histogram('activations', final_tensor)
 
   with tf.name_scope('cross_entropy'):
@@ -790,6 +827,7 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor,
         labels=ground_truth_input, logits=logits)
     with tf.name_scope('total'):
       cross_entropy_mean = tf.reduce_mean(cross_entropy)
+
   tf.summary.scalar('cross_entropy', cross_entropy_mean)
 
   with tf.name_scope('train'):
@@ -825,6 +863,7 @@ def add_evaluation_step(result_tensor, ground_truth_tensor):
 def save_graph_to_file(sess, graph, graph_file_name):
   output_graph_def = graph_util.convert_variables_to_constants(
       sess, graph.as_graph_def(), [FLAGS.final_tensor_name])
+
   with gfile.FastGFile(graph_file_name, 'wb') as f:
     f.write(output_graph_def.SerializeToString())
   return
@@ -858,6 +897,7 @@ def create_model_info(architecture):
     ValueError: If architecture name is unknown.
   """
   architecture = architecture.lower()
+  is_quantized = False
   if architecture == 'inception_v3':
     # pylint: disable=line-too-long
     data_url = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
@@ -902,19 +942,28 @@ def create_model_info(architecture):
             architecture)
         return None
       is_quantized = True
-    data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
-    data_url += version_string + '_' + size_string + '_frozen.tgz'
-    bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+
+    if is_quantized:
+      data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
+      data_url += version_string + '_' + size_string + '_quantized_frozen.tgz'
+      bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+      resized_input_tensor_name = 'Placeholder:0'
+      model_dir_name = ('mobilenet_v1_' + version_string + '_' + size_string +
+                        '_quantized_frozen')
+      model_base_name = 'quantized_frozen_graph.pb'
+
+    else:
+      data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
+      data_url += version_string + '_' + size_string + '_frozen.tgz'
+      bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
+      resized_input_tensor_name = 'input:0'
+      model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
+      model_base_name = 'frozen_graph.pb'
+
     bottleneck_tensor_size = 1001
     input_width = int(size_string)
     input_height = int(size_string)
     input_depth = 3
-    resized_input_tensor_name = 'input:0'
-    if is_quantized:
-      model_base_name = 'quantized_graph.pb'
-    else:
-      model_base_name = 'frozen_graph.pb'
-    model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
     model_file_name = os.path.join(model_dir_name, model_base_name)
     input_mean = 127.5
     input_std = 127.5
@@ -933,6 +982,7 @@ def create_model_info(architecture):
       'model_file_name': model_file_name,
       'input_mean': input_mean,
       'input_std': input_std,
+      'quantize_layer': is_quantized,
   }
 
 
@@ -1028,7 +1078,7 @@ def main(_):
     (train_step, cross_entropy, bottleneck_input, ground_truth_input,
      final_tensor) = add_final_training_ops(
          len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor,
-         model_info['bottleneck_tensor_size'])
+         model_info['bottleneck_tensor_size'], model_info['quantize_layer'])
 
     # Create the operations we need to evaluate the accuracy of our new layer.
     evaluation_step, prediction = add_evaluation_step(
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index c342a17dd86..2de4c4ec99f 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -70,10 +70,18 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
   def testAddFinalTrainingOps(self, flags_mock):
     with tf.Graph().as_default():
       with tf.Session() as sess:
-        bottleneck = tf.placeholder(
-            tf.float32, [1, 1024],
-            name='bottleneck')
-        retrain.add_final_training_ops(5, 'final', bottleneck, 1024)
+        bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
+        # Test creating final training op with quantization
+        retrain.add_final_training_ops(5, 'final', bottleneck, 1024, False)
+        self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
+
+  @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01)
+  def testAddFinalTrainingOpsQuantized(self, flags_mock):
+    with tf.Graph().as_default():
+      with tf.Session() as sess:
+        bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck')
+        # Test creating final training op with quantization
+        retrain.add_final_training_ops(5, 'final', bottleneck, 1024, True)
         self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0'))
 
   def testAddEvaluationStep(self):
@@ -99,5 +107,12 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
     self.assertIsNotNone(model_info)
     self.assertEqual(299, model_info['input_width'])
 
+  def testCreateModelInfoQuantized(self):
+    # Test for mobilenet_quantized
+    model_info = retrain.create_model_info('mobilenet_1.0_224')
+    self.assertIsNotNone(model_info)
+    self.assertEqual(224, model_info['input_width'])
+
+
 if __name__ == '__main__':
   tf.test.main()

From 5d35b03064268e05626d9a65348c1359e83ddcc2 Mon Sep 17 00:00:00 2001
From: Austin Anderson <angerson@google.com>
Date: Tue, 7 Nov 2017 18:39:17 -0800
Subject: [PATCH 015/115] Fix Bazel builds for the TF Lite demo app

Adds a new remote repository for the mobilenet tflite models necessary
for running the TF Lite demo app.

PiperOrigin-RevId: 174946867
---
 tensorflow/workspace.bzl           | 37 ++++++++++++++++--------------
 third_party/tflite_mobilenet.BUILD | 13 +++++++++++
 2 files changed, 33 insertions(+), 17 deletions(-)
 create mode 100644 third_party/tflite_mobilenet.BUILD

diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index afcae6eade1..3081a8d1dcd 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,21 +1,24 @@
 # TensorFlow external dependencies that can be loaded in WORKSPACE files.
 
 load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
+
 load("//third_party/sycl:sycl_configure.bzl", "sycl_configure")
 load("//third_party/mkl:build_defs.bzl", "mkl_repository")
-load("@io_bazel_rules_closure//closure/private:java_import_external.bzl",
-     "java_import_external")
+load(
+    "@io_bazel_rules_closure//closure/private:java_import_external.bzl",
+    "java_import_external",
+)
 load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external")
 load("//third_party/py:python_configure.bzl", "python_configure")
-load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl",
-     "arm_compiler_configure")
-
+load(
+    "//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl",
+    "arm_compiler_configure",
+)
 
 def _is_windows(repository_ctx):
   """Returns true if the host operating system is windows."""
   return repository_ctx.os.name.lower().find("windows") != -1
 
-
 def _get_env_var(repository_ctx, name):
   """Find an environment variable."""
   if name in repository_ctx.os.environ:
@@ -23,7 +26,6 @@ def _get_env_var(repository_ctx, name):
   else:
     return None
 
-
 # Parse the bazel version string from `native.bazel_version`.
 def _parse_bazel_version(bazel_version):
   # Remove commit from version.
@@ -39,7 +41,6 @@ def _parse_bazel_version(bazel_version):
     version_tuple += (str(number),)
   return version_tuple
 
-
 # Check that a specific bazel version is being used.
 def check_version(bazel_version):
   if "bazel_version" not in dir(native):
@@ -56,11 +57,9 @@ def check_version(bazel_version):
       fail("\nCurrent Bazel version is {}, expected at least {}\n".format(
           native.bazel_version, bazel_version))
 
-
 def _repos_are_siblings():
   return Label("@foo//bar").workspace_root.startswith("../")
 
-
 # Temporary workaround to support including TensorFlow as a submodule until this
 # use-case is supported in the next Bazel release.
 def _temp_workaround_http_archive_impl(repo_ctx):
@@ -73,9 +72,7 @@ def _temp_workaround_http_archive_impl(repo_ctx):
   if repo_ctx.attr.patch_file != None:
     _apply_patch(repo_ctx, repo_ctx.attr.patch_file)
 
-
 temp_workaround_http_archive = repository_rule(
-    implementation = _temp_workaround_http_archive_impl,
     attrs = {
         "build_file": attr.label(),
         "repository": attr.string(),
@@ -84,6 +81,7 @@ temp_workaround_http_archive = repository_rule(
         "sha256": attr.string(default = ""),
         "strip_prefix": attr.string(default = ""),
     },
+    implementation = _temp_workaround_http_archive_impl,
 )
 
 # Executes specified command with arguments and calls 'fail' if it exited with
@@ -95,7 +93,6 @@ def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
           + "Stderr: {3}").format(" ".join(cmd_and_args), result.return_code,
                                   result.stdout, result.stderr))
 
-
 # Apply a patch_file to the repository root directory
 # Runs 'patch -p1'
 def _apply_patch(repo_ctx, patch_file):
@@ -113,7 +110,6 @@ def _apply_patch(repo_ctx, patch_file):
     cmd = [bazel_sh, "-c", " ".join(cmd)]
   _execute_and_check_ret_code(repo_ctx, cmd)
 
-
 # Download the repository and apply a patch to its root
 def _patched_http_archive_impl(repo_ctx):
   repo_ctx.download_and_extract(
@@ -122,9 +118,7 @@ def _patched_http_archive_impl(repo_ctx):
       stripPrefix=repo_ctx.attr.strip_prefix)
   _apply_patch(repo_ctx, repo_ctx.attr.patch_file)
 
-
 patched_http_archive = repository_rule(
-    implementation = _patched_http_archive_impl,
     attrs = {
         "patch_file": attr.label(),
         "build_file": attr.label(),
@@ -133,9 +127,9 @@ patched_http_archive = repository_rule(
         "sha256": attr.string(default = ""),
         "strip_prefix": attr.string(default = ""),
     },
+    implementation = _patched_http_archive_impl,
 )
 
-
 # If TensorFlow is linked as a submodule.
 # path_prefix is no longer used.
 # tf_repo_name is thought to be under consideration.
@@ -821,3 +815,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
           "https://github.com/google/flatbuffers/archive/971a68110e4fc1bace10fcb6deeb189e7e1a34ce.tar.gz",
       ],
   )
+
+  native.new_http_archive(
+      name = "tflite_mobilenet",
+      build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+      sha256 = "eb71679d23a0cbdb173b36ea39f3d3096de0a9b0410d148a8237f20cc1157a61",
+      urls = [
+          "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_1.0_224_quantized_2017_11_01.zip"
+      ],
+  )
diff --git a/third_party/tflite_mobilenet.BUILD b/third_party/tflite_mobilenet.BUILD
new file mode 100644
index 00000000000..75663eff485
--- /dev/null
+++ b/third_party/tflite_mobilenet.BUILD
@@ -0,0 +1,13 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"])  # Apache 2.0
+
+filegroup(
+    name = "model_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "BUILD",
+        ],
+    ),
+)

From 4476ea391fcca4f6af0994242e3453fe4a159bb3 Mon Sep 17 00:00:00 2001
From: Suharsh Sivakumar <suharshs@google.com>
Date: Tue, 7 Nov 2017 18:45:36 -0800
Subject: [PATCH 016/115] MovingAvgQuantize and LastValueQuantize should use
 the updated value from the Assign op, otherwise min max variables never get
 updated.

PiperOrigin-RevId: 174947421
---
 tensorflow/contrib/quantize/BUILD             | 18 +++-
 .../contrib/quantize/python/quant_ops.py      | 57 +++++-------
 .../contrib/quantize/python/quant_ops_test.py | 87 +++++++++++++++++++
 .../python/quantize_parameterized_test.py     | 65 +++++++-------
 4 files changed, 160 insertions(+), 67 deletions(-)
 create mode 100644 tensorflow/contrib/quantize/python/quant_ops_test.py

diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 935af80e7a0..45a98c7f858 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -133,7 +133,6 @@ py_library(
     deps = [
         "//tensorflow/contrib/framework:framework_py",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:check_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:init_ops",
         "//tensorflow/python:math_ops",
@@ -143,6 +142,23 @@ py_library(
     ],
 )
 
+py_test(
+    name = "quant_ops_test",
+    size = "small",
+    srcs = ["python/quant_ops_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":quant_ops",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:platform_test",
+        "//tensorflow/python:session",
+        "//tensorflow/python:variables",
+    ],
+)
+
 py_library(
     name = "quantize",
     srcs = ["python/quantize.py"],
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index 0a38ef9fcd6..f80d427ff0a 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -22,15 +22,12 @@ from tensorflow.contrib.framework.python.ops import add_arg_scope
 from tensorflow.contrib.framework.python.ops import model_variable
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.training import moving_averages
 
-EPSILON = 1e-5
-
 
 @add_arg_scope
 def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
@@ -133,12 +130,10 @@ def LastValueQuantize(inputs,
         batch_min = inputs
     else:
       batch_min = math_ops.reduce_min(inputs, name='BatchMin')
-    batch_min -= EPSILON
-    # B-eng requires that 0.0 if always in the [min; max] range.
+    # TFLite requires that 0.0 if always in the [min; max] range.
     batch_min = math_ops.minimum(batch_min, 0.0)
-    assign_min_op = state_ops.assign(
-        min_var, batch_min, name='AssignMinLast').op
-    ops.add_to_collection(updates_collection, assign_min_op)
+    assign_min = state_ops.assign(min_var, batch_min, name='AssignMinLast')
+    ops.add_to_collection(updates_collection, assign_min.op)
 
     if per_channel:
       if input_dim >= 2:
@@ -148,17 +143,15 @@ def LastValueQuantize(inputs,
         batch_max = inputs
     else:
       batch_max = math_ops.reduce_max(inputs, name='BatchMax')
-    batch_max += EPSILON
-    # B-eng requires that 0.0 if always in the [min; max] range.
+    # TFLite requires that 0.0 if always in the [min; max] range.
     batch_max = math_ops.maximum(batch_max, 0.0)
-    assign_max_op = state_ops.assign(
-        max_var, batch_max, name='AssignMaxLast').op
-    ops.add_to_collection(updates_collection, assign_max_op)
+    assign_max = state_ops.assign(max_var, batch_max, name='AssignMaxLast')
+    ops.add_to_collection(updates_collection, assign_max.op)
 
     return _FakeQuantWithMinMaxVars(
         inputs,
-        batch_min,
-        batch_max,
+        assign_min,
+        assign_max,
         per_channel=per_channel,
         num_bits=num_bits,
         narrow_range=narrow_range)
@@ -251,9 +244,9 @@ def MovingAvgQuantize(inputs,
       batch_min = math_ops.reduce_min(inputs, name='BatchMin')
     # B-eng requires that 0.0 if always in the [min; max] range.
     batch_min = math_ops.minimum(batch_min, 0.0)
-    assign_min_op = moving_averages.assign_moving_average(
-        min_var, batch_min, ema_decay, name='AssignMinEma').op
-    ops.add_to_collection(updates_collection, assign_min_op)
+    assign_min = moving_averages.assign_moving_average(
+        min_var, batch_min, ema_decay, name='AssignMinEma')
+    ops.add_to_collection(updates_collection, assign_min.op)
 
     if per_channel:
       if input_dim >= 2:
@@ -265,14 +258,14 @@ def MovingAvgQuantize(inputs,
       batch_max = math_ops.reduce_max(inputs, name='BatchMax')
     # B-eng requires that 0.0 if always in the [min; max] range.
     batch_max = math_ops.maximum(batch_max, 0.0)
-    assign_max_op = moving_averages.assign_moving_average(
-        max_var, batch_max, ema_decay, name='AssignMaxEma').op
-    ops.add_to_collection(updates_collection, assign_max_op)
+    assign_max = moving_averages.assign_moving_average(
+        max_var, batch_max, ema_decay, name='AssignMaxEma')
+    ops.add_to_collection(updates_collection, assign_max.op)
 
     return _FakeQuantWithMinMaxVars(
         inputs,
-        min_var,
-        max_var,
+        assign_min,
+        assign_max,
         per_channel=per_channel,
         num_bits=num_bits,
         narrow_range=narrow_range)
@@ -301,20 +294,10 @@ def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits,
   if per_channel:
     assert len(min_var.get_shape()) == 1
     assert len(max_var.get_shape()) == 1
-    with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]):
-      return array_ops.fake_quant_with_min_max_vars_per_channel(
-          inputs,
-          min_var,
-          max_var,
-          num_bits=num_bits,
-          narrow_range=narrow_range)
+    return array_ops.fake_quant_with_min_max_vars_per_channel(
+        inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range)
   else:
     assert min_var.get_shape() == []  # pylint: disable=g-explicit-bool-comparison
     assert max_var.get_shape() == []  # pylint: disable=g-explicit-bool-comparison
-    with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]):
-      return array_ops.fake_quant_with_min_max_vars(
-          inputs,
-          min_var,
-          max_var,
-          num_bits=num_bits,
-          narrow_range=narrow_range)
+    return array_ops.fake_quant_with_min_max_vars(
+        inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range)
diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py
new file mode 100644
index 00000000000..38846796028
--- /dev/null
+++ b/tensorflow/contrib/quantize/python/quant_ops_test.py
@@ -0,0 +1,87 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for third_party.tensorflow.contrib.quantize.python.quant_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.quantize.python import quant_ops
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+_MIN_MAX_VARS = 'min_max_vars'
+
+
+class QuantOpsTest(googletest.TestCase):
+
+  def testLastValueQuantizeTrainingAssign(self):
+    g = ops.Graph()
+    with session.Session(graph=g) as sess:
+      x = array_ops.placeholder(dtypes.float32, shape=[2])
+      y = quant_ops.LastValueQuantize(
+          x,
+          init_min=0.0,
+          init_max=0.0,
+          is_training=True,
+          vars_collection=_MIN_MAX_VARS)
+
+      # Run the step.
+      sess.run(variables.global_variables_initializer())
+      sess.run(y, feed_dict={x: [-1.0, 1.0]})
+      # Now check that the min_max_vars were, in fact, updated.
+      min_value, max_value = self._GetMinMaxValues(sess)
+      self.assertEqual(min_value, -1.0)
+      self.assertEqual(max_value, 1.0)
+
+  def testMovingAvgQuantizeTrainingAssign(self):
+    g = ops.Graph()
+    with session.Session(graph=g) as sess:
+      x = array_ops.placeholder(dtypes.float32, shape=[2])
+      y = quant_ops.MovingAvgQuantize(
+          x,
+          init_min=0.0,
+          init_max=0.0,
+          is_training=True,
+          vars_collection=_MIN_MAX_VARS)
+
+      # Run the step.
+      sess.run(variables.global_variables_initializer())
+      # Do two runs to avoid zero debias.
+      sess.run(y, feed_dict={x: [-1.0, 1.0]})
+      sess.run(y, feed_dict={x: [0.0, 0.0]})
+      # Now check that the min_max_vars were, in fact, updated.
+      min_value, max_value = self._GetMinMaxValues(sess)
+      self.assertGreater(min_value, -1.0)
+      self.assertLess(min_value, 0.0)
+      self.assertGreater(max_value, 0.0)
+      self.assertLess(max_value, 1.0)
+
+  def _GetMinMaxValues(self, sess):
+    min_max_vars = ops.get_collection(_MIN_MAX_VARS)
+    self.assertEqual(len(min_max_vars), 2)
+    min_idx = 0 if 'min' in min_max_vars[0].name else 1
+    max_idx = (min_idx + 1) % 2
+    min_var, max_var = min_max_vars[min_idx], min_max_vars[max_idx]
+    min_max_values = sess.run([min_var, max_var])
+    return min_max_values[0], min_max_values[1]
+
+
+if __name__ == '__main__':
+  googletest.main()
diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
index 3e62f95bd63..57dab03f162 100644
--- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py
@@ -97,8 +97,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                 quantization_node_name)
     self.assertEqual(weights_quant.type, quantization_node_name)
     expected_inputs = [
-        scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum',
-        scope + '/weights/read'
+        scope + '/weights_quant/AssignMinLast',
+        scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
     ]
     self._AssertInputOpsAre(weights_quant, expected_inputs)
     output_op_name = scope + '/Conv2D'
@@ -109,8 +109,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                quantization_node_name)
       self.assertEqual(conv_quant.type, quantization_node_name)
       expected_inputs = [
-          scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
-          scope + '/BiasAdd'
+          scope + '/conv_quant/AssignMinEma',
+          scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
       ]
       self._AssertInputOpsAre(conv_quant, expected_inputs)
       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -122,7 +122,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
     self.assertEqual(act_quant.type, quantization_node_name)
 
     expected_inputs = [
-        'test/act_quant/min/read', 'test/act_quant/max/read',
+        'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
         'test/' + activation_op_name
     ]
     self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -172,8 +172,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                 quantization_node_name)
     self.assertEqual(weights_quant.type, quantization_node_name)
     expected_inputs = [
-        scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum',
-        scope + '/weights/read'
+        scope + '/weights_quant/AssignMinLast',
+        scope + '/weights_quant/AssignMaxLast', scope + '/weights/read'
     ]
     self._AssertInputOpsAre(weights_quant, expected_inputs)
     output_op_name = scope + '/MatMul'
@@ -184,8 +184,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                quantization_node_name)
       self.assertEqual(conv_quant.type, quantization_node_name)
       expected_inputs = [
-          scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
-          scope + '/BiasAdd'
+          scope + '/conv_quant/AssignMinEma',
+          scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
       ]
       self._AssertInputOpsAre(conv_quant, expected_inputs)
       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -196,7 +196,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                             quantization_node_name)
     self.assertEqual(act_quant.type, quantization_node_name)
     expected_inputs = [
-        'test/act_quant/min/read', 'test/act_quant/max/read',
+        'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
         'test/' + activation_op_name
     ]
     self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -247,7 +247,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                 quantization_node_name)
     self.assertEqual(weights_quant.type, quantization_node_name)
     expected_inputs = [
-        scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum',
+        scope + '/weights_quant/AssignMinLast',
+        scope + '/weights_quant/AssignMaxLast',
         scope + '/depthwise_weights/read'
     ]
     self._AssertInputOpsAre(weights_quant, expected_inputs)
@@ -259,8 +260,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                quantization_node_name)
       self.assertEqual(conv_quant.type, quantization_node_name)
       expected_inputs = [
-          scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
-          scope + '/BiasAdd'
+          scope + '/conv_quant/AssignMinEma',
+          scope + '/conv_quant/AssignMaxEma', scope + '/BiasAdd'
       ]
       self._AssertInputOpsAre(conv_quant, expected_inputs)
       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -271,7 +272,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                             quantization_node_name)
     self.assertEqual(act_quant.type, quantization_node_name)
     expected_inputs = [
-        'test/act_quant/min/read', 'test/act_quant/max/read',
+        'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
         'test/' + activation_op_name
     ]
     self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -401,8 +402,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                 quantization_node_name)
     self.assertEqual(weights_quant.type, quantization_node_name)
     expected_inputs = [
-        scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'),
-        scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'),
+        scope + '/weights_quant/' + ('AssignMinEma'
+                                     if use_ema else 'AssignMinLast'),
+        scope + '/weights_quant/' + ('AssignMaxEma'
+                                     if use_ema else 'AssignMaxLast'),
         scope + '/mul_fold'
     ]
     self._AssertInputOpsAre(weights_quant, expected_inputs)
@@ -415,8 +418,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                quantization_node_name)
       self.assertEqual(conv_quant.type, quantization_node_name)
       expected_inputs = [
-          scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
-          scope + '/add_fold'
+          scope + '/conv_quant/AssignMinEma',
+          scope + '/conv_quant/AssignMaxEma', scope + '/add_fold'
       ]
       self._AssertInputOpsAre(conv_quant, expected_inputs)
       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -427,7 +430,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                             quantization_node_name)
     self.assertEqual(act_quant.type, quantization_node_name)
     expected_inputs = [
-        'test/act_quant/min/read', 'test/act_quant/max/read',
+        'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
         'test/' + activation_op_name
     ]
     self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -518,8 +521,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                 quantization_node_name)
     self.assertEqual(weights_quant.type, quantization_node_name)
     expected_inputs = [
-        scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'),
-        scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'),
+        scope + '/weights_quant/' + ('AssignMinEma'
+                                     if use_ema else 'AssignMinLast'),
+        scope + '/weights_quant/' + ('AssignMaxEma'
+                                     if use_ema else 'AssignMaxLast'),
         scope + '/mul_fold'
     ]
     self._AssertInputOpsAre(weights_quant, expected_inputs)
@@ -532,8 +537,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                quantization_node_name)
       self.assertEqual(conv_quant.type, quantization_node_name)
       expected_inputs = [
-          scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
-          scope + '/add_fold'
+          scope + '/conv_quant/AssignMinEma',
+          scope + '/conv_quant/AssignMaxEma', scope + '/add_fold'
       ]
       self._AssertInputOpsAre(conv_quant, expected_inputs)
       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -544,7 +549,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                             quantization_node_name)
     self.assertEqual(act_quant.type, quantization_node_name)
     expected_inputs = [
-        'test/act_quant/min/read', 'test/act_quant/max/read',
+        'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
         'test/' + activation_op_name
     ]
     self._AssertInputOpsAre(act_quant, expected_inputs)
@@ -639,8 +644,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                 quantization_node_name)
     self.assertEqual(weights_quant.type, quantization_node_name)
     expected_inputs = [
-        scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'),
-        scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'),
+        scope + '/weights_quant/' + ('AssignMinEma'
+                                     if use_ema else 'AssignMinLast'),
+        scope + '/weights_quant/' + ('AssignMaxEma'
+                                     if use_ema else 'AssignMaxLast'),
         scope + '/mul_fold'
     ]
     self._AssertInputOpsAre(weights_quant, expected_inputs)
@@ -653,8 +660,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                                quantization_node_name)
       self.assertEqual(conv_quant.type, quantization_node_name)
       expected_inputs = [
-          scope + '/conv_quant/min/read', scope + '/conv_quant/max/read',
-          scope + '/add_fold'
+          scope + '/conv_quant/AssignMinEma',
+          scope + '/conv_quant/AssignMaxEma', scope + '/add_fold'
       ]
       self._AssertInputOpsAre(conv_quant, expected_inputs)
       output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1'
@@ -665,7 +672,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                             quantization_node_name)
     self.assertEqual(act_quant.type, quantization_node_name)
     expected_inputs = [
-        'test/act_quant/min/read', 'test/act_quant/max/read',
+        'test/act_quant/AssignMinEma', 'test/act_quant/AssignMaxEma',
         'test/' + activation_op_name
     ]
     self._AssertInputOpsAre(act_quant, expected_inputs)

From dc2636aa6c88f41e1aec2a367e341eb42ceead54 Mon Sep 17 00:00:00 2001
From: Benoit Steiner <bsteiner@google.com>
Date: Tue, 7 Nov 2017 18:46:10 -0800
Subject: [PATCH 017/115] Silenced an unnecessary warning

PiperOrigin-RevId: 174947453
---
 tensorflow/core/grappler/utils.cc | 1 -
 1 file changed, 1 deletion(-)

diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index d9f4cdb5ae5..11bd8fa5cb3 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -45,7 +45,6 @@ NodeDef* NodeMap::GetNode(const string& name) const {
   string node_name = NodeName(name);
   auto it = nodes_.find(node_name);
   if (it == nodes_.end()) {
-    LOG(WARNING) << "Node " << node_name << " is not in the graph.";
     return nullptr;
   }
   return it->second;

From d484522eb3d58aac70130f5c02a732c7442046bc Mon Sep 17 00:00:00 2001
From: Mark Daoust <markdaoust@google.com>
Date: Tue, 7 Nov 2017 19:06:01 -0800
Subject: [PATCH 018/115] Fix link (the link tool expects these to be on one
 line)

PiperOrigin-RevId: 174948909
---
 tensorflow/docs_src/mobile/index.md | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/docs_src/mobile/index.md b/tensorflow/docs_src/mobile/index.md
index a6f1422f6f1..06ad47bc62a 100644
--- a/tensorflow/docs_src/mobile/index.md
+++ b/tensorflow/docs_src/mobile/index.md
@@ -35,8 +35,8 @@ speech-driven interface, and many of these require on-device processing. Most of
 the time a user isn’t giving commands, and so streaming audio continuously to a
 remote server would be a waste of bandwidth, since it would mostly be silence or
 background noises. To solve this problem it’s common to have a small neural
-network running on-device @{$tutorials/audio_recognition$listening out for a
-particular keyword}. Once that keyword has been spotted, the rest of the
+network running on-device @{$tutorials/audio_recognition$listening out for a particular keyword}.
+Once that keyword has been spotted, the rest of the
 conversation can be transmitted over to the server for further processing if
 more computing power is needed.
 

From fa5672bddf3f78283d7d1552a42ffc8708f863bb Mon Sep 17 00:00:00 2001
From: Colin Raffel <craffel@google.com>
Date: Tue, 7 Nov 2017 21:05:37 -0800
Subject: [PATCH 019/115] Fix tf.contrib.seq2seq._monotonic_probability_fn to
 use a hard sigmoid when mode='hard'.

Also adds tests to make sure the attention probabilities are 0 or 1 when mode='hard'.

PiperOrigin-RevId: 174956465
---
 .../kernel_tests/attention_wrapper_test.py    | 37 +++++++++++++++++++
 .../seq2seq/python/ops/attention_wrapper.py   |  6 ++-
 2 files changed, 42 insertions(+), 1 deletion(-)

diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
index 91493302b1a..01a5540121a 100644
--- a/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
+++ b/tensorflow/contrib/seq2seq/python/kernel_tests/attention_wrapper_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import rnn_cell
 from tensorflow.python.ops import variables
 from tensorflow.python.ops import variable_scope as vs
@@ -589,6 +590,24 @@ class AttentionWrapperTest(test.TestCase):
         expected_final_alignment_history=expected_final_alignment_history,
         name='testBahdanauMonotonicNormalized')
 
+  def testBahdanauMonotonicHard(self):
+    # Run attention mechanism with mode='hard', make sure probabilities are hard
+    b, t, u, d = 10, 20, 30, 40
+    with self.test_session(use_gpu=True) as sess:
+      a = wrapper.BahdanauMonotonicAttention(
+          d,
+          random_ops.random_normal((b, t, u)),
+          mode='hard')
+      # Just feed previous attention as [1, 0, 0, ...]
+      attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
+      sess.run(variables.global_variables_initializer())
+      attn_out = attn.eval()
+      # All values should be 0 or 1
+      self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1)))
+      # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0)
+      self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1,
+                                           attn_out.sum(axis=1) == 0)))
+
   def testLuongMonotonicNotNormalized(self):
     create_attention_mechanism = functools.partial(
         wrapper.LuongMonotonicAttention, sigmoid_noise=1.0,
@@ -695,6 +714,24 @@ class AttentionWrapperTest(test.TestCase):
         expected_final_alignment_history=expected_final_alignment_history,
         name='testMultiAttention')
 
+  def testLuongMonotonicHard(self):
+    # Run attention mechanism with mode='hard', make sure probabilities are hard
+    b, t, u, d = 10, 20, 30, 40
+    with self.test_session(use_gpu=True) as sess:
+      a = wrapper.LuongMonotonicAttention(
+          d,
+          random_ops.random_normal((b, t, u)),
+          mode='hard')
+      # Just feed previous attention as [1, 0, 0, ...]
+      attn = a(random_ops.random_normal((b, d)), array_ops.one_hot([0]*b, t))
+      sess.run(variables.global_variables_initializer())
+      attn_out = attn.eval()
+      # All values should be 0 or 1
+      self.assertTrue(np.all(np.logical_or(attn_out == 0, attn_out == 1)))
+      # Sum of distributions should be 0 or 1 (0 when all p_choose_i are 0)
+      self.assertTrue(np.all(np.logical_or(attn_out.sum(axis=1) == 1,
+                                           attn_out.sum(axis=1) == 0)))
+
   def testMultiAttentionNoAttentionLayer(self):
     create_attention_mechanisms = (
         wrapper.BahdanauAttention, wrapper.LuongAttention)
diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
index 839df079ee7..87230e33552 100644
--- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
+++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
@@ -679,7 +679,11 @@ def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode,
                                      seed=seed)
     score += sigmoid_noise*noise
   # Compute "choosing" probabilities from the attention scores
-  p_choose_i = math_ops.sigmoid(score)
+  if mode == "hard":
+    # When mode is hard, use a hard sigmoid
+    p_choose_i = math_ops.cast(score > 0, score.dtype)
+  else:
+    p_choose_i = math_ops.sigmoid(score)
   # Convert from choosing probabilities to attention distribution
   return monotonic_attention(p_choose_i, previous_alignments, mode)
 

From fff9b90a3f081b1dd0ca8ce5785f0e67c3557cce Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 7 Nov 2017 22:02:42 -0800
Subject: [PATCH 020/115] Update nsync version---allow compilation with bazel
 on x86_32

The new version of nsync has a BUILD file that detects
x86_32 (which bazel currently calls piii).

PiperOrigin-RevId: 174959924
---
 tensorflow/workspace.bzl | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 3081a8d1dcd..dfe332b091e 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -442,11 +442,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
   native.http_archive(
       name = "nsync",
       urls = [
-          "https://mirror.bazel.build/github.com/google/nsync/archive/4fc8ff3e7626c5f24bc9674438d8257f0ffc226c.tar.gz",
-          # "https://github.com/google/nsync/archive/4fc8ff3e7626c5f24bc9674438d8257f0ffc226c.tar.gz",
+          "https://mirror.bazel.build/github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz",
+          # "https://github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz",
       ],
-      sha256 = "ffbbe828f3d0bef75462e34801de5cea31d10aa63eaa42a4ed74c46521bdfd58",
-      strip_prefix = "nsync-4fc8ff3e7626c5f24bc9674438d8257f0ffc226c",
+      sha256 = "e3bd4555415ace511338fc27e595351738eea4e9006f1612b76c82914770716b",
+      strip_prefix = "nsync-93815892dddafe9146a5f7e7042281d59d0f4323",
   )
 
   native.http_archive(

From f0b1e65b0ac9e587c117485a96a0eaf40675c518 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 7 Nov 2017 22:33:54 -0800
Subject: [PATCH 021/115] Automated g4 rollback of changelist 174912490

PiperOrigin-RevId: 174961746
---
 .../contrib/cmake/tf_core_kernels.cmake       |   1 -
 tensorflow/contrib/cmake/tf_core_ops.cmake    |   1 -
 tensorflow/contrib/cmake/tf_python.cmake      |   2 -
 tensorflow/contrib/data/BUILD                 |  13 +-
 tensorflow/contrib/data/__init__.py           |   2 +-
 tensorflow/contrib/data/ops/dataset_ops.cc    | 232 ---------
 .../python/kernel_tests/iterator_ops_test.py  |   2 +-
 .../kernel_tests/range_dataset_op_test.py     |   2 +-
 .../kernel_tests/reader_dataset_ops_test.py   |   2 +-
 tensorflow/contrib/data/python/ops/BUILD      |  40 +-
 .../contrib/data/python/ops/batching.py       |   2 +-
 .../contrib/data/python/ops/dataset_ops.py    |   8 +-
 .../contrib/data/python/ops/error_ops.py      |   2 +-
 .../contrib/data/python/ops/grouping.py       |   2 +-
 .../contrib/data/python/ops/interleave_ops.py |   2 +-
 .../contrib/data/python/ops/iterator_ops.py   |   2 +-
 tensorflow/contrib/data/python/ops/readers.py |   2 +-
 .../contrib/data/python/ops/scan_ops.py       |   2 +-
 .../core/ops/compat/ops_history.v1.pbtxt      | 452 ++++++++++++++++++
 tensorflow/core/ops/dataset_ops.cc            | 197 ++++++++
 .../python/kernel_tests/iterator_ops_test.py  |  62 +++
 .../kernel_tests/range_dataset_op_test.py     | 330 +++++++++++++
 .../kernel_tests/reader_dataset_ops_test.py   | 298 ++++++++++++
 23 files changed, 1366 insertions(+), 292 deletions(-)
 delete mode 100644 tensorflow/contrib/data/ops/dataset_ops.cc

diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 5b62598aa58..f978c8ccd5a 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -70,7 +70,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
       "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc"
-      "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc"
       "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc"
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 03c168795cc..4a61ed7a354 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -81,7 +81,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t
 GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc")
 GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc")
 GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc")
-GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc")
 GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc")
 GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
 GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index a14b7331585..7636e9ba6e4 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -776,8 +776,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops"
   DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py)
 GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops"
   DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py)
-GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops"
-  DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py)
 GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops"
   DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py)
 GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops"
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 7bcf5a5f4dc..eaede0e00ec 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -35,19 +35,8 @@ tf_custom_op_library(
     ],
 )
 
-# TODO(mrry): Move the kernels out of the core library into this library.
-tf_custom_op_library(
-    name = "_dataset_ops.so",
-    srcs = [
-        "ops/dataset_ops.cc",
-    ],
-)
-
 tf_gen_op_libs(
-    op_lib_names = [
-        "dataset_ops",
-        "prefetching_ops",
-    ],
+    op_lib_names = ["prefetching_ops"],
 )
 
 filegroup(
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 0c7e7936892..824ac4298f8 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -41,8 +41,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
 # pylint: disable=unused-import
+
 from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder
 from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch
 from tensorflow.contrib.data.python.ops.batching import unbatch
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
deleted file mode 100644
index 1574384cb2b..00000000000
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ /dev/null
@@ -1,232 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def_builder.h"
-#include "tensorflow/core/framework/shape_inference.h"
-
-namespace tensorflow {
-
-// --------------------------------------------------------------------------
-
-// The ops in this section can be composed to define an input
-// pipeline. Each op produces a DT_VARIANT tensor that represents
-// a DAG of "dataset" objects. An "dataset" object can be converted
-// to a stateful "iterator" by passing the "dataset" to the
-// "MakeIterator" op.
-//
-// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are
-// not presently serializable. To avoid issues with constant folding, ensure
-// that any "source dataset" ops (i.e. ops that output a dataset and do not
-// take one as input) are marked "stateful".
-
-REGISTER_OP("IgnoreErrorsDataset")
-    .Input("input_dataset: variant")
-    .Output("handle: variant")
-    .Attr("output_types: list(type) >= 1")
-    .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-Creates a dataset that contains the elements of `input_dataset` ignoring errors.
-)doc");
-
-REGISTER_OP("MapAndBatchDataset")
-    .Input("input_dataset: variant")
-    .Input("other_arguments: Targuments")
-    .Input("batch_size: int64")
-    .Input("num_parallel_batches: int64")
-    .Output("handle: variant")
-    .Attr("f: func")
-    .Attr("Targuments: list(type) >= 0")
-    .Attr("output_types: list(type) >= 1")
-    .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-Creates a dataset that applies `f` to the outputs of `input_dataset` and then
-batches `batch_size` of them.
-
-Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
-to `batch_size * num_parallel_batches` copies of `f` in parallel.
-
-batch_size: A scalar representing the number of elements to accumulate in a
-  batch. It determines the number of concurrent invocations of `f` that process
-  elements from `input_dataset` in parallel.
-num_parallel_batches: A scalar representing the number of batches to create in
-  parallel. Processing multiple batches in parallel benefits workloads prone to
-  stragglers.
-)doc");
-
-REGISTER_OP("ScanDataset")
-    .Input("input_dataset: variant")
-    .Input("initial_state: Tstate")
-    .Input("other_arguments: Targuments")
-    .Output("handle: variant")
-    .Attr("f: func")
-    .Attr("Tstate: list(type) >= 1")
-    .Attr("Targuments: list(type) >= 0")
-    .Attr("output_types: list(type) >= 1")
-    .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-Creates a dataset successively reduces `f` over the elements of `input_dataset`.
-)doc");
-
-REGISTER_OP("ParallelInterleaveDataset")
-    .Input("input_dataset: variant")
-    .Input("other_arguments: Targuments")
-    .Input("cycle_length: int64")
-    .Input("block_length: int64")
-    .Input("sloppy: bool")
-    .Output("handle: variant")
-    .Attr("f: func")
-    .Attr("Targuments: list(type) >= 0")
-    .Attr("output_types: list(type) >= 1")
-    .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-Creates a dataset that applies `f` to the outputs of `input_dataset`.
-
-The resulting dataset is similar to the `InterleaveDataset`, with the exception
-that if retrieving the next value from a dataset would cause the requester to
-block, it will skip that input dataset. This dataset is especially useful
-when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
-allows the training step to proceed so long as some data is available.
-
-!! WARNING !! This dataset is not deterministic!
-
-f: A function mapping elements of `input_dataset`, concatenated with
-   `other_arguments`, to a Dataset variant that contains elements matching
-   `output_types` and `output_shapes`.
-)doc");
-
-REGISTER_OP("GroupByWindowDataset")
-    .Input("input_dataset: variant")
-    .Input("key_func_other_arguments: Tkey_func_other_arguments")
-    .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
-    .Input(
-        "window_size_func_other_arguments: Twindow_size_func_other_arguments")
-    .Output("handle: variant")
-    .Attr("key_func: func")
-    .Attr("reduce_func: func")
-    .Attr("window_size_func: func")
-    .Attr("Tkey_func_other_arguments: list(type) >= 0")
-    .Attr("Treduce_func_other_arguments: list(type) >= 0")
-    .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
-    .Attr("output_types: list(type) >= 1")
-    .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-Creates a dataset that computes a windowed group-by on `input_dataset`.
-
-// TODO(mrry): Support non-int64 keys.
-
-key_func: A function mapping an element of `input_dataset`, concatenated
-  with `key_func_other_arguments` to a scalar value of type DT_INT64.
-)doc");
-
-REGISTER_OP("DenseToSparseBatchDataset")
-    .Input("input_dataset: variant")
-    .Input("batch_size: int64")
-    .Input("row_shape: int64")
-    .Output("handle: variant")
-    // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
-    .Attr("output_types: list(type) >= 1")
-    // NOTE(mrry): the 1st and 2nd elements will be vectors.
-    .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-Creates a dataset that yields a SparseTensor for each element of the input.
-
-input_dataset: A handle to an input dataset. Must have a single component.
-batch_size: A scalar representing the number of elements to accumulate in a
-  batch.
-row_shape: A vector representing the dense shape of each row in the produced
-  SparseTensor. The shape may be partially specified, using `-1` to indicate
-  that a particular dimension should use the maximum size of all batch elements.
-)doc");
-
-REGISTER_OP("SqlDataset")
-    .Input("driver_name: string")
-    .Input("data_source_name: string")
-    .Input("query: string")
-    .Output("handle: variant")
-    .Attr("output_types: list(type) >= 1")
-    .Attr("output_shapes: list(shape) >= 1")
-    .SetIsStateful()  // TODO(b/65524810): Source dataset ops must be marked
-                      // stateful to inhibit constant folding.
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-Creates a dataset that executes a SQL query and emits rows of the result set.
-
-driver_name: The database type. Currently, the only supported type is 'sqlite'.
-data_source_name: A connection string to connect to the database.
-query: A SQL query to execute.
-)doc");
-
-REGISTER_OP("DatasetToSingleElement")
-    .Input("dataset: variant")
-    .Output("components: output_types")
-    .Attr("output_types: list(type) >= 1")
-    .Attr("output_shapes: list(shape) >= 1")
-    .SetShapeFn([](shape_inference::InferenceContext* c) {
-      shape_inference::ShapeHandle unused;
-      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
-      std::vector<PartialTensorShape> output_shapes;
-      TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
-      if (output_shapes.size() != c->num_outputs()) {
-        return errors::InvalidArgument(
-            "`output_shapes` must be the same length as `output_types` (",
-            output_shapes.size(), " vs. ", c->num_outputs());
-      }
-      for (size_t i = 0; i < output_shapes.size(); ++i) {
-        shape_inference::ShapeHandle output_shape_handle;
-        TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
-            output_shapes[i], &output_shape_handle));
-        c->set_output(static_cast<int>(i), output_shape_handle);
-      }
-      return Status::OK();
-    })
-    .Doc(R"doc(
-Outputs the single element from the given dataset.
-
-dataset: A handle to a dataset that contains a single element.
-components: The components of the single element of `input`.
-)doc");
-
-REGISTER_OP("SerializeIterator")
-    .Input("resource_handle: resource")
-    .Output("serialized: variant")
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-Converts the given `resource_handle` representing an iterator to a variant tensor.
-
-resource_handle: A handle to an iterator resource.
-serialized: A variant tensor storing the state of the iterator contained in the
-  resource.
-)doc");
-
-REGISTER_OP("DeserializeIterator")
-    .Input("resource_handle: resource")
-    .Input("serialized: variant")
-    .SetShapeFn(shape_inference::NoOutputs)
-    .Doc(R"doc(
-Converts the given variant tensor to an iterator and stores it in the given resource.
-
-resource_handle: A handle to an iterator resource.
-serialized: A variant tensor storing the state of the iterator contained in the
-  resource.
-)doc");
-
-}  // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 271d80a54b5..bda9a2a4a37 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -21,7 +21,6 @@ import os
 import numpy as np
 
 from tensorflow.contrib.data.python.ops import dataset_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.contrib.data.python.ops import readers
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session
@@ -34,6 +33,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index 329dc80ba5a..f59ac760dc8 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -21,7 +21,6 @@ import os
 
 from tensorflow.contrib.data.python.ops import dataset_ops
 from tensorflow.contrib.data.python.ops import enumerate_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
 from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.framework import constant_op
@@ -30,6 +29,7 @@ from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import variables
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 8033f1d3880..3ae8f71d77f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -21,7 +21,6 @@ import gzip
 import os
 import zlib
 
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
 from tensorflow.contrib.data.python.ops import readers
 from tensorflow.core.example import example_pb2
@@ -34,6 +33,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.lib.io import python_io
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 727c5d1c38b..1b81cf5be91 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -11,6 +11,20 @@ load(
 )
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
 
+py_library(
+    name = "dataset_ops",
+    srcs = [
+        "dataset_ops.py",
+    ],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":transformation_ops",
+        "//tensorflow/python:util",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/data/util:nest",
+    ],
+)
+
 py_library(
     name = "iterator_ops",
     srcs = [
@@ -59,7 +73,6 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
-        ":gen_dataset_ops",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:dataset_ops_gen",
@@ -115,31 +128,6 @@ tf_custom_op_py_library(
     ],
 )
 
-tf_gen_op_wrapper_py(
-    name = "gen_dataset_ops",
-    out = "gen_dataset_ops.py",
-    deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"],
-)
-
-tf_custom_op_py_library(
-    name = "dataset_ops",
-    srcs = ["dataset_ops.py"],
-    dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
-    kernels = [
-        "//tensorflow/contrib/data:dataset_ops_op_lib",
-    ],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":gen_dataset_ops",
-        ":transformation_ops",
-        "//tensorflow/contrib/util:util_py",
-        "//tensorflow/python:platform",
-        "//tensorflow/python:util",
-        "//tensorflow/python/data/ops:dataset_ops",
-        "//tensorflow/python/data/util:nest",
-    ],
-)
-
 filegroup(
     name = "all_files",
     srcs = glob(
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index e6e5f716b62..abc9212a875 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.framework import dtypes
@@ -25,6 +24,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import math_ops
 
 
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index c4c4426809a..45d6dbe7438 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -20,21 +20,15 @@ from __future__ import print_function
 from tensorflow.contrib.data.python.ops import batching
 from tensorflow.contrib.data.python.ops import enumerate_ops
 from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.contrib.data.python.ops import grouping
 
-from tensorflow.contrib.util import loader
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gen_io_ops
-from tensorflow.python.platform import resource_loader
 from tensorflow.python.util import deprecation
 
 
-_dataset_ops = loader.load_op_library(
-    resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
-
-
 class Dataset(dataset_ops.Dataset):
   """Represents a potentially large set of elements.
 
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 51a27910723..238bb52b020 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,9 +17,9 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
+from tensorflow.python.ops import gen_dataset_ops
 
 
 def ignore_errors():
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 1c7c94b3c84..6df7b22fb69 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -17,12 +17,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
 
 
 def group_by_window(key_func,
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index ce23e95697c..74a919c1fff 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -17,12 +17,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.util import deprecation
 
 
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index 32d2f42c935..d736029fb03 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.training import saver
 
 
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index f22298b757c..2e1c3153ca7 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -18,7 +18,6 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import readers
 from tensorflow.python.data.util import nest
@@ -26,6 +25,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import gfile
 from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 87bbbb7d19b..5acaed48a3d 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -19,11 +19,11 @@ from __future__ import print_function
 
 import collections
 
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
 
 
 class _ScanDataset(dataset_ops.Dataset):
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 8b8251f84be..a4b5ca16af7 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -8270,6 +8270,29 @@ op {
     }
   }
 }
+op {
+  name: "DatasetToSingleElement"
+  input_arg {
+    name: "dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "components"
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
 op {
   name: "DebugGradientIdentity"
   input_arg {
@@ -9248,6 +9271,69 @@ op {
     }
   }
 }
+op {
+  name: "DenseToSparseBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "row_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "DenseToSparseBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "row_shape"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
 op {
   name: "DenseToSparseSetOperation"
   input_arg {
@@ -9741,6 +9827,18 @@ op {
     }
   }
 }
+op {
+  name: "DeserializeIterator"
+  input_arg {
+    name: "resource_handle"
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "serialized"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
 op {
   name: "DeserializeManySparse"
   input_arg {
@@ -13494,6 +13592,131 @@ op {
     }
   }
 }
+op {
+  name: "GroupByWindowDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "key_func_other_arguments"
+    type_list_attr: "Tkey_func_other_arguments"
+  }
+  input_arg {
+    name: "reduce_func_other_arguments"
+    type_list_attr: "Treduce_func_other_arguments"
+  }
+  input_arg {
+    name: "window_size_func_other_arguments"
+    type_list_attr: "Twindow_size_func_other_arguments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "key_func"
+    type: "func"
+  }
+  attr {
+    name: "reduce_func"
+    type: "func"
+  }
+  attr {
+    name: "window_size_func"
+    type: "func"
+  }
+  attr {
+    name: "Tkey_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Treduce_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Twindow_size_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "GroupByWindowDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "key_func_other_arguments"
+    type_list_attr: "Tkey_func_other_arguments"
+  }
+  input_arg {
+    name: "reduce_func_other_arguments"
+    type_list_attr: "Treduce_func_other_arguments"
+  }
+  input_arg {
+    name: "window_size_func_other_arguments"
+    type_list_attr: "Twindow_size_func_other_arguments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "key_func"
+    type: "func"
+  }
+  attr {
+    name: "reduce_func"
+    type: "func"
+  }
+  attr {
+    name: "window_size_func"
+    type: "func"
+  }
+  attr {
+    name: "Tkey_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Treduce_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Twindow_size_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
 op {
   name: "HSVToRGB"
   input_arg {
@@ -13914,6 +14137,53 @@ op {
     }
   }
 }
+op {
+  name: "IgnoreErrorsDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
+op {
+  name: "IgnoreErrorsDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
 op {
   name: "Imag"
   input_arg {
@@ -15818,6 +16088,50 @@ op {
   }
   is_stateful: true
 }
+op {
+  name: "MapAndBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "batch_size"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_parallel_batches"
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
 op {
   name: "MapClear"
   attr {
@@ -20556,6 +20870,54 @@ op {
     type: "type"
   }
 }
+op {
+  name: "ParallelInterleaveDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sloppy"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
 op {
   name: "ParallelMapDataset"
   input_arg {
@@ -30146,6 +30508,52 @@ op {
     }
   }
 }
+op {
+  name: "ScanDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "initial_state"
+    type_list_attr: "Tstate"
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Tstate"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+}
 op {
   name: "ScatterAdd"
   input_arg {
@@ -31861,6 +32269,18 @@ op {
     }
   }
 }
+op {
+  name: "SerializeIterator"
+  input_arg {
+    name: "resource_handle"
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "serialized"
+    type: DT_VARIANT
+  }
+  is_stateful: true
+}
 op {
   name: "SerializeManySparse"
   input_arg {
@@ -37265,6 +37685,38 @@ op {
     }
   }
 }
+op {
+  name: "SqlDataset"
+  input_arg {
+    name: "driver_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data_source_name"
+    type: DT_STRING
+  }
+  input_arg {
+    name: "query"
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  is_stateful: true
+}
 op {
   name: "Sqrt"
   input_arg {
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 8f5d8308a3d..f5122139645 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -141,6 +141,16 @@ count: A scalar representing the number of elements from the `input_dataset`
   that should be skipped.  If count is -1, skips everything.
 )doc");
 
+REGISTER_OP("IgnoreErrorsDataset")
+    .Input("input_dataset: variant")
+    .Output("handle: variant")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+)doc");
+
 REGISTER_OP("MapDataset")
     .Input("input_dataset: variant")
     .Input("other_arguments: Targuments")
@@ -174,6 +184,32 @@ num_parallel_calls: The number of concurrent invocations of `f` that process
   elements from `input_dataset` in parallel.
 )doc");
 
+REGISTER_OP("MapAndBatchDataset")
+    .Input("input_dataset: variant")
+    .Input("other_arguments: Targuments")
+    .Input("batch_size: int64")
+    .Input("num_parallel_batches: int64")
+    .Output("handle: variant")
+    .Attr("f: func")
+    .Attr("Targuments: list(type) >= 0")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset` and then
+batches `batch_size` of them.
+
+Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
+to `batch_size * num_parallel_batches` copies of `f` in parallel.
+
+batch_size: A scalar representing the number of elements to accumulate in a
+  batch. It determines the number of concurrent invocations of `f` that process
+  elements from `input_dataset` in parallel.
+num_parallel_batches: A scalar representing the number of batches to create in
+  parallel. Processing multiple batches in parallel benefits workloads prone to
+  stragglers.
+)doc");
+
 REGISTER_OP("PrefetchDataset")
     .Input("input_dataset: variant")
     .Input("buffer_size: int64")
@@ -188,6 +224,21 @@ buffer_size: The maximum number of elements to buffer in an iterator over
   this dataset.
 )doc");
 
+REGISTER_OP("ScanDataset")
+    .Input("input_dataset: variant")
+    .Input("initial_state: Tstate")
+    .Input("other_arguments: Targuments")
+    .Output("handle: variant")
+    .Attr("f: func")
+    .Attr("Tstate: list(type) >= 1")
+    .Attr("Targuments: list(type) >= 0")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Creates a dataset successively reduces `f` over the elements of `input_dataset`.
+)doc");
+
 REGISTER_OP("FlatMapDataset")
     .Input("input_dataset: variant")
     .Input("other_arguments: Targuments")
@@ -234,6 +285,59 @@ f: A function mapping elements of `input_dataset`, concatenated with
   `output_types` and `output_shapes`.
 )doc");
 
+REGISTER_OP("ParallelInterleaveDataset")
+    .Input("input_dataset: variant")
+    .Input("other_arguments: Targuments")
+    .Input("cycle_length: int64")
+    .Input("block_length: int64")
+    .Input("sloppy: bool")
+    .Output("handle: variant")
+    .Attr("f: func")
+    .Attr("Targuments: list(type) >= 0")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset`.
+
+The resulting dataset is similar to the `InterleaveDataset`, with the exception
+that if retrieving the next value from a dataset would cause the requester to
+block, it will skip that input dataset. This dataset is especially useful
+when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
+allows the training step to proceed so long as some data is available.
+
+!! WARNING !! This dataset is not deterministic!
+
+f: A function mapping elements of `input_dataset`, concatenated with
+   `other_arguments`, to a Dataset variant that contains elements matching
+   `output_types` and `output_shapes`.
+)doc");
+
+REGISTER_OP("GroupByWindowDataset")
+    .Input("input_dataset: variant")
+    .Input("key_func_other_arguments: Tkey_func_other_arguments")
+    .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
+    .Input(
+        "window_size_func_other_arguments: Twindow_size_func_other_arguments")
+    .Output("handle: variant")
+    .Attr("key_func: func")
+    .Attr("reduce_func: func")
+    .Attr("window_size_func: func")
+    .Attr("Tkey_func_other_arguments: list(type) >= 0")
+    .Attr("Treduce_func_other_arguments: list(type) >= 0")
+    .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Creates a dataset that computes a windowed group-by on `input_dataset`.
+
+// TODO(mrry): Support non-int64 keys.
+
+key_func: A function mapping an element of `input_dataset`, concatenated
+  with `key_func_other_arguments` to a scalar value of type DT_INT64.
+)doc");
+
 REGISTER_OP("FilterDataset")
     .Input("input_dataset: variant")
     .Input("other_arguments: Targuments")
@@ -304,6 +408,27 @@ padding_values: A list of scalars containing the padding value to use for
   each of the outputs.
 )doc");
 
+REGISTER_OP("DenseToSparseBatchDataset")
+    .Input("input_dataset: variant")
+    .Input("batch_size: int64")
+    .Input("row_shape: int64")
+    .Output("handle: variant")
+    // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
+    .Attr("output_types: list(type) >= 1")
+    // NOTE(mrry): the 1st and 2nd elements will be vectors.
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Creates a dataset that yields a SparseTensor for each element of the input.
+
+input_dataset: A handle to an input dataset. Must have a single component.
+batch_size: A scalar representing the number of elements to accumulate in a
+  batch.
+row_shape: A vector representing the dense shape of each row in the produced
+  SparseTensor. The shape may be partially specified, using `-1` to indicate
+  that a particular dimension should use the maximum size of all batch elements.
+)doc");
+
 REGISTER_OP("RangeDataset")
     .Input("start: int64")
     .Input("stop: int64")
@@ -389,6 +514,24 @@ compression_type: A scalar containing either (i) the empty string (no
 buffer_size: A scalar containing the number of bytes to buffer.
 )doc");
 
+REGISTER_OP("SqlDataset")
+    .Input("driver_name: string")
+    .Input("data_source_name: string")
+    .Input("query: string")
+    .Output("handle: variant")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetIsStateful()  // TODO(b/65524810): Source dataset ops must be marked
+                      // stateful to inhibit constant folding.
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Creates a dataset that executes a SQL query and emits rows of the result set.
+
+driver_name: The database type. Currently, the only supported type is 'sqlite'.
+data_source_name: A connection string to connect to the database.
+query: A SQL query to execute.
+)doc");
+
 REGISTER_OP("FixedLengthRecordDataset")
     .Input("filenames: string")
     .Input("header_bytes: int64")
@@ -519,6 +662,36 @@ REGISTER_OP("IteratorGetNext")
 Gets the next output from the given iterator.
 )doc");
 
+REGISTER_OP("DatasetToSingleElement")
+    .Input("dataset: variant")
+    .Output("components: output_types")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn([](shape_inference::InferenceContext* c) {
+      shape_inference::ShapeHandle unused;
+      TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+      std::vector<PartialTensorShape> output_shapes;
+      TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+      if (output_shapes.size() != c->num_outputs()) {
+        return errors::InvalidArgument(
+            "`output_shapes` must be the same length as `output_types` (",
+            output_shapes.size(), " vs. ", c->num_outputs());
+      }
+      for (size_t i = 0; i < output_shapes.size(); ++i) {
+        shape_inference::ShapeHandle output_shape_handle;
+        TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+            output_shapes[i], &output_shape_handle));
+        c->set_output(static_cast<int>(i), output_shape_handle);
+      }
+      return Status::OK();
+    })
+    .Doc(R"doc(
+Outputs the single element from the given dataset.
+
+dataset: A handle to a dataset that contains a single element.
+components: The components of the single element of `input`.
+)doc");
+
 REGISTER_OP("IteratorToStringHandle")
     .Input("resource_handle: resource")
     .Output("string_handle: string")
@@ -547,4 +720,28 @@ output_shapes: If specified, defines the shape of each tuple component in an
   element produced by the resulting iterator.
 )doc");
 
+REGISTER_OP("SerializeIterator")
+    .Input("resource_handle: resource")
+    .Output("serialized: variant")
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Converts the given `resource_handle` representing an iterator to a variant tensor.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+  resource.
+)doc");
+
+REGISTER_OP("DeserializeIterator")
+    .Input("resource_handle: resource")
+    .Input("serialized: variant")
+    .SetShapeFn(shape_inference::NoOutputs)
+    .Doc(R"doc(
+Converts the given variant tensor to an iterator and stores it in the given resource.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+  resource.
+)doc");
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py
index 60a44b5b14a..2128ef4ae17 100644
--- a/tensorflow/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/kernel_tests/iterator_ops_test.py
@@ -17,12 +17,14 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os
 import numpy as np
 
 from tensorflow.core.protobuf import config_pb2
 from tensorflow.python.client import session
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import readers
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -31,7 +33,9 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import script_ops
@@ -533,6 +537,64 @@ class IteratorTest(test.TestCase):
                 target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
             })
 
+  def testIncorrectIteratorRestore(self):
+
+    def _path():
+      return os.path.join(self.get_temp_dir(), "iterator")
+
+    def _save_op(iterator_resource):
+      iterator_state_variant = gen_dataset_ops.serialize_iterator(
+          iterator_resource)
+      save_op = io_ops.write_file(
+          _path(), parsing_ops.serialize_tensor(iterator_state_variant))
+      return save_op
+
+    def _restore_op(iterator_resource):
+      iterator_state_variant = parsing_ops.parse_tensor(
+          io_ops.read_file(_path()), dtypes.variant)
+      restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                        iterator_state_variant)
+      return restore_op
+
+    def _build_range_dataset_graph():
+      start = 1
+      stop = 10
+      iterator = dataset_ops.Dataset.range(start,
+                                           stop).make_initializable_iterator()
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = _save_op(iterator._iterator_resource)
+      restore_op = _restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    def _build_reader_dataset_graph():
+      filenames = ["test"]  # Does not exist but we don't care in this test.
+      iterator = readers.FixedLengthRecordDataset(
+          filenames, 1, 0, 0).make_initializable_iterator()
+      init_op = iterator.initializer
+      get_next_op = iterator.get_next()
+      save_op = _save_op(iterator._iterator_resource)
+      restore_op = _restore_op(iterator._iterator_resource)
+      return init_op, get_next_op, save_op, restore_op
+
+    # Saving iterator for RangeDataset graph.
+    with ops.Graph().as_default() as g:
+      init_op, _, save_op, _ = _build_range_dataset_graph()
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(save_op)
+
+    # Attempt to restore the saved iterator into an IteratorResource of
+    # incompatible type. An iterator of RangeDataset has output type int64,
+    # while an iterator of FixedLengthRecordDataset has output type string.
+    # So an InvalidArgumentError should be raised by
+    # IteratorResource::set_iterator.
+    with ops.Graph().as_default() as g:
+      _, _, _, restore_op = _build_reader_dataset_graph()
+      with self.test_session(graph=g) as sess:
+        with self.assertRaises(errors.InvalidArgumentError):
+          sess.run(restore_op)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py
index 3c1685c951f..0c530522b83 100644
--- a/tensorflow/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py
@@ -17,15 +17,32 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import os
+
 from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
 
 
 class RangeDatasetTest(test.TestCase):
 
+  def tearDown(self):
+    # Remove all checkpoint files.
+    prefix = self._iterator_checkpoint_prefix()
+    pattern = prefix + "*"
+    files = gfile.Glob(pattern)
+    map(gfile.Remove, files)
+
   def testStop(self):
     stop = array_ops.placeholder(dtypes.int64, shape=[])
     iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator()
@@ -151,6 +168,319 @@ class RangeDatasetTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
+  def _iterator_checkpoint_prefix(self):
+    return os.path.join(self.get_temp_dir(), "iterator")
+
+  def _save_op(self, iterator_resource):
+    iterator_state_variant = gen_dataset_ops.serialize_iterator(
+        iterator_resource)
+    save_op = io_ops.write_file(
+        self._iterator_checkpoint_prefix(),
+        parsing_ops.serialize_tensor(iterator_state_variant))
+    return save_op
+
+  def _restore_op(self, iterator_resource):
+    iterator_state_variant = parsing_ops.parse_tensor(
+        io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
+    restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                      iterator_state_variant)
+    return restore_op
+
+  def testSaveRestore(self):
+
+    def _build_graph(start, stop):
+      iterator = dataset_ops.Dataset.range(start,
+                                           stop).make_initializable_iterator()
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    # Saving and restoring in different sessions.
+    start = 2
+    stop = 10
+    break_point = 5
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, _ = _build_graph(start, stop)
+      with self.test_session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for i in range(start, break_point):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, _, restore_op = _build_graph(start, stop)
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(restore_op)
+        for i in range(break_point, stop):
+          self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+    # Saving and restoring in same session.
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+      with self.test_session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for i in range(start, break_point):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+        sess.run(restore_op)
+        for i in range(break_point, stop):
+          self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  def testRestoreWithoutBuildingDatasetGraph(self):
+
+    def _build_graph(start, stop, num_epochs):
+      dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
+      iterator = dataset.make_initializable_iterator()
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    # Saving and restoring in different sessions.
+    start = 2
+    stop = 10
+    num_epochs = 5
+    break_point = 5
+    break_epoch = 3
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for _ in range(break_epoch):
+          for i in range(start, stop):
+            self.assertEqual(i, sess.run(get_next))
+        for i in range(start, break_point):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      # Create an empty IteratorResource and restore the Iterator into it.
+      output_types = dtypes.int64
+      output_shapes = tensor_shape.scalar()
+      iterator = iterator_ops.Iterator.from_structure(output_types,
+                                                      output_shapes)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      get_next = iterator.get_next()
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        for i in range(break_point, stop):
+          self.assertEqual(i, sess.run(get_next))
+        for _ in range(break_epoch + 1, num_epochs):
+          for i in range(start, stop):
+            self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  def testRestoreInModifiedGraph(self):
+
+    def _build_graph(start, stop):
+      dataset = dataset_ops.Dataset.range(start, stop)
+      iterator = dataset.make_initializable_iterator()
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    # Saving and restoring in different sessions.
+    start = 2
+    stop = 10
+    stop_1 = 8
+    break_point = 5
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, _ = _build_graph(start, stop)
+      with self.test_session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for i in range(start, break_point):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      # Intentionally build a graph with a different value for stop to make sure
+      # the original dataset graph is actually getting loaded.
+      init_op, get_next, _, restore_op = _build_graph(start, stop_1)
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        for i in range(break_point, stop):
+          self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  def testInitThenRestore(self):
+    # Note: Calling init_op before restore_op is redundant. This test just makes
+    # sure we do not fail if restore is called on an already initialized
+    # iterator resource.
+
+    def _build_graph(start, stop):
+      dataset = dataset_ops.Dataset.range(start, stop)
+      iterator = dataset.make_initializable_iterator()
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    # Saving and restoring in different sessions.
+    start = 2
+    stop = 10
+    break_point = 5
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, _ = _build_graph(start, stop)
+      with self.test_session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for i in range(start, break_point):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, _, restore_op = _build_graph(start, stop)
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(restore_op)
+        for i in range(break_point, stop):
+          self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  def testMultipleSaves(self):
+
+    def _build_graph(start, stop):
+      iterator = dataset_ops.Dataset.range(start,
+                                           stop).make_initializable_iterator()
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    start = 2
+    stop = 10
+    break_point1 = 5
+    break_point2 = 7
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, _ = _build_graph(start, stop)
+      with self.test_session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        for i in range(start, break_point1):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        for i in range(break_point1, break_point2):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    break_point2 = 7
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        for i in range(break_point2, stop):
+          self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  def testSaveRestoreWithRepeat(self):
+
+    def _build_graph(start, stop, num_epochs):
+      iterator = dataset_ops.Dataset.range(
+          start, stop).repeat(num_epochs).make_initializable_iterator()
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    start = 2
+    stop = 10
+    num_epochs = 5
+    break_range = 5
+    break_epoch = 3
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(
+          start, stop, num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(restore_op)
+        for _ in range(break_epoch - 1):
+          for i in range(start, stop):
+            self.assertEqual(i, sess.run(get_next))
+        for i in range(start, break_range):
+          self.assertEqual(i, sess.run(get_next))
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        for i in range(break_range, stop):
+          self.assertEqual(i, sess.run(get_next))
+        for _ in range(break_epoch, num_epochs):
+          for i in range(start, stop):
+            self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
+  def testSaveRestoreExhaustedIterator(self):
+
+    def _build_graph(start, stop, num_epochs):
+      iterator = dataset_ops.Dataset.range(
+          start, stop).repeat(num_epochs).make_initializable_iterator()
+      init_op = iterator.initializer
+      get_next = iterator.get_next()
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
+      return init_op, get_next, save_op, restore_op
+
+    start = 2
+    stop = 10
+    num_epochs = 5
+    with ops.Graph().as_default() as g:
+      init_op, get_next, save_op, restore_op = _build_graph(
+          start, stop, num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(variables.global_variables_initializer())
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(restore_op)
+        for _ in range(num_epochs):
+          for i in range(start, stop):
+            self.assertEqual(i, sess.run(get_next))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next)
+
 
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
index 70b6ce442ea..c8e7333b4b9 100644
--- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
@@ -26,8 +26,13 @@ from tensorflow.python.data.ops import readers
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
 from tensorflow.python.lib.io import python_io
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import test
 from tensorflow.python.util import compat
 
@@ -267,6 +272,299 @@ class FixedLengthRecordReaderTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(iterator.get_next())
 
+  def _iterator_checkpoint_path(self):
+    return os.path.join(self.get_temp_dir(), "iterator")
+
+  def _save_op(self, iterator_resource):
+    iterator_state_variant = gen_dataset_ops.serialize_iterator(
+        iterator_resource)
+    save_op = io_ops.write_file(
+        self._iterator_checkpoint_path(),
+        parsing_ops.serialize_tensor(iterator_state_variant))
+    return save_op
+
+  def _restore_op(self, iterator_resource):
+    iterator_state_variant = parsing_ops.parse_tensor(
+        io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
+    restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                      iterator_state_variant)
+    return restore_op
+
+  def _build_iterator_graph(self, num_epochs):
+    filenames = self._createFiles()
+    dataset = (readers.FixedLengthRecordDataset(
+        filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
+               .repeat(num_epochs))
+    iterator = dataset.make_initializable_iterator()
+    init_op = iterator.initializer
+    get_next_op = iterator.get_next()
+    save_op = self._save_op(iterator._iterator_resource)
+    restore_op = self._restore_op(iterator._iterator_resource)
+    return init_op, get_next_op, save_op, restore_op
+
+  def _restore_iterator(self):
+    output_types = dtypes.string
+    output_shapes = tensor_shape.scalar()
+    iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
+    get_next = iterator.get_next()
+    restore_op = self._restore_op(iterator._iterator_resource)
+    return restore_op, get_next
+
+  def testSaveRestore(self):
+    num_epochs = 10
+    epoch_break = 5
+    file_break = self._num_files // 2
+    record_break = self._num_records // 2
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(restore_op)
+        for epoch in range(num_epochs):
+          for f in range(self._num_files):
+            for r in range(self._num_records):
+              if (epoch == epoch_break and f == file_break and
+                  r == record_break):
+                sess.run(save_op)
+                break
+              self.assertEqual(self._record(f, r), sess.run(get_next_op))
+            else:
+              continue
+            break
+          else:
+            continue
+          break
+        else:
+          with self.assertRaises(errors.OutOfRangeError):
+            sess.run(get_next_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        for epoch in range(num_epochs):
+          for f in range(self._num_files):
+            for r in range(self._num_records):
+              if (epoch < epoch_break or
+                  (epoch == epoch_break and f < file_break) or
+                  (epoch == epoch_break and f == file_break and
+                   r < record_break)):
+                continue
+              self.assertEqual(self._record(f, r), sess.run(get_next_op))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next_op)
+
+  def testInitThenRestore(self):
+    # Note: Calling init_op before restore_op is redundant. This test just makes
+    # sure we do not fail if restore is called on an already initialized
+    # iterator resource.
+    num_epochs = 10
+    epoch_break = 5
+    file_break = self._num_files // 2
+    record_break = self._num_records // 2
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(restore_op)
+        for epoch in range(num_epochs):
+          for f in range(self._num_files):
+            for r in range(self._num_records):
+              if (epoch == epoch_break and f == file_break and
+                  r == record_break):
+                sess.run(save_op)
+                break
+              self.assertEqual(self._record(f, r), sess.run(get_next_op))
+            else:
+              continue
+            break
+          else:
+            continue
+          break
+        else:
+          with self.assertRaises(errors.OutOfRangeError):
+            sess.run(get_next_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        sess.run(restore_op)
+        for epoch in range(num_epochs):
+          for f in range(self._num_files):
+            for r in range(self._num_records):
+              if (epoch < epoch_break or
+                  (epoch == epoch_break and f < file_break) or
+                  (epoch == epoch_break and f == file_break and
+                   r < record_break)):
+                continue
+              self.assertEqual(self._record(f, r), sess.run(get_next_op))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next_op)
+
+  def testRestoreInModifiedGraph(self):
+    num_epochs = 10
+    num_epochs_1 = 20
+    epoch_break = 5
+    file_break = self._num_files // 2
+    record_break = self._num_records // 2
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(restore_op)
+        for epoch in range(num_epochs):
+          for f in range(self._num_files):
+            for r in range(self._num_records):
+              if (epoch == epoch_break and f == file_break and
+                  r == record_break):
+                sess.run(save_op)
+                break
+              self.assertEqual(self._record(f, r), sess.run(get_next_op))
+            else:
+              continue
+            break
+          else:
+            continue
+          break
+        else:
+          with self.assertRaises(errors.OutOfRangeError):
+            sess.run(get_next_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs_1)
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        for epoch in range(num_epochs):
+          for f in range(self._num_files):
+            for r in range(self._num_records):
+              if (epoch < epoch_break or
+                  (epoch == epoch_break and f < file_break) or
+                  (epoch == epoch_break and f == file_break and
+                   r < record_break)):
+                continue
+              self.assertEqual(self._record(f, r), sess.run(get_next_op))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next_op)
+
+  def testRestoreWithoutBuildingDatasetGraph(self):
+    num_epochs = 10
+    epoch_break = 5
+    file_break = self._num_files // 2
+    record_break = self._num_records // 2
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(restore_op)
+        for epoch in range(num_epochs):
+          for f in range(self._num_files):
+            for r in range(self._num_records):
+              if (epoch == epoch_break and f == file_break and
+                  r == record_break):
+                sess.run(save_op)
+                break
+              self.assertEqual(self._record(f, r), sess.run(get_next_op))
+            else:
+              continue
+            break
+          else:
+            continue
+          break
+        else:
+          with self.assertRaises(errors.OutOfRangeError):
+            sess.run(get_next_op)
+
+    with ops.Graph().as_default() as g:
+      restore_op, get_next_op = self._restore_iterator()
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        for epoch in range(num_epochs):
+          for f in range(self._num_files):
+            for r in range(self._num_records):
+              if (epoch < epoch_break or
+                  (epoch == epoch_break and f < file_break) or
+                  (epoch == epoch_break and f == file_break and
+                   r < record_break)):
+                continue
+              self.assertEqual(self._record(f, r), sess.run(get_next_op))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next_op)
+
+  def testRestoreUnusedIterator(self):
+    num_epochs = 10
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(restore_op)
+        # Save unused iterator.
+        sess.run(save_op)
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        for _ in range(num_epochs * self._num_files * self._num_records):
+          sess.run(get_next_op)
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next_op)
+
+  def testRestoreExhaustedIterator(self):
+    num_epochs = 10
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(init_op)
+        # Note: There is no checkpoint saved currently so a NotFoundError is
+        # raised.
+        with self.assertRaises(errors.NotFoundError):
+          sess.run(restore_op)
+        for _ in range(num_epochs):
+          for f in range(self._num_files):
+            for r in range(self._num_records):
+              self.assertEqual(self._record(f, r), sess.run(get_next_op))
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next_op)
+        sess.run(save_op)
+
+    with ops.Graph().as_default() as g:
+      init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+          num_epochs=num_epochs)
+      with self.test_session(graph=g) as sess:
+        sess.run(restore_op)
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(get_next_op)
+
 
 class TFRecordDatasetTest(test.TestCase):
 

From 5199923383856f9e3bdee40b6f7f976328b42e09 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 7 Nov 2017 22:41:22 -0800
Subject: [PATCH 022/115] Go: Update generated wrapper functions for TensorFlow
 ops.

PiperOrigin-RevId: 174962378
---
 tensorflow/go/op/wrappers.go | 411 +++++++++++++++++++++++------------
 1 file changed, 277 insertions(+), 134 deletions(-)

diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 4e5d17f76fd..bdfad485673 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3983,41 +3983,6 @@ func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value t
 	return op.Output(0)
 }
 
-// Identity op for gradient debugging.
-//
-// This op is hidden from public in Python. It is used by TensorFlow Debugger to
-// register gradient tensors for gradient debugging.
-func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "DebugGradientIdentity",
-		Input: []tf.Input{
-			input,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
-// Deprecated. Use TensorArrayGradV3
-func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"source": source}
-	opspec := tf.OpSpec{
-		Type: "TensorArrayGradV2",
-		Input: []tf.Input{
-			handle, flow_in,
-		},
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // Get the current size of the TensorArray.
 //
 // Arguments:
@@ -4551,31 +4516,6 @@ func QueueCloseV2(scope *Scope, handle tf.Output, optional ...QueueCloseV2Attr)
 	return scope.AddOperation(opspec)
 }
 
-// Concatenates tensors along one dimension.
-//
-// Arguments:
-//	values: List of `N` Tensors to concatenate. Their ranks and types must match,
-// and their sizes must match in all dimensions except `concat_dim`.
-//	axis: 0-D.  The dimension along which to concatenate.  Must be in the
-// range [-rank(values), rank(values)).
-//
-// Returns A `Tensor` with the concatenation of values stacked along the
-// `concat_dim` dimension.  This tensor's shape matches that of `values` except
-// in `concat_dim` where it has the sum of the sizes.
-func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	opspec := tf.OpSpec{
-		Type: "ConcatV2",
-		Input: []tf.Input{
-			tf.OutputList(values), axis,
-		},
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // QueueDequeueUpToV2Attr is an optional argument to QueueDequeueUpToV2.
 type QueueDequeueUpToV2Attr func(optionalAttr)
 
@@ -4992,80 +4932,6 @@ func PriorityQueueV2(scope *Scope, shapes []tf.Shape, optional ...PriorityQueueV
 	return op.Output(0)
 }
 
-// FIFOQueueV2Attr is an optional argument to FIFOQueueV2.
-type FIFOQueueV2Attr func(optionalAttr)
-
-// FIFOQueueV2Shapes sets the optional shapes attribute to value.
-//
-// value: The shape of each component in a value. The length of this attr must
-// be either 0 or the same as the length of component_types. If the length of
-// this attr is 0, the shapes of queue elements are not constrained, and
-// only one element may be dequeued at a time.
-// If not specified, defaults to <>
-//
-// REQUIRES: len(value) >= 0
-func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr {
-	return func(m optionalAttr) {
-		m["shapes"] = value
-	}
-}
-
-// FIFOQueueV2Capacity sets the optional capacity attribute to value.
-//
-// value: The upper bound on the number of elements in this queue.
-// Negative numbers mean no limit.
-// If not specified, defaults to -1
-func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr {
-	return func(m optionalAttr) {
-		m["capacity"] = value
-	}
-}
-
-// FIFOQueueV2Container sets the optional container attribute to value.
-//
-// value: If non-empty, this queue is placed in the given container.
-// Otherwise, a default container is used.
-// If not specified, defaults to ""
-func FIFOQueueV2Container(value string) FIFOQueueV2Attr {
-	return func(m optionalAttr) {
-		m["container"] = value
-	}
-}
-
-// FIFOQueueV2SharedName sets the optional shared_name attribute to value.
-//
-// value: If non-empty, this queue will be shared under the given name
-// across multiple sessions.
-// If not specified, defaults to ""
-func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr {
-	return func(m optionalAttr) {
-		m["shared_name"] = value
-	}
-}
-
-// A queue that produces elements in first-in first-out order.
-//
-// Arguments:
-//	component_types: The type of each component in a value.
-//
-// Returns The handle to the queue.
-func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) {
-	if scope.Err() != nil {
-		return
-	}
-	attrs := map[string]interface{}{"component_types": component_types}
-	for _, a := range optional {
-		a(attrs)
-	}
-	opspec := tf.OpSpec{
-		Type: "FIFOQueueV2",
-
-		Attrs: attrs,
-	}
-	op := scope.AddOperation(opspec)
-	return op.Output(0)
-}
-
 // StridedSliceAttr is an optional argument to StridedSlice.
 type StridedSliceAttr func(optionalAttr)
 
@@ -5445,6 +5311,101 @@ func DynamicStitch(scope *Scope, indices []tf.Output, data []tf.Output) (merged
 	return op.Output(0)
 }
 
+// FIFOQueueV2Attr is an optional argument to FIFOQueueV2.
+type FIFOQueueV2Attr func(optionalAttr)
+
+// FIFOQueueV2Shapes sets the optional shapes attribute to value.
+//
+// value: The shape of each component in a value. The length of this attr must
+// be either 0 or the same as the length of component_types. If the length of
+// this attr is 0, the shapes of queue elements are not constrained, and
+// only one element may be dequeued at a time.
+// If not specified, defaults to <>
+//
+// REQUIRES: len(value) >= 0
+func FIFOQueueV2Shapes(value []tf.Shape) FIFOQueueV2Attr {
+	return func(m optionalAttr) {
+		m["shapes"] = value
+	}
+}
+
+// FIFOQueueV2Capacity sets the optional capacity attribute to value.
+//
+// value: The upper bound on the number of elements in this queue.
+// Negative numbers mean no limit.
+// If not specified, defaults to -1
+func FIFOQueueV2Capacity(value int64) FIFOQueueV2Attr {
+	return func(m optionalAttr) {
+		m["capacity"] = value
+	}
+}
+
+// FIFOQueueV2Container sets the optional container attribute to value.
+//
+// value: If non-empty, this queue is placed in the given container.
+// Otherwise, a default container is used.
+// If not specified, defaults to ""
+func FIFOQueueV2Container(value string) FIFOQueueV2Attr {
+	return func(m optionalAttr) {
+		m["container"] = value
+	}
+}
+
+// FIFOQueueV2SharedName sets the optional shared_name attribute to value.
+//
+// value: If non-empty, this queue will be shared under the given name
+// across multiple sessions.
+// If not specified, defaults to ""
+func FIFOQueueV2SharedName(value string) FIFOQueueV2Attr {
+	return func(m optionalAttr) {
+		m["shared_name"] = value
+	}
+}
+
+// A queue that produces elements in first-in first-out order.
+//
+// Arguments:
+//	component_types: The type of each component in a value.
+//
+// Returns The handle to the queue.
+func FIFOQueueV2(scope *Scope, component_types []tf.DataType, optional ...FIFOQueueV2Attr) (handle tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"component_types": component_types}
+	for _, a := range optional {
+		a(attrs)
+	}
+	opspec := tf.OpSpec{
+		Type: "FIFOQueueV2",
+
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Converts the given `resource_handle` representing an iterator to a variant tensor.
+//
+// Arguments:
+//	resource_handle: A handle to an iterator resource.
+//
+// Returns A variant tensor storing the state of the iterator contained in the
+// resource.
+func SerializeIterator(scope *Scope, resource_handle tf.Output) (serialized tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "SerializeIterator",
+		Input: []tf.Input{
+			resource_handle,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Return a tensor with the same shape and contents as the input tensor or value.
 func Identity(scope *Scope, input tf.Output) (output tf.Output) {
 	if scope.Err() != nil {
@@ -5576,6 +5537,39 @@ func IteratorToStringHandle(scope *Scope, resource_handle tf.Output) (string_han
 	return op.Output(0)
 }
 
+// Outputs the single element from the given dataset.
+//
+// Arguments:
+//	dataset: A handle to a dataset that contains a single element.
+//
+//
+//
+// Returns The components of the single element of `input`.
+func DatasetToSingleElement(scope *Scope, dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+	opspec := tf.OpSpec{
+		Type: "DatasetToSingleElement",
+		Input: []tf.Input{
+			dataset,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	if scope.Err() != nil {
+		return
+	}
+	var idx int
+	var err error
+	if components, idx, err = makeOutputList(op, idx, "components"); err != nil {
+		scope.UpdateErr("DatasetToSingleElement", err)
+		return
+	}
+	return components
+}
+
 // Gets the next output from the given iterator.
 func IteratorGetNext(scope *Scope, iterator tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (components []tf.Output) {
 	if scope.Err() != nil {
@@ -5696,6 +5690,30 @@ func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf
 	return op.Output(0)
 }
 
+// Creates a dataset that executes a SQL query and emits rows of the result set.
+//
+// Arguments:
+//	driver_name: The database type. Currently, the only supported type is 'sqlite'.
+//	data_source_name: A connection string to connect to the database.
+//	query: A SQL query to execute.
+//
+//
+func SqlDataset(scope *Scope, driver_name tf.Output, data_source_name tf.Output, query tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+	opspec := tf.OpSpec{
+		Type: "SqlDataset",
+		Input: []tf.Input{
+			driver_name, data_source_name, query,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // PlaceholderAttr is an optional argument to Placeholder.
 type PlaceholderAttr func(optionalAttr)
 
@@ -5766,6 +5784,68 @@ func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, out
 	return op.Output(0)
 }
 
+// Identity op for gradient debugging.
+//
+// This op is hidden from public in Python. It is used by TensorFlow Debugger to
+// register gradient tensors for gradient debugging.
+func DebugGradientIdentity(scope *Scope, input tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "DebugGradientIdentity",
+		Input: []tf.Input{
+			input,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Deprecated. Use TensorArrayGradV3
+func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"source": source}
+	opspec := tf.OpSpec{
+		Type: "TensorArrayGradV2",
+		Input: []tf.Input{
+			handle, flow_in,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Creates a dataset that yields a SparseTensor for each element of the input.
+//
+// Arguments:
+//	input_dataset: A handle to an input dataset. Must have a single component.
+//	batch_size: A scalar representing the number of elements to accumulate in a
+// batch.
+//	row_shape: A vector representing the dense shape of each row in the produced
+// SparseTensor. The shape may be partially specified, using `-1` to indicate
+// that a particular dimension should use the maximum size of all batch elements.
+//
+//
+func DenseToSparseBatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, row_shape tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+	opspec := tf.OpSpec{
+		Type: "DenseToSparseBatchDataset",
+		Input: []tf.Input{
+			input_dataset, batch_size, row_shape,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Creates a dataset that batches and pads `batch_size` elements from the input.
 //
 // Arguments:
@@ -5826,6 +5906,69 @@ func TensorArrayConcatV2(scope *Scope, handle tf.Output, flow_in tf.Output, dtyp
 	return op.Output(0), op.Output(1)
 }
 
+// Converts the given variant tensor to an iterator and stores it in the given resource.
+//
+// Arguments:
+//	resource_handle: A handle to an iterator resource.
+//	serialized: A variant tensor storing the state of the iterator contained in the
+// resource.
+//
+// Returns the created operation.
+func DeserializeIterator(scope *Scope, resource_handle tf.Output, serialized tf.Output) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "DeserializeIterator",
+		Input: []tf.Input{
+			resource_handle, serialized,
+		},
+	}
+	return scope.AddOperation(opspec)
+}
+
+// Concatenates tensors along one dimension.
+//
+// Arguments:
+//	values: List of `N` Tensors to concatenate. Their ranks and types must match,
+// and their sizes must match in all dimensions except `concat_dim`.
+//	axis: 0-D.  The dimension along which to concatenate.  Must be in the
+// range [-rank(values), rank(values)).
+//
+// Returns A `Tensor` with the concatenation of values stacked along the
+// `concat_dim` dimension.  This tensor's shape matches that of `values` except
+// in `concat_dim` where it has the sum of the sizes.
+func ConcatV2(scope *Scope, values []tf.Output, axis tf.Output) (output tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "ConcatV2",
+		Input: []tf.Input{
+			tf.OutputList(values), axis,
+		},
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
+// Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+func IgnoreErrorsDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+	if scope.Err() != nil {
+		return
+	}
+	attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+	opspec := tf.OpSpec{
+		Type: "IgnoreErrorsDataset",
+		Input: []tf.Input{
+			input_dataset,
+		},
+		Attrs: attrs,
+	}
+	op := scope.AddOperation(opspec)
+	return op.Output(0)
+}
+
 // Creates a dataset that concatenates `input_dataset` with `another_dataset`.
 func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
 	if scope.Err() != nil {

From eb49f78c38ef106f806f7698b374f4b28130025f Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 7 Nov 2017 23:19:03 -0800
Subject: [PATCH 023/115] Update ops-related pbtxt files.

PiperOrigin-RevId: 174964560
---
 tensorflow/core/ops/ops.pbtxt | 362 ++++++++++++++++++++++++++++++++++
 1 file changed, 362 insertions(+)

diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index d35decc1823..8353b45e225 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -6058,6 +6058,32 @@ op {
   summary: "Compute the cumulative sum of the tensor `x` along `axis`."
   description: "By default, this op performs an inclusive cumsum, which means that the first\nelement of the input is identical to the first element of the output:\n\n```python\ntf.cumsum([a, b, c])  # => [a, a + b, a + b + c]\n```\n\nBy setting the `exclusive` kwarg to `True`, an exclusive cumsum is\nperformed instead:\n\n```python\ntf.cumsum([a, b, c], exclusive=True)  # => [0, a, a + b]\n```\n\nBy setting the `reverse` kwarg to `True`, the cumsum is performed in the\nopposite direction:\n\n```python\ntf.cumsum([a, b, c], reverse=True)  # => [a + b + c, b + c, c]\n```\n\nThis is more efficient than using separate `tf.reverse` ops.\n\nThe `reverse` and `exclusive` kwargs can also be combined:\n\n```python\ntf.cumsum([a, b, c], exclusive=True, reverse=True)  # => [b + c, c, 0]\n```"
 }
+op {
+  name: "DatasetToSingleElement"
+  input_arg {
+    name: "dataset"
+    description: "A handle to a dataset that contains a single element."
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "components"
+    description: "The components of the single element of `input`."
+    type_list_attr: "output_types"
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  summary: "Outputs the single element from the given dataset."
+}
 op {
   name: "DebugGradientIdentity"
   input_arg {
@@ -6689,6 +6715,41 @@ op {
   summary: "Applies set operation along last dimension of 2 `Tensor` inputs."
   description: "See SetOperationOp::SetOperationFromContext for values of `set_operation`.\n\nOutput `result` is a `SparseTensor` represented by `result_indices`,\n`result_values`, and `result_shape`. For `set1` and `set2` ranked `n`, this\nhas rank `n` and the same 1st `n-1` dimensions as `set1` and `set2`. The `nth`\ndimension contains the result of `set_operation` applied to the corresponding\n`[0...n-1]` dimension of `set`."
 }
+op {
+  name: "DenseToSparseBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    description: "A handle to an input dataset. Must have a single component."
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "batch_size"
+    description: "A scalar representing the number of elements to accumulate in a\nbatch."
+    type: DT_INT64
+  }
+  input_arg {
+    name: "row_shape"
+    description: "A vector representing the dense shape of each row in the produced\nSparseTensor. The shape may be partially specified, using `-1` to indicate\nthat a particular dimension should use the maximum size of all batch elements."
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  summary: "Creates a dataset that yields a SparseTensor for each element of the input."
+}
 op {
   name: "DenseToSparseSetOperation"
   input_arg {
@@ -7028,6 +7089,21 @@ op {
   summary: "Dequantize the \'input\' tensor into a float Tensor."
   description: "[min_range, max_range] are scalar floats that specify the range for\nthe \'input\' data. The \'mode\' attribute controls exactly which calculations are\nused to convert the float values to their quantized equivalents.\n\nIn \'MIN_COMBINED\' mode, each value of the tensor will undergo the following:\n\n```\nif T == qint8, in[i] += (range(T) + 1)/ 2.0\nout[i] = min_range + (in[i]* (max_range - min_range) / range(T))\n```\nhere `range(T) = numeric_limits<T>::max() - numeric_limits<T>::min()`\n\n*MIN_COMBINED Mode Example*\n\nIf the input comes from a QuantizedRelu6, the output type is\nquint8 (range of 0-255) but the possible range of QuantizedRelu6 is\n0-6.  The min_range and max_range values are therefore 0.0 and 6.0.\nDequantize on quint8 will take each value, cast to float, and multiply\nby 6 / 255.\nNote that if quantizedtype is qint8, the operation will additionally add\neach value by 128 prior to casting.\n\nIf the mode is \'MIN_FIRST\', then this approach is used:\n\n```c++\nnum_discrete_values = 1 << (# of bits in T)\nrange_adjust = num_discrete_values / (num_discrete_values - 1)\nrange = (range_max - range_min) * range_adjust\nrange_scale = range / num_discrete_values\nconst double offset_input = static_cast<double>(input) - lowest_quantized;\nresult = range_min + ((input - numeric_limits<T>::min()) * range_scale)\n```\n\n*SCALED mode Example*\n\n`SCALED` mode matches the quantization approach used in\n`QuantizeAndDequantize{V2|V3}`.\n\nIf the mode is `SCALED`, we do not use the full range of the output type,\nchoosing to elide the lowest possible value for symmetry (e.g., output range is\n-127 to 127, not -128 to 127 for signed 8 bit quantization), so that 0.0 maps to\n0.\n\nWe first find the range of values in our tensor. The\nrange we use is always centered on 0, so we find m such that\n```c++\n  m = max(abs(input_min), abs(input_max))\n```\n\nOur input tensor range is then `[-m, m]`.\n\nNext, we choose our fixed-point quantization buckets, `[min_fixed, max_fixed]`.\nIf T is signed, this is\n```\n  num_bits = sizeof(T) * 8\n  [min_fixed, max_fixed] =\n      [-(1 << (num_bits - 1) - 1), (1 << (num_bits - 1)) - 1]\n```\n\nOtherwise, if T is unsigned, the fixed-point range is\n```\n  [min_fixed, max_fixed] = [0, (1 << num_bits) - 1]\n```\n\nFrom this we compute our scaling factor, s:\n```c++\n  s = (2 * m) / (max_fixed - min_fixed)\n```\n\nNow we can dequantize the elements of our tensor:\n```c++\nresult = input * s\n```"
 }
+op {
+  name: "DeserializeIterator"
+  input_arg {
+    name: "resource_handle"
+    description: "A handle to an iterator resource."
+    type: DT_RESOURCE
+  }
+  input_arg {
+    name: "serialized"
+    description: "A variant tensor storing the state of the iterator contained in the\nresource."
+    type: DT_VARIANT
+  }
+  summary: "Converts the given variant tensor to an iterator and stores it in the given resource."
+  is_stateful: true
+}
 op {
   name: "DeserializeManySparse"
   input_arg {
@@ -10142,6 +10218,71 @@ op {
   summary: "Returns the truth value of (x >= y) element-wise."
   description: "*NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting\n[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)"
 }
+op {
+  name: "GroupByWindowDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "key_func_other_arguments"
+    type_list_attr: "Tkey_func_other_arguments"
+  }
+  input_arg {
+    name: "reduce_func_other_arguments"
+    type_list_attr: "Treduce_func_other_arguments"
+  }
+  input_arg {
+    name: "window_size_func_other_arguments"
+    type_list_attr: "Twindow_size_func_other_arguments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "key_func"
+    type: "func"
+    description: "A function mapping an element of `input_dataset`, concatenated\nwith `key_func_other_arguments` to a scalar value of type DT_INT64."
+  }
+  attr {
+    name: "reduce_func"
+    type: "func"
+  }
+  attr {
+    name: "window_size_func"
+    type: "func"
+  }
+  attr {
+    name: "Tkey_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Treduce_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "Twindow_size_func_other_arguments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  summary: "Creates a dataset that computes a windowed group-by on `input_dataset`."
+  description: "// TODO(mrry): Support non-int64 keys."
+}
 op {
   name: "HSVToRGB"
   input_arg {
@@ -10602,6 +10743,30 @@ op {
   summary: "Compute the upper regularized incomplete Gamma function `Q(a, x)`."
   description: "The upper regularized incomplete Gamma function is defined as:\n\n\\\\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\\\)\n\nwhere\n\n\\\\(Gamma(a, x) = int_{x}^{\\infty} t^{a-1} exp(-t) dt\\\\)\n\nis the upper incomplete Gama function.\n\nNote, above `P(a, x)` (`Igamma`) is the lower regularized complete\nGamma function."
 }
+op {
+  name: "IgnoreErrorsDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  summary: "Creates a dataset that contains the elements of `input_dataset` ignoring errors."
+}
 op {
   name: "Imag"
   input_arg {
@@ -12373,6 +12538,54 @@ op {
   description: "This operation may be executed multiple times. Each execution will reset the\niterator in `iterator` to the first element of `dataset`."
   is_stateful: true
 }
+op {
+  name: "MapAndBatchDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "batch_size"
+    description: "A scalar representing the number of elements to accumulate in a\nbatch. It determines the number of concurrent invocations of `f` that process\nelements from `input_dataset` in parallel."
+    type: DT_INT64
+  }
+  input_arg {
+    name: "num_parallel_batches"
+    description: "A scalar representing the number of batches to create in\nparallel. Processing multiple batches in parallel benefits workloads prone to\nstragglers."
+    type: DT_INT64
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  summary: "Creates a dataset that applies `f` to the outputs of `input_dataset` and then"
+  description: "batches `batch_size` of them.\n\nUnlike a \"MapDataset\", which applies `f` sequentially, this dataset invokes up\nto `batch_size * num_parallel_batches` copies of `f` in parallel."
+}
 op {
   name: "MapClear"
   attr {
@@ -16043,6 +16256,57 @@ op {
   summary: "Interleave the values from the `data` tensors into a single tensor."
   description: "Builds a merged tensor such that\n\n```python\n    merged[indices[m][i, ..., j], ...] = data[m][i, ..., j, ...]\n```\n\nFor example, if each `indices[m]` is scalar or vector, we have\n\n```python\n    # Scalar indices:\n    merged[indices[m], ...] = data[m][...]\n\n    # Vector indices:\n    merged[indices[m][i], ...] = data[m][i, ...]\n```\n\nEach `data[i].shape` must start with the corresponding `indices[i].shape`,\nand the rest of `data[i].shape` must be constant w.r.t. `i`.  That is, we\nmust have `data[i].shape = indices[i].shape + constant`.  In terms of this\n`constant`, the output shape is\n\n    merged.shape = [max(indices)] + constant\n\nValues may be merged in parallel, so if an index appears in both `indices[m][i]`\nand `indices[n][j]`, the result may be invalid. This differs from the normal\nDynamicStitch operator that defines the behavior in that case.\n\nFor example:\n\n```python\n    indices[0] = 6\n    indices[1] = [4, 1]\n    indices[2] = [[5, 2], [0, 3]]\n    data[0] = [61, 62]\n    data[1] = [[41, 42], [11, 12]]\n    data[2] = [[[51, 52], [21, 22]], [[1, 2], [31, 32]]]\n    merged = [[1, 2], [11, 12], [21, 22], [31, 32], [41, 42],\n              [51, 52], [61, 62]]\n```\n\nThis method can be used to merge partitions created by `dynamic_partition`\nas illustrated on the following example:\n\n```python\n    # Apply function (increments x_i) on elements for which a certain condition\n    # apply (x_i != -1 in this example).\n    x=tf.constant([0.1, -1., 5.2, 4.3, -1., 7.4])\n    condition_mask=tf.not_equal(x,tf.constant(-1.))\n    partitioned_data = tf.dynamic_partition(\n        x, tf.cast(condition_mask, tf.int32) , 2)\n    partitioned_data[1] = partitioned_data[1] + 1.0\n    condition_indices = tf.dynamic_partition(\n        tf.range(tf.shape(x)[0]), tf.cast(condition_mask, tf.int32) , 2)\n    x = tf.dynamic_stitch(condition_indices, partitioned_data)\n    # Here x=[1.1, -1., 6.2, 5.3, -1, 8.4], the -1. values remain\n    # unchanged.\n```\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"https://www.tensorflow.org/images/DynamicStitch.png\" alt>\n</div>"
 }
+op {
+  name: "ParallelInterleaveDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  input_arg {
+    name: "cycle_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "block_length"
+    type: DT_INT64
+  }
+  input_arg {
+    name: "sloppy"
+    type: DT_BOOL
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+    description: "A function mapping elements of `input_dataset`, concatenated with\n`other_arguments`, to a Dataset variant that contains elements matching\n`output_types` and `output_shapes`."
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
+  description: "The resulting dataset is similar to the `InterleaveDataset`, with the exception\nthat if retrieving the next value from a dataset would cause the requester to\nblock, it will skip that input dataset. This dataset is especially useful\nwhen loading data from a variable-latency datastores (e.g. HDFS, GCS), as it\nallows the training step to proceed so long as some data is available.\n\n!! WARNING !! This dataset is not deterministic!"
+}
 op {
   name: "ParallelMapDataset"
   input_arg {
@@ -23850,6 +24114,53 @@ op {
   summary: "Outputs a `Summary` protocol buffer with scalar values."
   description: "The input `tags` and `values` must have the same shape.  The generated summary\nhas a summary value for each tag-value pair in `tags` and `values`."
 }
+op {
+  name: "ScanDataset"
+  input_arg {
+    name: "input_dataset"
+    type: DT_VARIANT
+  }
+  input_arg {
+    name: "initial_state"
+    type_list_attr: "Tstate"
+  }
+  input_arg {
+    name: "other_arguments"
+    type_list_attr: "Targuments"
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "f"
+    type: "func"
+  }
+  attr {
+    name: "Tstate"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "Targuments"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  summary: "Creates a dataset successively reduces `f` over the elements of `input_dataset`."
+}
 op {
   name: "ScatterAdd"
   input_arg {
@@ -25044,6 +25355,21 @@ op {
   }
   summary: "Computes gradients for the scaled exponential linear (Selu) operation."
 }
+op {
+  name: "SerializeIterator"
+  input_arg {
+    name: "resource_handle"
+    description: "A handle to an iterator resource."
+    type: DT_RESOURCE
+  }
+  output_arg {
+    name: "serialized"
+    description: "A variant tensor storing the state of the iterator contained in the\nresource."
+    type: DT_VARIANT
+  }
+  summary: "Converts the given `resource_handle` representing an iterator to a variant tensor."
+  is_stateful: true
+}
 op {
   name: "SerializeManySparse"
   input_arg {
@@ -28954,6 +29280,42 @@ op {
   }
   summary: "Splits a tensor into `num_split` tensors along one dimension."
 }
+op {
+  name: "SqlDataset"
+  input_arg {
+    name: "driver_name"
+    description: "The database type. Currently, the only supported type is \'sqlite\'."
+    type: DT_STRING
+  }
+  input_arg {
+    name: "data_source_name"
+    description: "A connection string to connect to the database."
+    type: DT_STRING
+  }
+  input_arg {
+    name: "query"
+    description: "A SQL query to execute."
+    type: DT_STRING
+  }
+  output_arg {
+    name: "handle"
+    type: DT_VARIANT
+  }
+  attr {
+    name: "output_types"
+    type: "list(type)"
+    has_minimum: true
+    minimum: 1
+  }
+  attr {
+    name: "output_shapes"
+    type: "list(shape)"
+    has_minimum: true
+    minimum: 1
+  }
+  summary: "Creates a dataset that executes a SQL query and emits rows of the result set."
+  is_stateful: true
+}
 op {
   name: "Sqrt"
   input_arg {

From f901742e656c9959e7d8a82d5713d24d96122058 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 02:25:01 -0800
Subject: [PATCH 024/115] Also register string types if __ANDROID_TYPES_FULL__
 is defined

PiperOrigin-RevId: 174979678
---
 tensorflow/contrib/makefile/tf_op_files.txt | 18 +++++++++++++++
 tensorflow/core/framework/register_types.h  |  5 +++--
 tensorflow/core/kernels/BUILD               | 25 +++++++++++++++++++++
 tensorflow/core/kernels/concat_lib_cpu.cc   |  9 +++++---
 4 files changed, 52 insertions(+), 5 deletions(-)

diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 8b77c99cb57..5f06106c1dc 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -8,6 +8,7 @@ tensorflow/core/kernels/xent_op.cc
 tensorflow/core/kernels/where_op.cc
 tensorflow/core/kernels/variable_ops.cc
 tensorflow/core/kernels/unpack_op.cc
+tensorflow/core/kernels/unique_op.cc
 tensorflow/core/kernels/transpose_op.cc
 tensorflow/core/kernels/transpose_functor_cpu.cc
 tensorflow/core/kernels/training_op_helpers.cc
@@ -41,6 +42,9 @@ tensorflow/core/kernels/spectrogram_op.cc
 tensorflow/core/kernels/spectrogram.cc
 tensorflow/core/kernels/sparse_to_dense_op.cc
 tensorflow/core/kernels/sparse_matmul_op.cc
+tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+tensorflow/core/kernels/sparse_reshape_op.c
+tensorflow/core/kernels/segment_reduction_ops.cc
 tensorflow/core/kernels/softsign_op.cc
 tensorflow/core/kernels/softplus_op.cc
 tensorflow/core/kernels/softmax_op.cc
@@ -109,6 +113,10 @@ tensorflow/core/kernels/maxpooling_op.cc
 tensorflow/core/kernels/matmul_op.cc
 tensorflow/core/kernels/lrn_op.cc
 tensorflow/core/kernels/logging_ops.cc
+tensorflow/core/kernels/initializable_lookup_table.c
+tensorflow/core/kernels/lookup_table_init_op.cc
+tensorflow/core/kernels/lookup_table_op.cc
+tensorflow/core/kernels/lookup_util.cc
 tensorflow/core/kernels/inplace_ops.cc
 tensorflow/core/kernels/in_topk_op.cc
 tensorflow/core/kernels/immutable_constant_op.cc
@@ -116,10 +124,18 @@ tensorflow/core/kernels/identity_op.cc
 tensorflow/core/kernels/identity_n_op.cc
 tensorflow/core/kernels/gather_op.cc
 tensorflow/core/kernels/gather_functor.cc
+tensorflow/core/kernels/gather_nd_op.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
 tensorflow/core/kernels/fused_batch_norm_op.cc
 tensorflow/core/kernels/function_ops.cc
 tensorflow/core/kernels/fill_functor.cc
 tensorflow/core/kernels/fifo_queue.cc
+tensorflow/core/kernels/fifo_queue_op.cc
 tensorflow/core/kernels/fake_quant_ops.cc
 tensorflow/core/kernels/example_parsing_ops.cc
 tensorflow/core/kernels/encode_wav_op.cc
@@ -166,6 +182,8 @@ tensorflow/core/kernels/cwise_op_floor.cc
 tensorflow/core/kernels/cwise_op_exp.cc
 tensorflow/core/kernels/cwise_op_equal_to_2.cc
 tensorflow/core/kernels/cwise_op_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
 tensorflow/core/kernels/cwise_op_div.cc
 tensorflow/core/kernels/cwise_op_bitwise_xor.cc
 tensorflow/core/kernels/cwise_op_bitwise_or.cc
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index c31ab18cc12..4bb37e4f6ed 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -87,7 +87,8 @@ limitations under the License.
 
 #elif defined(__ANDROID_TYPES_FULL__)
 
-// Only half, float, int32, int64, bool, and quantized types are supported.
+// Only string, half, float, int32, int64, bool, and quantized types
+// supported.
 #define TF_CALL_float(m) m(float)
 #define TF_CALL_double(m)
 #define TF_CALL_int32(m) m(::tensorflow::int32)
@@ -96,7 +97,7 @@ limitations under the License.
 #define TF_CALL_int16(m)
 
 #define TF_CALL_int8(m)
-#define TF_CALL_string(m)
+#define TF_CALL_string(m) m(string)
 #define TF_CALL_resource(m)
 #define TF_CALL_variant(m)
 #define TF_CALL_complex64(m)
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 34cd51ba66e..6206963251f 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4420,6 +4420,15 @@ filegroup(
         "fill_functor.h",
         "function_ops.cc",
         "gather_functor.h",
+        "gather_nd_op.cc",
+        "gather_nd_op.h",
+        "gather_nd_op_cpu_impl.h",
+        "gather_nd_op_cpu_impl_0.cc",
+        "gather_nd_op_cpu_impl_1.cc",
+        "gather_nd_op_cpu_impl_2.cc",
+        "gather_nd_op_cpu_impl_3.cc",
+        "gather_nd_op_cpu_impl_4.cc",
+        "gather_nd_op_cpu_impl_5.cc",
         "gather_op.cc",
         "identity_n_op.cc",
         "identity_n_op.h",
@@ -4513,6 +4522,10 @@ filegroup(
         "fused_batch_norm_op.h",
         "gemm_functors.h",
         "image_resizer_state.h",
+        "initializable_lookup_table.h",
+        "lookup_table_init_op.h",
+        "lookup_table_op.h",
+        "lookup_util.h",
         "maxpooling_op.h",
         "mfcc.h",
         "mfcc_dct.h",
@@ -4529,6 +4542,7 @@ filegroup(
         "resize_nearest_neighbor_op.h",
         "reverse_op.h",
         "save_restore_tensor.h",
+        "segment_reduction_ops.h",
         "softplus_op.h",
         "softsign_op.h",
         "spacetobatch_functor.h",
@@ -4578,6 +4592,8 @@ filegroup(
         "cwise_op_div.cc",
         "cwise_op_equal_to_1.cc",
         "cwise_op_equal_to_2.cc",
+        "cwise_op_not_equal_to_1.cc",
+        "cwise_op_not_equal_to_2.cc",
         "cwise_op_exp.cc",
         "cwise_op_floor.cc",
         "cwise_op_floor_div.cc",
@@ -4619,6 +4635,7 @@ filegroup(
         "encode_wav_op.cc",
         "fake_quant_ops.cc",
         "fifo_queue.cc",
+        "fifo_queue_op.cc",
         "fused_batch_norm_op.cc",
         "population_count_op.cc",
         "population_count_op.h",
@@ -4642,7 +4659,11 @@ filegroup(
         "depthtospace_op.cc",
         "dynamic_stitch_op.cc",
         "in_topk_op.cc",
+        "initializable_lookup_table.cc",
         "logging_ops.cc",
+        "lookup_table_init_op.cc",
+        "lookup_table_op.cc",
+        "lookup_util.cc",
         "lrn_op.cc",
         "maxpooling_op.cc",
         "mfcc.cc",
@@ -4677,12 +4698,15 @@ filegroup(
         "save_op.cc",
         "save_restore_tensor.cc",
         "save_restore_v2_ops.cc",
+        "segment_reduction_ops.cc",
         "session_ops.cc",
         "softplus_op.cc",
         "softsign_op.cc",
         "spacetobatch_functor.cc",
         "spacetobatch_op.cc",
         "spacetodepth_op.cc",
+        "sparse_fill_empty_rows_op.cc",
+        "sparse_reshape_op.cc",
         "sparse_to_dense_op.cc",
         "spectrogram.cc",
         "spectrogram_op.cc",
@@ -4705,6 +4729,7 @@ filegroup(
         "training_ops.cc",
         "transpose_functor_cpu.cc",
         "transpose_op.cc",
+        "unique_op.cc",
         "warn_about_ints.cc",
         "where_op.cc",
         "xent_op.cc",
diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc
index 258ce154560..b0bec0c5dcd 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.cc
+++ b/tensorflow/core/kernels/concat_lib_cpu.cc
@@ -74,11 +74,14 @@ REGISTER(qint16)
 REGISTER(qint32)
 REGISTER(bfloat16)
 
-#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION)
-// Primarily used for SavedModel support on mobile.
+#if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) && \
+    !defined(__ANDROID_TYPES_FULL__)
+// Primarily used for SavedModel support on mobile. Registering it here only if
+// __ANDROID_TYPES_FULL__ is not defined, as that already register strings
 REGISTER(string);
 #endif  // defined(IS_MOBILE_PLATFORM) &&
-        // !defined(SUPPORT_SELECTIVE_REGISTRATION)
+        // !defined(SUPPORT_SELECTIVE_REGISTRATION) &&
+        // !defined(__ANDROID_TYPES_FULL__)
 
 #ifdef TENSORFLOW_USE_SYCL
 template <typename T>

From af9c4ea6be5589cad66b8cb1159a58d7ec19ca7e Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 03:09:46 -0800
Subject: [PATCH 025/115] Check GPU availability after creating test session.

PiperOrigin-RevId: 174983466
---
 tensorflow/contrib/nccl/python/ops/nccl_ops_test.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
index 0b13e3595e3..bad0abd44cc 100644
--- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
+++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py
@@ -72,14 +72,15 @@ class NcclTestCase(test.TestCase):
           two.
       device_sets: Tuple of virtual devices to run test on.
     """
-    if not test.is_gpu_available():
-      return  # Test requires access to a GPU
-
     for dtype in [np.float32, np.int32, np.int64, np.float64]:
       # Create session inside outer loop to test use of
       # same communicator across multiple sessions.
       with self.test_session(use_gpu=True) as sess:
 
+        # Check GPU availability *after* creating test session, see b/68975239.
+        if not test.is_gpu_available():
+          return  # Test requires access to a GPU
+
         for devices in device_sets:
           shape = (3, 4)
           random = (np.random.random_sample(shape) - .5) * 1024

From 59ea341a1fd0a4badc6c3cfec7a578195a3bf623 Mon Sep 17 00:00:00 2001
From: Shanqing Cai <cais@google.com>
Date: Wed, 8 Nov 2017 06:48:29 -0800
Subject: [PATCH 026/115] tfdbg: Add test for loading DebugDumpDir with a
 relative path

PiperOrigin-RevId: 174999937
---
 .../python/debug/wrappers/dumping_wrapper_test.py  | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
index d987ba84b55..eda5ecf5087 100644
--- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
@@ -111,6 +111,20 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
     self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
     self.assertEqual(repr(None), dump.run_feed_keys_info)
 
+  def testDumpingOnASingleRunWorksWithRelativePathForDebugDumpDir(self):
+    sess = dumping_wrapper.DumpingDebugWrapperSession(
+        self.sess, session_root=self.session_root, log_usage=False)
+    sess.run(self.inc_v)
+    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
+    cwd = os.getcwd()
+    try:
+      os.chdir(self.session_root)
+      dump = debug_data.DebugDumpDir(
+          os.path.relpath(dump_dirs[0], self.session_root))
+      self.assertAllClose([10.0], dump.get_tensors("v", 0, "DebugIdentity"))
+    finally:
+      os.chdir(cwd)
+
   def testDumpingOnASingleRunWithFeedDictWorks(self):
     sess = dumping_wrapper.DumpingDebugWrapperSession(
         self.sess, session_root=self.session_root, log_usage=False)

From ac0ba5bd041f3287bb2a4f12c2ef43a3264f6073 Mon Sep 17 00:00:00 2001
From: Shanqing Cai <cais@google.com>
Date: Wed, 8 Nov 2017 07:36:56 -0800
Subject: [PATCH 027/115] tfdbg: Fix a test bug hidden in a child thread

PiperOrigin-RevId: 175004323
---
 tensorflow/python/debug/wrappers/dumping_wrapper_test.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
index eda5ecf5087..acea9433e22 100644
--- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
@@ -364,12 +364,14 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
         thread_name_filter=r"MainThread$")
 
     self.assertAllClose(1.0, sess.run(self.delta))
+    child_thread_result = []
     def child_thread_job():
-      sess.run(sess.run(self.eta))
+      child_thread_result.append(sess.run(self.eta))
 
     thread = threading.Thread(name="ChildThread", target=child_thread_job)
     thread.start()
     thread.join()
+    self.assertAllClose([-1.4], child_thread_result)
 
     dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
     self.assertEqual(1, len(dump_dirs))

From 500e0aa5eb3fb0ed08b717fc34b8a8f2a2bd0907 Mon Sep 17 00:00:00 2001
From: Igor Ganichev <iga@google.com>
Date: Wed, 8 Nov 2017 10:09:23 -0800
Subject: [PATCH 028/115] Fix incomplete spec of EagerTensor.numpy()

PiperOrigin-RevId: 175023039
---
 tensorflow/python/framework/ops.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 6469aca3ec2..b256af2182a 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -617,15 +617,16 @@ class _EagerTensorBase(Tensor):
     return dtypes._INTERN_TABLE[self._datatype_enum()]  # pylint: disable=protected-access
 
   def numpy(self):
-    """Returns a numpy array with the same contents as the Tensor.
+    """Returns a numpy array or a scalar with the same contents as the Tensor.
 
     TODO(ashankar,agarwal): Perhaps this should NOT reference the underlying
     buffer but instead always explicitly copy? Note that currently it may or may
     not copy based on whether the numpy data is properly aligned or not.
 
     Returns:
-      A numpy array that may share memory with the Tensor object. Any changes
-      to one may be reflected in the other.
+      A numpy array or a scalar. Numpy array may share memory with the
+      Tensor object. Any changes to one may be reflected in the other. A scalar
+      value is returned when self has rank 0.
 
     Raises:
       ValueError: if the type of this Tensor is not representable in numpy.

From 0211cb2f83b620ff899b6876e6e11ac08bc853b2 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 10:19:01 -0800
Subject: [PATCH 029/115] Do not return a mutable HloComputation* from a
 entry_computation() on a const HloModule*.

PiperOrigin-RevId: 175024608
---
 tensorflow/compiler/xla/service/buffer_assignment.cc   | 10 +++++-----
 tensorflow/compiler/xla/service/hlo_module.h           |  6 +++++-
 .../compiler/xla/service/interpreter/executable.cc     |  2 +-
 3 files changed, 11 insertions(+), 7 deletions(-)

diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index b422b22df9c..c74f050f775 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -497,19 +497,19 @@ Status GatherComputationsByAllocationType(
     std::vector<const HloComputation*>* global_computations) {
   // Create a worklist of computations paired with whether the allocation must
   // be thread-local.
-  std::deque<std::pair<HloComputation*, bool>> worklist;
+  std::deque<std::pair<const HloComputation*, bool>> worklist;
   worklist.push_back(std::make_pair(module->entry_computation(),
                                     /*is_thread_local*/ false));
 
   // Sets for quickly checking membership. Computations are returned in vectors
   // for stable iteration.
-  FlatSet<HloComputation*> thread_local_set;
-  FlatSet<HloComputation*> global_set;
+  FlatSet<const HloComputation*> thread_local_set;
+  FlatSet<const HloComputation*> global_set;
 
   while (!worklist.empty()) {
     auto worklist_front = worklist.front();
     worklist.pop_front();
-    HloComputation* computation = worklist_front.first;
+    const HloComputation* computation = worklist_front.first;
     bool is_thread_local = worklist_front.second;
     bool in_thread_local_set = thread_local_set.count(computation) > 0;
     bool in_global_set = global_set.count(computation) > 0;
@@ -653,7 +653,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
   }
 
   if (allow_input_output_aliasing_ && allocation->maybe_live_out()) {
-    HloComputation* entry_computation =
+    const HloComputation* entry_computation =
         assignment->module_->entry_computation();
     for (auto param : entry_computation->parameter_instructions()) {
       for (auto& param_buffer :
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
index 6469851791d..5141e7bc8d4 100644
--- a/tensorflow/compiler/xla/service/hlo_module.h
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -85,7 +85,11 @@ class HloModule {
   std::unique_ptr<HloModule> Clone(const string& suffix = "clone") const;
 
   // Return a pointer to the entry computation of the module..
-  HloComputation* entry_computation() const {
+  const HloComputation* entry_computation() const {
+    CHECK_NE(nullptr, entry_computation_);
+    return entry_computation_;
+  }
+  HloComputation* entry_computation() {
     CHECK_NE(nullptr, entry_computation_);
     return entry_computation_;
   }
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 86dee8462fd..96f937caf96 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -89,7 +89,7 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteOnStream(
 
   uint64 start_micros = tensorflow::Env::Default()->NowMicros();
 
-  HloComputation* computation = module().entry_computation();
+  const HloComputation* computation = module().entry_computation();
   if (computation->num_parameters() != arguments.size()) {
     return tensorflow::errors::Internal(
         "Mismatch between argument count and graph parameter count.");

From a2853e3011bfae8a75fc04ed1ef2f4ff0fd7cf59 Mon Sep 17 00:00:00 2001
From: Thomas Schumm <fwiffo@google.com>
Date: Wed, 8 Nov 2017 10:43:48 -0800
Subject: [PATCH 030/115] HParams.set_hparam doesn't fully check types,
 contrary to its docstring.

PiperOrigin-RevId: 175028981
---
 .../training/python/training/hparam.py        | 58 +++++++++++++++++--
 .../training/python/training/hparam_test.py   | 31 +++++++++-
 2 files changed, 84 insertions(+), 5 deletions(-)

diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py
index 391899b34f9..7db625cdd59 100644
--- a/tensorflow/contrib/training/python/training/hparam.py
+++ b/tensorflow/contrib/training/python/training/hparam.py
@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 import json
+import numbers
 import re
 
 import six
@@ -76,7 +77,7 @@ def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
       function.
 
   Raises:
-    ValueError: If the name has already been sued.
+    ValueError: If the name has already been used.
   """
   try:
     parsed_value = parse_fn(m_dict['val'])
@@ -138,6 +139,54 @@ def _process_list_value(name, parse_fn, var_type, m_dict, values,
     _parse_fail(name, var_type, m_dict['vals'], values)
 
 
+def _cast_to_type_if_compatible(name, param_type, value):
+  """Cast hparam to the provided type, if compatible.
+
+  Args:
+    name: Name of the hparam to be cast.
+    param_type: The type of the hparam.
+    value: The value to be cast, if compatible.
+
+  Returns:
+    The result of casting `value` to `param_type`.
+
+  Raises:
+    ValueError: If the type of `value` is not compatible with param_type.
+      * If `param_type` is a string type, but `value` is not.
+      * If `param_type` is a boolean, but `value` is not, or vice versa.
+      * If `param_type` is an integer type, but `value` is not.
+      * If `param_type` is a float type, but `value` is not a numeric type.
+  """
+  fail_msg = (
+      "Could not cast hparam '%s' of type '%s' from value %r" %
+      (name, param_type, value))
+
+  # Some callers use None, for which we can't do any casting/checking. :(
+  if issubclass(param_type, type(None)):
+    return value
+
+  # Avoid converting a non-string type to a string.
+  if (issubclass(param_type, (six.string_types, six.binary_type)) and
+      not isinstance(value, (six.string_types, six.binary_type))):
+    raise ValueError(fail_msg)
+
+  # Avoid converting a number or string type to a boolean or vice versa.
+  if issubclass(param_type, bool) != isinstance(value, bool):
+    raise ValueError(fail_msg)
+
+  # Avoid converting float to an integer (the reverse is fine).
+  if (issubclass(param_type, numbers.Integral) and
+      not isinstance(value, numbers.Integral)):
+    raise ValueError(fail_msg)
+
+  # Avoid converting a non-numeric type to a numeric type.
+  if (issubclass(param_type, numbers.Number) and
+      not isinstance(value, numbers.Number)):
+    raise ValueError(fail_msg)
+
+  return param_type(value)
+
+
 def parse_values(values, type_map):
   """Parses hyperparameter values from a string into a python map.
 
@@ -438,17 +487,18 @@ class HParams(object):
     Raises:
       ValueError: If there is a type mismatch.
     """
-    _, is_list = self._hparam_types[name]
+    param_type, is_list = self._hparam_types[name]
     if isinstance(value, list):
       if not is_list:
         raise ValueError(
             'Must not pass a list for single-valued parameter: %s' % name)
-      setattr(self, name, value)
+      setattr(self, name, [
+          _cast_to_type_if_compatible(name, param_type, v) for v in value])
     else:
       if is_list:
         raise ValueError(
             'Must pass a list for multi-valued parameter: %s.' % name)
-      setattr(self, name, value)
+      setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
 
   def parse(self, values):
     """Override hyperparameter values, parsing new values from a string.
diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py
index f54514cefd3..949c262f5bb 100644
--- a/tensorflow/contrib/training/python/training/hparam_test.py
+++ b/tensorflow/contrib/training/python/training/hparam_test.py
@@ -318,13 +318,42 @@ class HParamsTest(test.TestCase):
     self.assertEqual(3.0, hparams.b)
     self.assertEqual('relu4', hparams.c_c)
 
-  def testSetHParamTypeMismatch(self):
+  def testSetHParamListNonListMismatch(self):
     hparams = hparam.HParams(a=1, b=[2.0, 3.0])
     with self.assertRaisesRegexp(ValueError, r'Must not pass a list'):
       hparams.set_hparam('a', [1.0])
     with self.assertRaisesRegexp(ValueError, r'Must pass a list'):
       hparams.set_hparam('b', 1.0)
 
+  def testSetHParamTypeMismatch(self):
+    hparams = hparam.HParams(
+        int_=1, str_='str', bool_=True, float_=1.1, list_int=[1, 2], none=None)
+
+    with self.assertRaises(ValueError):
+      hparams.set_hparam('str_', 2.2)
+
+    with self.assertRaises(ValueError):
+      hparams.set_hparam('int_', False)
+
+    with self.assertRaises(ValueError):
+      hparams.set_hparam('bool_', 1)
+
+    with self.assertRaises(ValueError):
+      hparams.set_hparam('int_', 2.2)
+
+    with self.assertRaises(ValueError):
+      hparams.set_hparam('list_int', [2, 3.3])
+
+    with self.assertRaises(ValueError):
+      hparams.set_hparam('int_', '2')
+
+    # Casting int to float is OK
+    hparams.set_hparam('float_', 1)
+
+    # Getting stuck with NoneType :(
+    hparams.set_hparam('none', '1')
+    self.assertEqual('1', hparams.none)
+
   def testNonProtoFails(self):
     with self.assertRaisesRegexp(AssertionError, ''):
       hparam.HParams(hparam_def=1)

From 2884da93d7da90f7532643a8e3f1fa0f2a1d6bbe Mon Sep 17 00:00:00 2001
From: Justin Lebar <jlebar@google.com>
Date: Wed, 8 Nov 2017 10:52:49 -0800
Subject: [PATCH 031/115] [XLA] Print constant literals of size <= 8 elements.

Previously we'd only print scalars.  But if you have a constant with
just a few values, what the heck, show the whole thing.

PiperOrigin-RevId: 175030210
---
 .../compiler/xla/service/hlo_graph_dumper.cc   | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index fd162622ce2..1c063c973dc 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -761,12 +761,22 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
 string HloDotDumper::GetInstructionNodeInlinedOperands(
     const HloInstruction* instr) {
   auto stringify_constant = [](const HloInstruction* constant) {
-    if (ShapeUtil::IsEffectiveScalar(constant->shape())) {
-      auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
-          constant->shape(), /*linear_index=*/0);
-      return Printf("%s (%s)", constant->literal().GetAsString(elem_idx),
+    const auto& shape = constant->shape();
+
+    // Print the literal value of constants with <= K elements.
+    optional<int64> elem_count;
+    if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) {
+      elem_count = 1;
+      for (int64 dim : shape.dimensions()) {
+        *elem_count *= dim;
+      }
+    }
+    if (elem_count.has_value() && *elem_count <= 8) {
+      return Printf("%s (%s)", constant->literal().ToString(),
                     ShapeUtil::HumanString(constant->shape()));
     }
+
+    // Otherwise, print e.g. "%constant.42 (s32[100])".
     string constant_name;
     if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) {
       constant_name = constant->name();

From 729c8c1bb36656c4528d7ff306fbbbd7856733ea Mon Sep 17 00:00:00 2001
From: Justine Tunney <jart@google.com>
Date: Wed, 8 Nov 2017 10:55:48 -0800
Subject: [PATCH 032/115] Add database writer ops to contrib/summary

PiperOrigin-RevId: 175030602
---
 .../contrib/cmake/tf_core_framework.cmake     |   3 +
 tensorflow/contrib/summary/BUILD              |   6 +
 tensorflow/contrib/summary/summary.py         |   2 +
 tensorflow/contrib/summary/summary_ops.py     | 125 ++++++++++++++++--
 .../contrib/summary/summary_ops_test.py       | 110 +++++++++++++++
 tensorflow/contrib/tensorboard/db/BUILD       |   2 +
 .../tensorboard/db/summary_db_writer.cc       |  34 ++++-
 .../tensorboard/db/summary_db_writer_test.cc  |  56 +++++++-
 tensorflow/core/kernels/BUILD                 |   3 +
 tensorflow/core/kernels/summary_interface.cc  |   4 +-
 tensorflow/core/kernels/summary_kernels.cc    |  50 +++++++
 tensorflow/core/ops/summary_ops.cc            |  41 ++++++
 12 files changed, 419 insertions(+), 17 deletions(-)

diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index c3dc8531bb9..c607546f4a5 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -301,6 +301,8 @@ file(GLOB_RECURSE tf_core_framework_srcs
     "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc"
     "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc"
     "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc"
+    "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.cc"
+    "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.h"
     "${tensorflow_source_dir}/public/*.h"
 )
 
@@ -314,6 +316,7 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs
     "${tensorflow_source_dir}/tensorflow/core/util/*test*.h"
     "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc"
     "${tensorflow_source_dir}/tensorflow/core/util/*main.cc"
+    "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc"
 )
 
 list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs})
diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD
index da23f1c3806..3c60d2bb565 100644
--- a/tensorflow/contrib/summary/BUILD
+++ b/tensorflow/contrib/summary/BUILD
@@ -26,12 +26,18 @@ py_test(
     deps = [
         ":summary_ops",
         ":summary_test_util",
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:errors",
+        "//tensorflow/python:framework",
         "//tensorflow/python:framework_test_lib",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:ops",
         "//tensorflow/python:platform",
+        "//tensorflow/python:state_ops",
         "//tensorflow/python:training",
         "//tensorflow/python/eager:function",
         "//tensorflow/python/eager:test",
+        "@six_archive//:six",
     ],
 )
 
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py
index ca82ea094c4..813e8b2b09d 100644
--- a/tensorflow/contrib/summary/summary.py
+++ b/tensorflow/contrib/summary/summary.py
@@ -28,11 +28,13 @@ from __future__ import print_function
 from tensorflow.contrib.summary.summary_ops import all_summary_ops
 from tensorflow.contrib.summary.summary_ops import always_record_summaries
 from tensorflow.contrib.summary.summary_ops import audio
+from tensorflow.contrib.summary.summary_ops import create_summary_db_writer
 from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
 from tensorflow.contrib.summary.summary_ops import eval_dir
 from tensorflow.contrib.summary.summary_ops import generic
 from tensorflow.contrib.summary.summary_ops import histogram
 from tensorflow.contrib.summary.summary_ops import image
+from tensorflow.contrib.summary.summary_ops import import_event
 from tensorflow.contrib.summary.summary_ops import never_record_summaries
 from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps
 from tensorflow.contrib.summary.summary_ops import scalar
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index 9238671c4a2..f6be99f6ae8 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -19,7 +19,12 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import getpass
 import os
+import re
+import time
+
+import six
 
 from tensorflow.contrib.summary import gen_summary_ops
 from tensorflow.python.eager import context
@@ -42,6 +47,10 @@ _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries"
 _SUMMARY_COLLECTION_NAME = "_SUMMARY_V2"
 _SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2"
 
+_EXPERIMENT_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,256}$")
+_RUN_NAME_PATTERNS = re.compile(r"^[^\x00-\x1F<>]{0,512}$")
+_USER_NAME_PATTERNS = re.compile(r"^[a-z]([-a-z0-9]{0,29}[a-z0-9])?$", re.I)
+
 
 def should_record_summaries():
   """Returns boolean Tensor which is true if summaries should be recorded."""
@@ -132,7 +141,8 @@ def create_summary_file_writer(logdir,
      flush once the queue gets bigger than this.
     flush_millis: the largest interval between flushes.
     filename_suffix: optional suffix for the event file name.
-    name: name for the summary writer.
+    name: Shared name for this SummaryWriter resource stored to default
+      Graph.
 
   Returns:
     Either a summary writer or an empty object which can be used as a
@@ -147,14 +157,81 @@ def create_summary_file_writer(logdir,
       flush_millis = constant_op.constant(2 * 60 * 1000)
     if filename_suffix is None:
       filename_suffix = constant_op.constant("")
-    resource = gen_summary_ops.summary_writer(shared_name=name)
-    # TODO(apassos) ensure the initialization op runs when in graph mode;
-    # consider calling session.run here.
-    ops.add_to_collection(
-        _SUMMARY_WRITER_INIT_COLLECTION_NAME,
-        gen_summary_ops.create_summary_file_writer(
-            resource, logdir, max_queue, flush_millis, filename_suffix))
-    return SummaryWriter(resource)
+    return _make_summary_writer(
+        name,
+        gen_summary_ops.create_summary_file_writer,
+        logdir=logdir,
+        max_queue=max_queue,
+        flush_millis=flush_millis,
+        filename_suffix=filename_suffix)
+
+
+def create_summary_db_writer(db_uri,
+                             experiment_name=None,
+                             run_name=None,
+                             user_name=None,
+                             name=None):
+  """Creates a summary database writer in the current context.
+
+  This can be used to write tensors from the execution graph directly
+  to a database. Only SQLite is supported right now. This function
+  will create the schema if it doesn't exist. Entries in the Users,
+  Experiments, and Runs tables will be created automatically if they
+  don't already exist.
+
+  Args:
+    db_uri: For example "file:/tmp/foo.sqlite".
+    experiment_name: Defaults to YYYY-MM-DD in local time if None.
+      Empty string means the Run will not be associated with an
+      Experiment. Can't contain ASCII control characters or <>. Case
+      sensitive.
+    run_name: Defaults to HH:MM:SS in local time if None. Empty string
+      means a Tag will not be associated with any Run. Can't contain
+      ASCII control characters or <>. Case sensitive.
+    user_name: Defaults to system username if None. Empty means the
+      Experiment will not be associated with a User. Must be valid as
+      both a DNS label and Linux username.
+    name: Shared name for this SummaryWriter resource stored to default
+      Graph.
+
+  Returns:
+    A new SummaryWriter instance.
+  """
+  with ops.device("cpu:0"):
+    if experiment_name is None:
+      experiment_name = time.strftime("%Y-%m-%d", time.localtime(time.time()))
+    if run_name is None:
+      run_name = time.strftime("%H:%M:%S", time.localtime(time.time()))
+    if user_name is None:
+      user_name = getpass.getuser()
+    experiment_name = _cleanse_string(
+        "experiment_name", _EXPERIMENT_NAME_PATTERNS, experiment_name)
+    run_name = _cleanse_string("run_name", _RUN_NAME_PATTERNS, run_name)
+    user_name = _cleanse_string("user_name", _USER_NAME_PATTERNS, user_name)
+    return _make_summary_writer(
+        name,
+        gen_summary_ops.create_summary_db_writer,
+        db_uri=db_uri,
+        experiment_name=experiment_name,
+        run_name=run_name,
+        user_name=user_name)
+
+
+def _make_summary_writer(name, factory, **kwargs):
+  resource = gen_summary_ops.summary_writer(shared_name=name)
+  # TODO(apassos): Consider doing this instead.
+  # node = factory(resource, **kwargs)
+  # if not context.in_eager_mode():
+  #   ops.get_default_session().run(node)
+  ops.add_to_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME,
+                        factory(resource, **kwargs))
+  return SummaryWriter(resource)
+
+
+def _cleanse_string(name, pattern, value):
+  if isinstance(value, six.string_types) and pattern.search(value) is None:
+    raise ValueError("%s (%s) must match %s" % (name, value, pattern.pattern))
+  return ops.convert_to_tensor(value, dtypes.string)
 
 
 def _nothing():
@@ -206,16 +283,22 @@ def summary_writer_function(name, tensor, function, family=None):
   return op
 
 
-def generic(name, tensor, metadata, family=None, global_step=None):
+def generic(name, tensor, metadata=None, family=None, global_step=None):
   """Writes a tensor summary if possible."""
   if global_step is None:
     global_step = training_util.get_global_step()
   def function(tag, scope):
+    if metadata is None:
+      serialized_metadata = constant_op.constant("")
+    elif hasattr(metadata, "SerializeToString"):
+      serialized_metadata = constant_op.constant(metadata.SerializeToString())
+    else:
+      serialized_metadata = metadata
     # Note the identity to move the tensor to the CPU.
     return gen_summary_ops.write_summary(
         context.context().summary_writer_resource,
         global_step, array_ops.identity(tensor),
-        tag, metadata, name=scope)
+        tag, serialized_metadata, name=scope)
   return summary_writer_function(name, tensor, function, family=family)
 
 
@@ -284,6 +367,26 @@ def audio(name, tensor, sample_rate, max_outputs, family=None,
   return summary_writer_function(name, tensor, function, family=family)
 
 
+def import_event(tensor, name=None):
+  """Writes a tf.Event binary proto.
+
+  When using create_summary_db_writer(), this can be used alongside
+  tf.TFRecordReader to load event logs into the database. Please note
+  that this is lower level than the other summary functions and will
+  ignore any conditions set by methods like should_record_summaries().
+
+  Args:
+    tensor: A `Tensor` of type `string` containing a serialized `Event`
+      proto.
+    name: A name for the operation (optional).
+
+  Returns:
+    The created Operation.
+  """
+  return gen_summary_ops.import_event(
+      context.context().summary_writer_resource, tensor, name=name)
+
+
 def eval_dir(model_dir, name=None):
   """Construct a logdir for an eval summary writer."""
   return os.path.join(model_dir, "eval" if not name else "eval_" + name)
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 466e1940969..6e1a746815f 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -17,14 +17,22 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import functools
+import os
 import tempfile
 
+import six
+import sqlite3
+
 from tensorflow.contrib.summary import summary_ops
 from tensorflow.contrib.summary import summary_test_util
 from tensorflow.python.eager import function
 from tensorflow.python.eager import test
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import state_ops
 from tensorflow.python.platform import gfile
 from tensorflow.python.training import training_util
 
@@ -99,5 +107,107 @@ class TargetTest(test_util.TensorFlowTestCase):
       self.assertEqual(len(events), 2)
       self.assertEqual(events[1].summary.value[0].tag, 'scalar')
 
+
+class DbTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
+    if os.path.exists(self.db_path):
+      os.unlink(self.db_path)
+    self.db = sqlite3.connect(self.db_path)
+    self.create_summary_db_writer = functools.partial(
+        summary_ops.create_summary_db_writer,
+        db_uri=self.db_path,
+        experiment_name='experiment',
+        run_name='run',
+        user_name='user')
+
+  def tearDown(self):
+    self.db.close()
+
+  def testIntegerSummaries(self):
+    step = training_util.create_global_step()
+
+    def adder(x, y):
+      state_ops.assign_add(step, 1)
+      summary_ops.generic('x', x)
+      summary_ops.generic('y', y)
+      sum_ = x + y
+      summary_ops.generic('sum', sum_)
+      return sum_
+
+    with summary_ops.always_record_summaries():
+      with self.create_summary_db_writer().as_default():
+        self.assertEqual(5, adder(int64(2), int64(3)).numpy())
+
+    six.assertCountEqual(self, [1, 1, 1],
+                         get_all(self.db, 'SELECT step FROM Tensors'))
+    six.assertCountEqual(self, ['x', 'y', 'sum'],
+                         get_all(self.db, 'SELECT tag_name FROM Tags'))
+    x_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "x"')
+    y_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "y"')
+    sum_id = get_one(self.db, 'SELECT tag_id FROM Tags WHERE tag_name = "sum"')
+
+    with summary_ops.always_record_summaries():
+      with self.create_summary_db_writer().as_default():
+        self.assertEqual(9, adder(int64(4), int64(5)).numpy())
+
+    six.assertCountEqual(self, [1, 1, 1, 2, 2, 2],
+                         get_all(self.db, 'SELECT step FROM Tensors'))
+    six.assertCountEqual(self, [x_id, y_id, sum_id],
+                         get_all(self.db, 'SELECT tag_id FROM Tags'))
+    self.assertEqual(2, get_tensor(self.db, x_id, 1))
+    self.assertEqual(3, get_tensor(self.db, y_id, 1))
+    self.assertEqual(5, get_tensor(self.db, sum_id, 1))
+    self.assertEqual(4, get_tensor(self.db, x_id, 2))
+    self.assertEqual(5, get_tensor(self.db, y_id, 2))
+    self.assertEqual(9, get_tensor(self.db, sum_id, 2))
+    six.assertCountEqual(
+        self, ['experiment'],
+        get_all(self.db, 'SELECT experiment_name FROM Experiments'))
+    six.assertCountEqual(self, ['run'],
+                         get_all(self.db, 'SELECT run_name FROM Runs'))
+    six.assertCountEqual(self, ['user'],
+                         get_all(self.db, 'SELECT user_name FROM Users'))
+
+  def testBadExperimentName(self):
+    with self.assertRaises(ValueError):
+      self.create_summary_db_writer(experiment_name='\0')
+
+  def testBadRunName(self):
+    with self.assertRaises(ValueError):
+      self.create_summary_db_writer(run_name='\0')
+
+  def testBadUserName(self):
+    with self.assertRaises(ValueError):
+      self.create_summary_db_writer(user_name='-hi')
+    with self.assertRaises(ValueError):
+      self.create_summary_db_writer(user_name='hi-')
+    with self.assertRaises(ValueError):
+      self.create_summary_db_writer(user_name='@')
+
+
+def get_one(db, q, *p):
+  return db.execute(q, p).fetchone()[0]
+
+
+def get_all(db, q, *p):
+  return unroll(db.execute(q, p).fetchall())
+
+
+def get_tensor(db, tag_id, step):
+  return get_one(
+      db, 'SELECT tensor FROM Tensors WHERE tag_id = ? AND step = ?', tag_id,
+      step)
+
+
+def int64(x):
+  return array_ops.constant(x, dtypes.int64)
+
+
+def unroll(list_of_tuples):
+  return sum(list_of_tuples, ())
+
+
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD
index d8bbf87d2ce..068e862650d 100644
--- a/tensorflow/contrib/tensorboard/db/BUILD
+++ b/tensorflow/contrib/tensorboard/db/BUILD
@@ -45,10 +45,12 @@ cc_library(
 
 tf_cc_test(
     name = "summary_db_writer_test",
+    size = "small",
     srcs = ["summary_db_writer_test.cc"],
     deps = [
         ":summary_db_writer",
         "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/lib/db:sqlite",
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
index df64e363055..a26ad616603 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc
@@ -15,10 +15,12 @@ limitations under the License.
 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
 
 #include "tensorflow/contrib/tensorboard/db/schema.h"
+#include "tensorflow/core/framework/summary.pb.h"
 #include "tensorflow/core/lib/db/sqlite.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/snappy.h"
+#include "tensorflow/core/util/event.pb.h"
 
 namespace tensorflow {
 namespace {
@@ -86,13 +88,19 @@ class SummaryDbWriter : public SummaryWriterInterface {
         TF_RETURN_IF_ERROR(BindTensor(t));
         break;
     }
-    TF_RETURN_IF_ERROR(insert_tensor_.StepAndReset());
-    return Status::OK();
+    return insert_tensor_.StepAndReset();
   }
 
   Status WriteEvent(std::unique_ptr<Event> e) override {
-    // TODO(@jart): This will be used to load event logs.
-    return errors::Unimplemented("WriteEvent");
+    mutex_lock ml(mu_);
+    TF_RETURN_IF_ERROR(InitializeParents());
+    if (e->what_case() == Event::WhatCase::kSummary) {
+      const Summary& summary = e->summary();
+      for (int i = 0; i < summary.value_size(); ++i) {
+        TF_RETURN_IF_ERROR(WriteSummary(e.get(), summary.value(i)));
+      }
+    }
+    return Status::OK();
   }
 
   Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
@@ -247,6 +255,24 @@ class SummaryDbWriter : public SummaryWriterInterface {
     return Status::OK();
   }
 
+  Status WriteSummary(const Event* e, const Summary::Value& summary)
+      EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+    int64 tag_id;
+    TF_RETURN_IF_ERROR(GetTagId(run_id_, summary.tag(), &tag_id));
+    insert_tensor_.BindInt(1, tag_id);
+    insert_tensor_.BindInt(2, e->step());
+    insert_tensor_.BindDouble(3, e->wall_time());
+    switch (summary.value_case()) {
+      case Summary::Value::ValueCase::kSimpleValue:
+        insert_tensor_.BindDouble(4, summary.simple_value());
+        break;
+      default:
+        // TODO(@jart): Handle the rest.
+        return Status::OK();
+    }
+    return insert_tensor_.StepAndReset();
+  }
+
   mutex mu_;
   Env* env_;
   std::shared_ptr<Sqlite> db_ GUARDED_BY(mu_);
diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
index d32904f97c4..c1af51e7b7a 100644
--- a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
+++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc
@@ -14,14 +14,19 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
 
+#include "tensorflow/core/framework/summary.pb.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/db/sqlite.h"
+#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/event.pb.h"
 
 namespace tensorflow {
 namespace {
 
+const float kTolerance = 1e-5;
+
 Tensor MakeScalarInt64(int64 x) {
   Tensor t(DT_INT64, TensorShape({}));
   t.scalar<int64>()() = x;
@@ -41,7 +46,7 @@ class FakeClockEnv : public EnvWrapper {
 
 class SummaryDbWriterTest : public ::testing::Test {
  protected:
-  void SetUp() override { db_ = Sqlite::Open("file::memory:").ValueOrDie(); }
+  void SetUp() override { db_ = Sqlite::Open(":memory:").ValueOrDie(); }
 
   void TearDown() override {
     if (writer_ != nullptr) {
@@ -158,5 +163,54 @@ TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) {
       QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty());
 }
 
+TEST_F(SummaryDbWriterTest, EmptyParentNames_NoParentsCreated) {
+  TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
+  TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy",
+                                    "this-is-metaaa"));
+  TF_ASSERT_OK(writer_->Flush());
+  ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users"));
+  ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments"));
+  ASSERT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs"));
+  ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags"));
+  ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+}
+
+TEST_F(SummaryDbWriterTest, WriteEvent_Scalar) {
+  TF_ASSERT_OK(CreateSummaryDbWriter(db_, "", "", "", &env_, &writer_));
+  std::unique_ptr<Event> e{new Event};
+  e->set_step(7);
+  e->set_wall_time(123.456);
+  Summary::Value* s = e->mutable_summary()->add_value();
+  s->set_tag("Ï€");
+  s->set_simple_value(3.14f);
+  s = e->mutable_summary()->add_value();
+  s->set_tag("φ");
+  s->set_simple_value(1.61f);
+  TF_ASSERT_OK(writer_->WriteEvent(std::move(e)));
+  TF_ASSERT_OK(writer_->Flush());
+  ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tags"));
+  ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors"));
+  int64 tag1_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'Ï€'");
+  int64 tag2_id = QueryInt("SELECT tag_id FROM Tags WHERE tag_name = 'φ'");
+  EXPECT_GT(tag1_id, 0LL);
+  EXPECT_GT(tag2_id, 0LL);
+  EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
+                         "SELECT computed_time FROM Tensors WHERE tag_id = ",
+                         tag1_id, " AND step = 7")));
+  EXPECT_EQ(123.456, QueryDouble(strings::StrCat(
+                         "SELECT computed_time FROM Tensors WHERE tag_id = ",
+                         tag2_id, " AND step = 7")));
+  EXPECT_NEAR(3.14,
+              QueryDouble(strings::StrCat(
+                  "SELECT tensor FROM Tensors WHERE tag_id = ", tag1_id,
+                  " AND step = 7")),
+              kTolerance);  // Summary::simple_value is float
+  EXPECT_NEAR(1.61,
+              QueryDouble(strings::StrCat(
+                  "SELECT tensor FROM Tensors WHERE tag_id = ", tag2_id,
+                  " AND step = 7")),
+              kTolerance);
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 6206963251f..4169e842da6 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -6243,8 +6243,11 @@ tf_kernel_library(
     srcs = ["summary_kernels.cc"],
     deps = [
         ":summary_interface",
+        "//tensorflow/contrib/tensorboard/db:summary_db_writer",
         "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
         "//tensorflow/core:summary_ops_op_lib",
+        "//tensorflow/core/lib/db:sqlite",
     ],
 )
 
diff --git a/tensorflow/core/kernels/summary_interface.cc b/tensorflow/core/kernels/summary_interface.cc
index 313137ae495..cd366f8c137 100644
--- a/tensorflow/core/kernels/summary_interface.cc
+++ b/tensorflow/core/kernels/summary_interface.cc
@@ -257,7 +257,9 @@ class SummaryWriterImpl : public SummaryWriterInterface {
     Summary::Value* v = e->mutable_summary()->add_value();
     t.AsProtoTensorContent(v->mutable_tensor());
     v->set_tag(tag);
-    v->mutable_metadata()->ParseFromString(serialized_metadata);
+    if (!serialized_metadata.empty()) {
+      v->mutable_metadata()->ParseFromString(serialized_metadata);
+    }
     return WriteEvent(std::move(e));
   }
 
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index cfa707de715..1fe2fc5b666 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -13,9 +13,12 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h"
 #include "tensorflow/core/framework/op_kernel.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/kernels/summary_interface.h"
+#include "tensorflow/core/lib/db/sqlite.h"
+#include "tensorflow/core/platform/protobuf.h"
 
 namespace tensorflow {
 
@@ -46,6 +49,32 @@ class CreateSummaryFileWriterOp : public OpKernel {
 REGISTER_KERNEL_BUILDER(Name("CreateSummaryFileWriter").Device(DEVICE_CPU),
                         CreateSummaryFileWriterOp);
 
+class CreateSummaryDbWriterOp : public OpKernel {
+ public:
+  explicit CreateSummaryDbWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor* tmp;
+    OP_REQUIRES_OK(ctx, ctx->input("db_uri", &tmp));
+    const string db_uri = tmp->scalar<string>()();
+    OP_REQUIRES_OK(ctx, ctx->input("experiment_name", &tmp));
+    const string experiment_name = tmp->scalar<string>()();
+    OP_REQUIRES_OK(ctx, ctx->input("run_name", &tmp));
+    const string run_name = tmp->scalar<string>()();
+    OP_REQUIRES_OK(ctx, ctx->input("user_name", &tmp));
+    const string user_name = tmp->scalar<string>()();
+    SummaryWriterInterface* s;
+    auto db = Sqlite::Open(db_uri);
+    OP_REQUIRES_OK(ctx, db.status());
+    OP_REQUIRES_OK(
+        ctx, CreateSummaryDbWriter(std::move(db.ValueOrDie()), experiment_name,
+                                   run_name, user_name, ctx->env(), &s));
+    OP_REQUIRES_OK(ctx, CreateResource(ctx, HandleFromInput(ctx, 0), s));
+  }
+};
+REGISTER_KERNEL_BUILDER(Name("CreateSummaryDbWriter").Device(DEVICE_CPU),
+                        CreateSummaryDbWriterOp);
+
 class FlushSummaryWriterOp : public OpKernel {
  public:
   explicit FlushSummaryWriterOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -98,6 +127,27 @@ class WriteSummaryOp : public OpKernel {
 REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU),
                         WriteSummaryOp);
 
+class ImportEventOp : public OpKernel {
+ public:
+  explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    SummaryWriterInterface* s;
+    OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
+    core::ScopedUnref unref(s);
+    const Tensor* t;
+    OP_REQUIRES_OK(ctx, ctx->input("event", &t));
+    std::unique_ptr<Event> event{new Event};
+    if (!ParseProtoUnlimited(event.get(), t->scalar<string>()())) {
+      ctx->CtxFailureWithWarning(
+          errors::DataLoss("Bad tf.Event binary proto tensor string"));
+      return;
+    }
+    OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event)));
+  }
+};
+REGISTER_KERNEL_BUILDER(Name("ImportEvent").Device(DEVICE_CPU), ImportEventOp);
+
 class WriteScalarSummaryOp : public OpKernel {
  public:
   explicit WriteScalarSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
diff --git a/tensorflow/core/ops/summary_ops.cc b/tensorflow/core/ops/summary_ops.cc
index f778b487972..5efbac7ad76 100644
--- a/tensorflow/core/ops/summary_ops.cc
+++ b/tensorflow/core/ops/summary_ops.cc
@@ -49,6 +49,33 @@ flush_millis: How often, in milliseconds, to flush the pending events and
 filename_suffix: Every event file's name is suffixed with this suffix.
 )doc");
 
+REGISTER_OP("CreateSummaryDbWriter")
+    .Input("writer: resource")
+    .Input("db_uri: string")
+    .Input("experiment_name: string")
+    .Input("run_name: string")
+    .Input("user_name: string")
+    .SetShapeFn(shape_inference::NoOutputs)
+    .Doc(R"doc(
+Creates summary database writer accessible by given resource handle.
+
+This can be used to write tensors from the execution graph directly
+to a database. Only SQLite is supported right now. This function
+will create the schema if it doesn't exist. Entries in the Users,
+Experiments, and Runs tables will be created automatically if they
+don't already exist.
+
+writer: Handle to SummaryWriter resource to overwrite.
+db_uri: For example "file:/tmp/foo.sqlite".
+experiment_name: Can't contain ASCII control characters or <>. Case
+  sensitive. If empty, then the Run will not be associated with any
+  Experiment.
+run_name: Can't contain ASCII control characters or <>. Case sensitive.
+  If empty, then each Tag will not be associated with any Run.
+user_name: Must be valid as both a DNS label and Linux username. If
+  empty, then the Experiment will not be associated with any User.
+)doc");
+
 REGISTER_OP("FlushSummaryWriter")
     .Input("writer: resource")
     .SetShapeFn(shape_inference::NoOutputs)
@@ -89,6 +116,20 @@ summary_metadata: Serialized SummaryMetadata protocol buffer containing
  plugin-related metadata for this summary.
 )doc");
 
+REGISTER_OP("ImportEvent")
+    .Input("writer: resource")
+    .Input("event: string")
+    .SetShapeFn(shape_inference::NoOutputs)
+    .Doc(R"doc(
+Outputs a `tf.Event` protocol buffer.
+
+When CreateSummaryDbWriter is being used, this op can be useful for
+importing data from event logs.
+
+writer: A handle to a summary writer.
+event: A string containing a binary-encoded tf.Event proto.
+)doc");
+
 REGISTER_OP("WriteScalarSummary")
     .Input("writer: resource")
     .Input("global_step: int64")

From 83d9635669c60fa75910999ceb0c18341a08843a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 11:27:04 -0800
Subject: [PATCH 033/115] Add comment describing how to get optimized builds in
 Dockerfile.

PiperOrigin-RevId: 175036186
---
 tensorflow/tools/docker/Dockerfile.devel | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index 20e1dcd0854..1a0145b0785 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -83,6 +83,11 @@ ENV CI_BUILD_PYTHON python
 
 RUN tensorflow/tools/ci_build/builds/configured CPU \
     bazel build -c opt --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" \
+        # For optimized builds appropriate for the hardware platform of your choosing, uncomment below...
+        # For ivy-bridge or sandy-bridge
+        # --copt=-march="ivybridge" \
+        # for haswell, broadwell, or skylake
+        # --copt=-march="haswell" \
         tensorflow/tools/pip_package:build_pip_package && \
     bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/pip && \
     pip --no-cache-dir install --upgrade /tmp/pip/tensorflow-*.whl && \

From 5856a5cef9cd9cfdf16add7024ba4910949c2604 Mon Sep 17 00:00:00 2001
From: Chris Leary <leary@google.com>
Date: Wed, 8 Nov 2017 11:28:21 -0800
Subject: [PATCH 034/115] [XLA] More diagnostic information in the reshape
 shape inference error.

PiperOrigin-RevId: 175036413
---
 tensorflow/compiler/xla/service/shape_inference.cc | 5 ++++-
 tensorflow/compiler/xla/tests/reshape_test.cc      | 5 +++--
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 9c7dc2185e3..dcd726f22c7 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1948,7 +1948,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
       !std::is_permutation(dimensions.begin(), dimensions.end(),
                            indices.begin())) {
     return InvalidArgument(
-        "Reshape dimensions not a permutation of the operand dimensions.");
+        "Reshape dimensions [%s] are not a permutation of the operand "
+        "dimensions (operand shape is %s).",
+        tensorflow::str_util::Join(dimensions, ",").c_str(),
+        ShapeUtil::HumanString(operand).c_str());
   }
 
   return inferred_shape;
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
index 72c68f24a0a..d235b9a1580 100644
--- a/tensorflow/compiler/xla/tests/reshape_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -431,8 +431,9 @@ XLA_TEST_F(ReshapeTest, ToScalar) {
 XLA_TEST_F(ReshapeTest, BadDimensions) {
   ComputationBuilder b(client_, TestName());
   b.Reshape(b.ConstantR1<int32>({1}), {}, {});
-  EXPECT_THAT(ExecuteToString(&b, {}),
-              ::testing::HasSubstr("dimensions not a permutation"));
+  EXPECT_THAT(
+      ExecuteToString(&b, {}),
+      ::testing::HasSubstr("not a permutation of the operand dimensions"));
 }
 
 XLA_TEST_F(ReshapeTest, BadNewSizes) {

From f95d6a01d341231d18bb969b12e615a9cb066e00 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 11:30:23 -0800
Subject: [PATCH 035/115] Minor docstring fixes

PiperOrigin-RevId: 175036743
---
 tensorflow/python/ops/ctc_ops.py | 30 +++++++++++++++++-------------
 1 file changed, 17 insertions(+), 13 deletions(-)

diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index 477c0d1cb49..f037767cf40 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -22,8 +22,8 @@ from __future__ import print_function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 
-from tensorflow.python.ops import gen_ctc_ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_ctc_ops
 from tensorflow.python.ops.nn_grad import _BroadcastMul
 
 
@@ -38,7 +38,8 @@ def ctc_loss(labels, inputs, sequence_length,
 
   [A. Graves, S. Fernandez, F. Gomez, J. Schmidhuber.
   Connectionist Temporal Classification: Labeling Unsegmented Sequence Data
-  with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA, pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
+  with Recurrent Neural Networks. ICML 2006, Pittsburgh, USA,
+  pp. 369-376.](http://www.cs.toronto.edu/~graves/icml_2006.pdf)
 
   Input requirements:
 
@@ -108,9 +109,9 @@ def ctc_loss(labels, inputs, sequence_length,
       See `core/ops/ctc_ops.cc` for more details.
     inputs: 3-D `float` `Tensor`.
       If time_major == False, this will be a `Tensor` shaped:
-        `[batch_size x max_time x num_classes]`.
+        `[batch_size, max_time, num_classes]`.
       If time_major == True (default), this will be a `Tensor` shaped:
-        `[max_time x batch_size x num_classes]`.
+        `[max_time, batch_size, num_classes]`.
       The logits.
     sequence_length: 1-D `int32` vector, size `[batch_size]`.
       The sequence lengths.
@@ -120,15 +121,18 @@ def ctc_loss(labels, inputs, sequence_length,
     ignore_longer_outputs_than_inputs: Boolean. Default: False.
       If True, sequences with longer outputs than inputs will be ignored.
     time_major: The shape format of the `inputs` Tensors.
-      If True, these `Tensors` must be shaped `[max_time, batch_size, num_classes]`.
-      If False, these `Tensors` must be shaped `[batch_size, max_time, num_classes]`.
-      Using `time_major = True` (default) is a bit more efficient because it avoids
-      transposes at the beginning of the ctc_loss calculation.  However, most
-      TensorFlow data is batch-major, so by this function also accepts inputs
-      in batch-major form.
+      If True, these `Tensors` must be shaped `[max_time, batch_size,
+      num_classes]`.
+      If False, these `Tensors` must be shaped `[batch_size, max_time,
+      num_classes]`.
+      Using `time_major = True` (default) is a bit more efficient because it
+      avoids transposes at the beginning of the ctc_loss calculation.  However,
+      most TensorFlow data is batch-major, so by this function also accepts
+      inputs in batch-major form.
 
   Returns:
-    A 1-D `float` `Tensor`, size `[batch]`, containing the negative log probabilities.
+    A 1-D `float` `Tensor`, size `[batch]`, containing the negative log
+      probabilities.
 
   Raises:
     TypeError: if labels is not a `SparseTensor`.
@@ -198,7 +202,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
 
   Args:
     inputs: 3-D `float` `Tensor` sized
-      `[max_time x batch_size x num_classes]`.  The logits.
+      `[max_time, batch_size, num_classes]`.  The logits.
     sequence_length: 1-D `int32` vector containing sequence lengths,
       having size `[batch_size]`.
     merge_repeated: Boolean.  Default: True.
@@ -207,7 +211,7 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
     A tuple `(decoded, neg_sum_logits)` where
     decoded: A single-element list. `decoded[0]`
       is an `SparseTensor` containing the decoded outputs s.t.:
-      `decoded.indices`: Indices matrix `(total_decoded_outputs x 2)`.
+      `decoded.indices`: Indices matrix `(total_decoded_outputs, 2)`.
         The rows store: `[batch, time]`.
       `decoded.values`: Values vector, size `(total_decoded_outputs)`.
         The vector stores the decoded classes.

From 9c9dbe9740cb3ec385a3c9c6eb0ec57229486e90 Mon Sep 17 00:00:00 2001
From: Sanjoy Das <sanjoy@google.com>
Date: Wed, 8 Nov 2017 11:32:03 -0800
Subject: [PATCH 036/115] [XLA:CPU] Implement single threaded Matrix-Vector
 products in LLVM IR

Right now we're always doing a 8x8 tiling on the matrix. This can probably be
tuned further.

There are some other follow-up items that I did not want to put in this already
large CL:

 - Eigen has some smarts to avoid issuing unaligned vector loads and stores
   which the current CL does not.  We need to investigate if being smart about
   alignment is worth it.

 - Prevent LLVM from vectorizing the epilogue.  In fact we should disable loop
   vectorization for all the loops we've explicitly vectorized.

 - Cache the kernels by their shape to reduce code size impact.

 - Add aliasing information to the loads and stores emitted by the
   PacketSupportLibrary.  This is probably not super critical since we've
   already vectorized the code, but we should do this for completeness.

PiperOrigin-RevId: 175036991
---
 tensorflow/compiler/xla/service/cpu/BUILD     |   2 +
 .../xla/service/cpu/dot_op_emitter.cc         | 564 +++++++++++++++++-
 .../compiler/xla/service/cpu/dot_op_emitter.h |  28 +
 .../xla/service/cpu/ir_emission_utils.cc      |  17 +-
 .../xla/service/cpu/ir_emission_utils.h       |  11 +-
 .../xla/service/cpu/layout_assignment.cc      |   4 +-
 tensorflow/compiler/xla/service/llvm_ir/BUILD |  24 +
 .../service/llvm_ir/kernel_support_library.cc |  63 ++
 .../service/llvm_ir/kernel_support_library.h  | 124 ++++
 .../compiler/xla/service/llvm_ir/llvm_util.cc |   8 +
 .../compiler/xla/service/llvm_ir/llvm_util.h  |   2 +
 .../service/llvm_ir/vector_support_library.cc | 150 +++++
 .../service/llvm_ir/vector_support_library.h  | 174 ++++++
 .../compiler/xla/tests/dot_operation_test.cc  |  80 +++
 14 files changed, 1233 insertions(+), 18 deletions(-)
 create mode 100644 tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
 create mode 100644 tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
 create mode 100644 tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc
 create mode 100644 tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h

diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 6213baee2fa..10ec677e2f2 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -290,8 +290,10 @@ cc_library(
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:hlo_module_config",
         "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+        "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
         "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
         "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+        "//tensorflow/compiler/xla/service/llvm_ir:vector_support_library",
         "//tensorflow/core:lib",
         "@llvm//:core",
     ],
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index e57d49172b1..1cbd4094a35 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -25,7 +25,9 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
 #include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/util.h"
@@ -38,6 +40,450 @@ using llvm_ir::SetToFirstInsertPoint;
 
 namespace cpu {
 
+namespace {
+// Loads a tile of values from a 2D tensor.
+class TileLoader {
+ public:
+  // Constructs a TileLoader that will load a tile consisting of
+  // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at
+  // `major_dim_offset` in the major dimension.  The tile size along the minor
+  // dimension is the vector size, and that is implicitly determined by `vsl`.
+  TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder,
+             llvm::Value* matrix, int64 matrix_size_along_minor_dim,
+             llvm::Value* major_dim_offset, int64 tile_size_along_major_dim)
+      : vsl_(vsl) {
+    pointers_.reserve(tile_size_along_major_dim);
+    for (int64 i = 0; i < tile_size_along_major_dim; i++) {
+      llvm::Value* total_offset = ir_builder->CreateMul(
+          ir_builder->getInt64(matrix_size_along_minor_dim),
+          ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset));
+      pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset));
+    }
+  }
+
+  // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at
+  // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the
+  // minor dimension.
+  std::vector<llvm::Value*> LoadTile(llvm::Value* minor_dim_offset) const {
+    std::vector<llvm::Value*> result;
+    result.reserve(pointers_.size());
+    for (const auto& pointer : pointers_) {
+      result.push_back(vsl_->LoadVector(pointer, minor_dim_offset));
+    }
+    return result;
+  }
+
+ private:
+  VectorSupportLibrary* vsl_;
+  std::vector<llvm::Value*> pointers_;
+};
+
+// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the
+// layout of the vector does not matter).  This implementation uses a tiling
+// scheme to improve performance.
+//
+// We logically separate the LHS matrix into four segments:
+//
+//   +----------------------+---+
+//   |                      |   |
+//   |                      |   |
+//   |         A            | B |
+//   |                      |   |
+//   |                      |   |
+//   |                      |   |
+//   +----------------------+---+
+//   |         C            | D |
+//   +----------------------+---+
+//
+// where A is the largest submatrix of the LHS that can be evenly dividied into
+// tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
+//
+//   +---+---+---+---+       +--+--+--+--+
+//   |M00|M10|M20|M30|       |V0|V1|V2|V3|
+//   +---+---+---+---+       +--+--+--+--+
+//   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
+//   +---+---+---+---+       +--+--+--+--+
+//   |M02|M12|M22|M32|       |V0|V1|V2|V3|
+//   +---+---+---+---+       +--+--+--+--+
+//   |M03|M13|M23|M33|       |V0|V1|V2|V3|
+//   +---+---+---+---+       +--+--+--+--+
+//
+// (Legend: rows are horizontal and columns are vertical; and each column is one
+// llvm::Value of a vector type)
+//
+// where:
+//
+//   a. The left tile is from the column major left matrix.
+//   b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3]
+//      vector loaded from the RHS vector.
+//
+// As we iterate through the column dimension, we compute the change to the
+// result vector by an elementwise multiplication between the two tiles above
+// followed by a reduction along the major dimension:
+//
+//                     +-----------------------------------+
+//                     | M00*V0 + M10*V1 + M20*V2 + M30*V3 |
+//                     +-----------------------------------+
+//                     | M01*V0 + M11*V1 + M21*V2 + M31*V3 |
+// Result[R:R+4] +=    +-----------------------------------+
+//                     | M02*V0 + M12*V1 + M22*V2 + M32*V3 |
+//                     +-----------------------------------+
+//                     | M03*V0 + M13*V1 + M23*V2 + M33*V3 |
+//                     +-----------------------------------+
+//
+// Where R is the starting row for the tile.
+//
+// We have an inner epilogue loop to deal with the "C" submatrix and an outer
+// epilogue loop to deal with the B,D submarix.
+//
+// TODO(sanjoy): We should investigate if using gather loads and scatter stores
+// can be used here have the same inner loop for both column-major and row-major
+// matrix-vector products.
+class ColumnMajorMatrixVectorProductEmitter {
+ public:
+  ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type,
+                                        int64 tile_rows, int64 tile_cols,
+                                        int64 m, int64 k, llvm::Value* lhs,
+                                        llvm::Value* rhs, llvm::Value* result,
+                                        llvm::IRBuilder<>* ir_builder)
+      : scalar_type_(scalar_type),
+        tile_rows_(tile_rows),
+        tile_cols_(tile_cols),
+        m_(m),
+        k_(k),
+        lhs_(lhs),
+        rhs_(rhs),
+        result_(result),
+        ir_builder_(ir_builder),
+        ksl_(ir_builder_),
+        vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
+    CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_rows_)));
+  }
+
+  void Emit();
+
+ private:
+  void EmitOuterLoopBody(llvm::Value* column, int64 column_count,
+                         bool is_first_column);
+
+  TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) {
+    return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
+                      /*matrix_size_along_minor_dim=*/m_,
+                      /*major_dim_offset=*/column_start,
+                      /*tile_size_along_major_dim=*/column_count);
+  }
+
+  // Load a tile of values from the RHS.  For the RHS a "tile" is a contiguous
+  // sequnce of `count` values, each one broadcasted to the vector width.
+  std::vector<llvm::Value*> LoadRhsTile(llvm::Value* offset, int64 count) {
+    llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset);
+    std::vector<llvm::Value*> result;
+    result.reserve(count);
+    for (int64 i = 0; i < count; i++) {
+      result.push_back(vsl_.LoadBroadcast(base_pointer, i));
+    }
+    return result;
+  }
+
+  void EmitInnerLoopTiled(TileLoader* lhs_tile_loader,
+                          const std::vector<llvm::Value*>& rhs_tile,
+                          int64 columns, bool is_first_column);
+
+  void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns,
+                             bool is_first_tiled_column);
+
+  PrimitiveType scalar_type_;
+  int64 tile_rows_;
+  int64 tile_cols_;
+  int64 m_;
+  int64 k_;
+  llvm::Value* lhs_;
+  llvm::Value* rhs_;
+  llvm::Value* result_;
+  llvm::IRBuilder<>* ir_builder_;
+  KernelSupportLibrary ksl_;
+  VectorSupportLibrary vsl_;
+};
+
+void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody(
+    llvm::Value* column, int64 column_count, bool is_first_column) {
+  TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column,
+                                                /*column_count=*/column_count);
+
+  std::vector<llvm::Value*> rhs_tile =
+      LoadRhsTile(column, /*count=*/column_count);
+  EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile,
+                     /*columns=*/column_count, is_first_column);
+  EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column);
+}
+
+void ColumnMajorMatrixVectorProductEmitter::Emit() {
+  // See the comment on the class declaration for the algorithm used here.
+  int64 column_remainder = k_ % tile_cols_;
+  int64 column_limit = k_ - column_remainder;
+
+  ksl_.For("dot.outer.tiled",
+           /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_,
+           [&](llvm::Value* column, bool is_first_column) {
+             EmitOuterLoopBody(column, tile_cols_, is_first_column);
+           });
+
+  if (column_remainder != 0) {
+    EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder,
+                      column_limit == 0);
+  }
+}
+
+void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
+    TileLoader* lhs_tile_loader, const std::vector<llvm::Value*>& rhs_tile,
+    int64 columns, bool is_first_column) {
+  int64 row_limit = m_ - (m_ % tile_rows_);
+
+  ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit,
+           /*step=*/tile_rows_, [&](llvm::Value* row) {
+             std::vector<llvm::Value*> lhs_tile =
+                 lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row);
+             llvm::Value* accumulator = is_first_column
+                                            ? vsl_.GetZeroVector()
+                                            : vsl_.LoadVector(result_, row);
+             for (int i = 0; i < columns; i++) {
+               accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator);
+             }
+             vsl_.StoreVector(accumulator, result_, row);
+           });
+}
+
+void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
+    llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) {
+  int64 row_start = m_ - (m_ % tile_rows_);
+  if (row_start == m_) {
+    return;
+  }
+
+  llvm::Value* columns_llvm = ir_builder_->getInt64(columns);
+
+  // for (col = current_tile_col; col < (columns + current_tile_col); col++)
+  //   for (row = row_start, row < m_; row++) {
+  //     result[row] += lhs[row, col] * rhs[col]
+  //     // Also take into account that if col is 0 then result[row] is not
+  //     // initialized.
+  //   }
+
+  ksl_.For(
+      "dot.inner.epilg.outer", /*start=*/current_tile_col,
+      /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col),
+      /*step=*/1, /*peel_first_iteration=*/false,
+      [&](llvm::Value* col, llvm::Value* is_first_scalar_col) {
+        llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col);
+        llvm::Value* total_offset =
+            ir_builder_->CreateMul(col, ir_builder_->getInt64(m_));
+        llvm::Value* lhs_base_pointer =
+            vsl_.ComputeOffsetPointer(lhs_, total_offset);
+        ksl_.For(
+            "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_,
+            /*step=*/1, [&](llvm::Value* scalar_row) {
+              llvm::Value* product = vsl_.Mul(
+                  vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element);
+              llvm::Value* setting_result_first_time = ir_builder_->CreateAnd(
+                  is_first_scalar_col,
+                  ir_builder_->getInt1(is_first_tiled_column));
+              ksl_.If(
+                  setting_result_first_time,
+                  [&]() { vsl_.StoreScalar(product, result_, scalar_row); },
+                  [&]() {
+                    vsl_.StoreScalar(
+                        vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product),
+                        result_, scalar_row);
+                  });
+            });
+      });
+}
+
+// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the
+// layout of the vector does not matter).  This implementation uses a tiling
+// scheme to improve performance.
+//
+// We logically separate the LHS matrix into four segments:
+//
+//   +----------------------+---+
+//   |                      |   |
+//   |                      |   |
+//   |         A            | B |
+//   |                      |   |
+//   |                      |   |
+//   |                      |   |
+//   +----------------------+---+
+//   |         C            | D |
+//   +----------------------+---+
+//
+// where A is the largest submatrix of the LHS that can be evenly dividied into
+// tiles.  For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have:
+//
+//   +---+---+---+---+
+//   |M00|M10|M20|M30|
+//   +---+---+---+---+       +--+--+--+--+
+//   |M01|M11|M21|M31| and   |V0|V1|V2|V3|
+//   +---+---+---+---+       +--+--+--+--+
+//   |M02|M12|M22|M32|
+//   +---+---+---+---+
+//   |M03|M13|M23|M33|
+//   +---+---+---+---+
+//
+// (Legend: rows are horizontal and columns are vertical; and each row is one
+// llvm::Value of a vector type)
+//
+// where:
+//
+//   a. The left tile is loaded from the row major left matrix.
+//   b. The right vector is loaded from the RHS vector.
+//
+// We keep 4 vector accumulators accumulating the following four vector
+// expressions as we iterate over the row dimension:
+//
+//   +------+------+------+------+
+//   |M0I*V0|M1I*V1|M2I*V2|M3I*V3|  for I in [0,4)
+//   +------+------+------+------+
+//
+// In the end we do a horizontal reduction over these 4 vector accumulators to
+// get 4 values in the result vector.
+//
+// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer
+// epilogue loop to deal with the C,D submatrix.
+class RowMajorMatrixVectorProductEmitter {
+ public:
+  RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows,
+                                     int64 tile_cols, int64 m, int64 k,
+                                     llvm::Value* lhs, llvm::Value* rhs,
+                                     llvm::Value* result,
+                                     llvm::IRBuilder<>* ir_builder)
+      : scalar_type_(scalar_type),
+        tile_rows_(tile_rows),
+        tile_cols_(tile_cols),
+        m_(m),
+        k_(k),
+        lhs_(lhs),
+        rhs_(rhs),
+        result_(result),
+        ir_builder_(ir_builder),
+        ksl_(ir_builder_),
+        vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
+    CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
+  }
+
+  void Emit();
+
+ private:
+  TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) {
+    return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_,
+                      /*matrix_size_along_minor_dim=*/k_,
+                      /*major_dim_offset=*/row_start,
+                      /*tile_size_along_major_dim=*/row_count);
+  }
+
+  void EmitOuterLoopBody(llvm::Value* row, int64 row_count);
+
+  void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows,
+                          std::vector<VectorVariable>* vector_accumulators);
+
+  void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows,
+                             std::vector<ScalarVariable>* scalar_accumulators);
+
+  PrimitiveType scalar_type_;
+  int64 tile_rows_;
+  int64 tile_cols_;
+  int64 m_;
+  int64 k_;
+  llvm::Value* lhs_;
+  llvm::Value* rhs_;
+  llvm::Value* result_;
+  llvm::IRBuilder<>* ir_builder_;
+  KernelSupportLibrary ksl_;
+  VectorSupportLibrary vsl_;
+};
+
+void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row,
+                                                           int64 row_count) {
+  TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row,
+                                                /*row_count=*/row_count);
+  std::vector<VectorVariable> vector_accumulators;
+  std::vector<ScalarVariable> scalar_accumulators;
+  for (int i = 0; i < row_count; i++) {
+    vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector());
+    scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar());
+  }
+  EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count,
+                     &vector_accumulators);
+  EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count,
+                        &scalar_accumulators);
+
+  for (int i = 0; i < row_count; i++) {
+    llvm::Value* result_value =
+        vsl_.Add(vsl_.AddReduce(vector_accumulators[i].Get()),
+                 scalar_accumulators[i].Get());
+    llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row);
+    vsl_.StoreScalar(result_value, result_, offset);
+  }
+}
+
+void RowMajorMatrixVectorProductEmitter::Emit() {
+  // See the comment on the class declaration for the algorithm used here.
+  int64 row_remainder = m_ % tile_rows_;
+  int64 row_limit = m_ - row_remainder;
+
+  ksl_.For("dot.outer.tiled",
+           /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_,
+           [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); });
+
+  if (row_remainder != 0) {
+    EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder);
+  }
+}
+
+void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled(
+    TileLoader* lhs_tile_loader, int64 rows,
+    std::vector<VectorVariable>* vector_accumulators) {
+  int64 column_limit = k_ - (k_ % tile_cols_);
+
+  ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit,
+           /*step=*/tile_cols_, [&](llvm::Value* col) {
+             std::vector<llvm::Value*> lhs_tile =
+                 lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col);
+             llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col);
+             for (int i = 0; i < rows; i++) {
+               llvm::Value* old_sum = (*vector_accumulators)[i].Get();
+               (*vector_accumulators)[i].Set(
+                   vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i])));
+             }
+           });
+}
+
+void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue(
+    llvm::Value* current_tile_row, int64 rows,
+    std::vector<ScalarVariable>* scalar_accumulators) {
+  int64 column_start = k_ - (k_ % tile_cols_);
+  if (column_start == k_) {
+    return;
+  }
+
+  for (int r = 0; r < rows; r++) {
+    llvm::Value* total_offset = ir_builder_->CreateMul(
+        ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row),
+        ir_builder_->getInt64(k_));
+    llvm::Value* lhs_base_pointer =
+        vsl_.ComputeOffsetPointer(lhs_, total_offset);
+    ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_,
+             /*step=*/1, [&](llvm::Value* scalar_col) {
+               llvm::Value* product =
+                   vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col),
+                            vsl_.LoadScalar(rhs_, scalar_col));
+               llvm::Value* old_value = (*scalar_accumulators)[r].Get();
+               (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product));
+             });
+  }
+}
+
+}  // namespace
+
 DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
                            bool transpose_rhs,
                            const llvm_ir::IrArray& target_array,
@@ -72,6 +518,88 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
 
 bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; }
 
+bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
+  if (dot_.shape().dimensions_size() != 2 ||
+      ProfitableToImplementDotInUntiledLlvmIr(dot_) ==
+          DotInLlvmIrProfitable::kYes) {
+    return false;
+  }
+
+  if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) &&
+      !primitive_util::IsIntegralType(dot_.shape().element_type())) {
+    return false;
+  }
+
+  MatMultDims mat_mult_dims = GetMatMultDims();
+  bool is_column_major_matrix_vector = false;
+  bool is_row_major_matrix_vector = false;
+
+  int64 m, k;
+  bool swap_operands;
+
+  if (mat_mult_dims.m == 1) {
+    bool rhs_effectively_row_major =
+        transpose_rhs_ ^ !mat_mult_dims.rhs_column_major;
+    if (rhs_effectively_row_major) {
+      k = mat_mult_dims.k;
+      m = mat_mult_dims.n;
+      is_column_major_matrix_vector = true;
+      swap_operands = true;
+    } else {
+      k = mat_mult_dims.k;
+      m = mat_mult_dims.n;
+      is_row_major_matrix_vector = true;
+      swap_operands = true;
+    }
+  }
+
+  if (mat_mult_dims.n == 1) {
+    bool lhs_effectively_column_major =
+        transpose_lhs_ ^ mat_mult_dims.lhs_column_major;
+    if (lhs_effectively_column_major) {
+      m = mat_mult_dims.m;
+      k = mat_mult_dims.k;
+      is_column_major_matrix_vector = true;
+      swap_operands = false;
+    } else {
+      m = mat_mult_dims.m;
+      k = mat_mult_dims.k;
+      is_row_major_matrix_vector = true;
+      swap_operands = false;
+    }
+  }
+
+  if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) {
+    return false;
+  }
+
+  if (is_column_major_matrix_vector) {
+    VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
+            << " and k = " << k;
+    ColumnMajorMatrixVectorProductEmitter emitter(
+        dot_.shape().element_type(), 8, 8, m, k,
+        swap_operands ? rhs_array_.GetBasePointer()
+                      : lhs_array_.GetBasePointer(),
+        swap_operands ? lhs_array_.GetBasePointer()
+                      : rhs_array_.GetBasePointer(),
+        target_array_.GetBasePointer(), ir_builder_);
+    emitter.Emit();
+  } else {
+    VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
+            << " and k = " << k;
+    RowMajorMatrixVectorProductEmitter emitter(
+        dot_.shape().element_type(), 8, 8, m, k,
+        swap_operands ? rhs_array_.GetBasePointer()
+                      : lhs_array_.GetBasePointer(),
+        swap_operands ? lhs_array_.GetBasePointer()
+                      : rhs_array_.GetBasePointer(),
+        target_array_.GetBasePointer(), ir_builder_);
+    emitter.Emit();
+  }
+
+  return true;
+}
+
 tensorflow::Status DotOpEmitter::Emit() {
   // The dot operation performs a sum of products over dimension 0 of the left
   // hand side operand and dimension 1 of the right hand side operand.
@@ -105,6 +633,10 @@ tensorflow::Status DotOpEmitter::Emit() {
     return EmitScalarDot();
   }
 
+  if (EmitLlvmIrDotIfProfitable()) {
+    return Status::OK();
+  }
+
   if (PotentiallyImplementedAsEigenDot(dot_)) {
     return EmitCallToRuntime();
   }
@@ -340,22 +872,17 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
   //
   // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
 
-  const Shape& lhs_shape = lhs_array_.GetShape();
-  const Shape& rhs_shape = rhs_array_.GetShape();
+  MatMultDims mat_mult_dims = GetMatMultDims();
 
-  CHECK(LayoutUtil::Equal(lhs_shape.layout(), rhs_shape.layout()));
+  CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major);
 
-  int64 m = lhs_shape.dimensions(transpose_lhs_ ? 1 : 0);
-  int64 k = lhs_shape.dimensions(transpose_lhs_ ? 0 : 1);
-  int64 n = rhs_shape.dimensions(transpose_rhs_ ? 0 : 1);
   const llvm_ir::IrArray* lhs = &lhs_array_;
   const llvm_ir::IrArray* rhs = &rhs_array_;
   bool transpose_lhs = transpose_lhs_;
   bool transpose_rhs = transpose_rhs_;
 
-  bool is_column_major = lhs_shape.layout().minor_to_major(0) == 0;
-  if (!is_column_major) {
-    std::swap(m, n);
+  if (!mat_mult_dims.lhs_column_major) {
+    std::swap(mat_mult_dims.m, mat_mult_dims.n);
     std::swap(lhs, rhs);
     std::swap(transpose_lhs, transpose_rhs);
   }
@@ -367,12 +894,27 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
                                   float_ptr_type),
        ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
        ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
-       ir_builder_->getInt64(m), ir_builder_->getInt64(n),
-       ir_builder_->getInt64(k), ir_builder_->getInt32(transpose_lhs),
+       ir_builder_->getInt64(mat_mult_dims.m),
+       ir_builder_->getInt64(mat_mult_dims.n),
+       ir_builder_->getInt64(mat_mult_dims.k),
+       ir_builder_->getInt32(transpose_lhs),
        ir_builder_->getInt32(transpose_rhs)});
   return tensorflow::Status::OK();
 }
 
+DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const {
+  CHECK_EQ(dot_.shape().dimensions_size(), 2);
+
+  const Shape& lhs_shape = lhs_array_.GetShape();
+  const Shape& rhs_shape = rhs_array_.GetShape();
+
+  return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0),
+          lhs_shape.dimensions(transpose_lhs_ ? 0 : 1),
+          rhs_shape.dimensions(transpose_rhs_ ? 0 : 1),
+          lhs_shape.layout().minor_to_major(0) == 0,
+          rhs_shape.layout().minor_to_major(0) == 0};
+}
+
 llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
     llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array,
     int64 reduction_dimension, tensorflow::StringPiece name_suffix) {
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index cfc10660453..182e1b8c680 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -59,6 +59,10 @@ class DotOpEmitter {
   // LHS and RHS) and store the results in the target.
   tensorflow::Status EmitScalarDot();
 
+  // Emit an LLVM IR implementation of the dot operation if we can.  Returns
+  // true if an LLVM IR implementation was emitted.
+  bool EmitLlvmIrDotIfProfitable();
+
   // Emits a call to the CPU runtime to perform the matrix multiply.
   tensorflow::Status EmitCallToRuntime();
 
@@ -77,6 +81,30 @@ class DotOpEmitter {
   // no padding, and a rank of two.
   bool ShapesAreLegalForRuntimeDot() const;
 
+  // Represents the dimensions of a matrix-matrix multiply operation.
+  struct MatMultDims {
+    // The number of rows in the LHS.
+    int64 m;
+
+    // The number of columns in the LHS, which is also must be equal to the
+    // number of rows in the RHS.
+    int64 k;
+
+    // The number of columns on the RHS.
+    int64 n;
+
+    // True if the LHS matrix column major.
+    bool lhs_column_major;
+
+    // True if the RHS matrix column major.
+    bool rhs_column_major;
+  };
+
+  // Get the MatMultDims instance for the dot product this DotOpEmitter
+  // represents.  Precondition: the dot is of rank 2 (and thus its operands are
+  // of rank 2 as well).
+  MatMultDims GetMatMultDims() const;
+
   const HloInstruction& dot_;
   const bool transpose_lhs_;
   const bool transpose_rhs_;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index b99b36a55ee..7149a193107 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -105,7 +105,9 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
       return false;
     }
 
-    if (ProfitableToImplementDotInLlvmIr(hlo) == DotInLlvmIrProfitable::kYes) {
+    if (ProfitableToImplementDotInUntiledLlvmIr(hlo) ==
+            DotInLlvmIrProfitable::kYes ||
+        ProfitableToImplementDotInTiledLlvmIr(hlo)) {
       return false;
     }
 
@@ -136,7 +138,7 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
   return false;
 }
 
-DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
+DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
     const HloInstruction& dot) {
   if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) {
     const Shape& result_shape = dot.shape();
@@ -178,5 +180,16 @@ DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
   return DotInLlvmIrProfitable::kNo;
 }
 
+bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
+  // Any Matrix-Vector product of floating point or integral type, or
+  // a transpose-dot fusion of the same can be lowered to a tiled LLVM
+  // IR implementation.
+  const Shape& shape = dot.shape();
+  return shape.dimensions_size() == 2 &&
+         (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) &&
+         (primitive_util::IsFloatingPointType(shape.element_type()) ||
+          primitive_util::IsIntegralType(shape.element_type()));
+}
+
 }  // namespace cpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
index 66656ed9976..cbe07a7c2b9 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
@@ -29,16 +29,21 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot);
 enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs };
 
 // Returns a value to indicate if (and under what conditions) will lowering
-// |dot| as a pure LLVM IR dot operation be profitable over calling into Eigen.
-// Possible return values are:
+// |dot| as a untiled LLVM IR dot operation be profitable over calling into
+// Eigen or emitting a tiled LLVM IR implementation.  Possible return values
+// are:
 //
 //  * DotInLlvmIrProfitable::kYes - always profitable.
 //  * DotInLlvmIrProfitable::kNo - never profitable.
 //  * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make
 //    the Rhs layout column major.
-DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr(
+DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
     const HloInstruction& dot);
 
+// Returns true to indicate that we can generate a tiled LLVM IR implementation
+// for |dot|.
+bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot);
+
 }  // namespace cpu
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
index c446b6b792a..b75ca34e0a8 100644
--- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
@@ -51,7 +51,7 @@ Status CpuLayoutAssignment::AddBackendConstraints(
   tensorflow::gtl::FlatMap<const HloInstruction*, bool>
       should_make_rhs_col_major_cache;
   auto should_make_rhs_col_major = [&](const HloInstruction& instruction) {
-    if (ProfitableToImplementDotInLlvmIr(instruction) !=
+    if (ProfitableToImplementDotInUntiledLlvmIr(instruction) !=
         DotInLlvmIrProfitable::kWithColumnMajorRhs) {
       return false;
     }
@@ -68,7 +68,7 @@ Status CpuLayoutAssignment::AddBackendConstraints(
 
     bool result = std::all_of(
         rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) {
-          return ProfitableToImplementDotInLlvmIr(*user) ==
+          return ProfitableToImplementDotInUntiledLlvmIr(*user) ==
                      DotInLlvmIrProfitable::kWithColumnMajorRhs &&
                  user->operand(0) != rhs;
         });
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 075d4a1ab5e..8f24bb1718a 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -155,6 +155,30 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "vector_support_library",
+    srcs = ["vector_support_library.cc"],
+    hdrs = ["vector_support_library.h"],
+    deps = [
+        "//tensorflow/compiler/xla:types",
+        "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+        "@llvm//:core",
+    ],
+)
+
+cc_library(
+    name = "kernel_support_library",
+    srcs = ["kernel_support_library.cc"],
+    hdrs = ["kernel_support_library.h"],
+    deps = [
+        ":llvm_loop",
+        "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+        "//tensorflow/core:lib",
+        "@llvm//:core",
+    ],
+)
+
 # -----------------------------------------------------------------------------
 
 filegroup(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
new file mode 100644
index 00000000000..123a327d4db
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+
+namespace xla {
+void KernelSupportLibrary::For(
+    tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+    llvm::Value* step,
+    const std::function<void(llvm::Value*, bool)>& for_body_generator) {
+  If(ir_builder_->CreateICmpSLT(start, end), [&]() {
+    for_body_generator(start, /*is_first_iteration=*/true);
+    For(name, ir_builder_->CreateAdd(start, step), end, step,
+        [&](llvm::Value* iv) { for_body_generator(iv, false); });
+  });
+}
+
+void KernelSupportLibrary::For(
+    tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+    llvm::Value* step, bool peel_first_iteration,
+    const std::function<void(llvm::Value*, llvm::Value*)>& for_body_generator) {
+  if (peel_first_iteration) {
+    For(name, start, end, step, true,
+        [&](llvm::Value* indvar, bool is_first_iteration) {
+          for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration));
+        });
+  } else {
+    std::unique_ptr<llvm_ir::ForLoop> loop = llvm_ir::ForLoop::EmitForLoop(
+        name, start, end, step, ir_builder_, prevent_unrolling_);
+    ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back());
+    for_body_generator(loop->GetIndVarValue(),
+                       /*is_first_iteration=*/ir_builder_->CreateICmpEQ(
+                           loop->GetIndVarValue(), start));
+    llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_);
+  }
+}
+
+void KernelSupportLibrary::If(
+    llvm::Value* condition, const std::function<void()>& true_block_generator,
+    const std::function<void()>& false_block_generator) {
+  llvm_ir::LlvmIfData if_data =
+      llvm_ir::EmitIfThenElse(condition, "", ir_builder_);
+  ir_builder_->SetInsertPoint(&if_data.true_block->back());
+  true_block_generator();
+  ir_builder_->SetInsertPoint(&if_data.false_block->back());
+  false_block_generator();
+  llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_);
+}
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
new file mode 100644
index 00000000000..25aa2291a66
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -0,0 +1,124 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
+
+#include <string>
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace xla {
+// A thin wrapper around llvm_loop.h to make code generating structured control
+// flow more readable.
+class KernelSupportLibrary {
+ public:
+  // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR.
+  // If `prevent_unrolling` is true then unrolling is explicitly disabled on
+  // every loop generated by this instance of KernelSupportLibrary.
+  explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder,
+                                bool prevent_unrolling = true)
+      : ir_builder_(ir_builder), prevent_unrolling_(prevent_unrolling) {}
+
+  // Generates the following control flow structure:
+  //
+  //   if (`start` < `end`) {
+  //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`;
+  //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
+  //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
+  //   }
+  void For(
+      tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+      llvm::Value* step,
+      const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
+          for_body_generator);
+
+  void For(
+      tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+      const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
+          for_body_generator) {
+    For(name, /*start=*/ir_builder_->getInt64(start),
+        /*end=*/ir_builder_->getInt64(end),
+        /*step=*/ir_builder_->getInt64(step), for_body_generator);
+  }
+
+  // Generates the following control flow structure if `peel_first_iteration` is
+  // true:
+  //
+  //   if (`start` < `end`) {
+  //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`;
+  //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
+  //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`;
+  //   }
+  //
+  // and the following if `peel_first_iteration` is false:
+  //
+  //   for (i64 i = `start`; i s< `end`; i += `step`)
+  //     `for_body_generator(/*ind_var=*/,i,
+  //                         /*is_first_iteration=*/,(i != `start`))`;
+  void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+           llvm::Value* step, bool peel_first_iteration,
+           const std::function<void(llvm::Value* ind_var,
+                                    llvm::Value* is_first_iteration)>&
+               for_body_generator);
+
+  void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+           int64 step, bool peel_first_iteration,
+           const std::function<void(llvm::Value* ind_var,
+                                    llvm::Value* is_first_iteration)>&
+               for_body_generator) {
+    For(name, /*start=*/start, /*end=*/end,
+        /*step=*/ir_builder_->getInt64(step), peel_first_iteration,
+        for_body_generator);
+  }
+
+  void For(
+      tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
+      llvm::Value* step,
+      const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
+    For(name, start, end, step,
+        /*peel_first_iteration=*/false,
+        [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); });
+  }
+
+  void For(
+      tensorflow::StringPiece name, int64 start, int64 end, int64 step,
+      const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
+    For(name, /*start=*/ir_builder_->getInt64(start),
+        /*end=*/ir_builder_->getInt64(end),
+        /*step=*/ir_builder_->getInt64(step), for_body_generator);
+  }
+
+  // Generates the following control flow structure:
+  //
+  //   if (`condition`)
+  //     `true_block_generator()`;
+  //   else
+  //      `false_block_generator()`;
+  void If(llvm::Value* condition,
+          const std::function<void()>& true_block_generator,
+          const std::function<void()>& false_block_generator = []() {});
+
+ private:
+  llvm::IRBuilder<>* ir_builder_;
+  bool prevent_unrolling_;
+};
+}  // namespace xla
+
+#endif  // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
index 956c0d5f052..d95409e3999 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -537,6 +537,14 @@ void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
   builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
 }
 
+void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
+  if (llvm::Instruction* terminator = blk->getTerminator()) {
+    builder->SetInsertPoint(terminator);
+  } else {
+    builder->SetInsertPoint(blk);
+  }
+}
+
 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
                        llvm::IRBuilder<>* builder) {
   auto size = rotand->getType()->getPrimitiveSizeInBits();
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
index 304192b58e9..f70d9f88b34 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -243,6 +243,8 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
 
 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
 
+void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
+
 // Create a bitwise rotation of `rotand` by `rotor`.
 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
                        llvm::IRBuilder<>* builder);
diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc
new file mode 100644
index 00000000000..e8c6a83618e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc
@@ -0,0 +1,150 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h"
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+
+namespace xla {
+VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
+                                           int64 vector_size,
+                                           llvm::IRBuilder<>* ir_builder,
+                                           std::string name)
+    : vector_size_(vector_size),
+      primitive_type_(primitive_type),
+      ir_builder_(ir_builder),
+      name_(std::move(name)) {
+  scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
+      primitive_type, ir_builder_->GetInsertBlock()->getModule());
+  scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
+  vector_type_ = llvm::VectorType::get(scalar_type_, vector_size);
+  vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
+}
+
+llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
+  if (scalar_type_->isFloatingPointTy()) {
+    return ir_builder()->CreateFMul(lhs, rhs, name());
+  } else {
+    return ir_builder()->CreateMul(lhs, rhs, name());
+  }
+}
+
+llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
+  if (scalar_type_->isFloatingPointTy()) {
+    return ir_builder()->CreateFAdd(lhs, rhs, name());
+  } else {
+    return ir_builder()->CreateAdd(lhs, rhs, name());
+  }
+}
+
+llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
+    llvm::Value* base_pointer, llvm::Value* offset_elements) {
+  if (base_pointer->getType() != scalar_pointer_type()) {
+    base_pointer = ir_builder()->CreateBitCast(base_pointer,
+                                               scalar_pointer_type(), name());
+  }
+  return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements},
+                                         name());
+}
+
+llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
+  if (pointer->getType() != vector_pointer_type()) {
+    pointer =
+        ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name());
+  }
+  return ir_builder()->CreateAlignedLoad(
+      pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
+}
+
+llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
+  if (pointer->getType() != scalar_pointer_type()) {
+    pointer =
+        ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
+  }
+  return ir_builder()->CreateAlignedLoad(
+      pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
+}
+
+void VectorSupportLibrary::StoreVector(llvm::Value* value,
+                                       llvm::Value* pointer) {
+  if (pointer->getType() != vector_pointer_type()) {
+    pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type());
+  }
+  ir_builder()->CreateAlignedStore(
+      value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
+}
+
+void VectorSupportLibrary::StoreScalar(llvm::Value* value,
+                                       llvm::Value* pointer) {
+  if (pointer->getType() != scalar_pointer_type()) {
+    pointer =
+        ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
+  }
+  ir_builder()->CreateAlignedStore(
+      value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
+}
+
+llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
+  if (pointer->getType() != scalar_pointer_type()) {
+    pointer =
+        ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
+  }
+  return ir_builder()->CreateVectorSplat(
+      vector_size(), ir_builder()->CreateLoad(pointer), name());
+}
+
+llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
+  llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
+  for (unsigned i = vector_size(); i != 1; i >>= 1) {
+    // On every iteration, we shuffle half of the remaining lanes to the top
+    // half of shuffle, and add two old and the new vector.
+
+    for (unsigned j = 0; j < vector_size(); ++j) {
+      if (j < (i / 2)) {
+        mask[j] = ir_builder()->getInt32(i / 2 + j);
+      } else {
+        mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty());
+      }
+    }
+
+    llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector(
+        vector, llvm::UndefValue::get(vector_type()),
+        llvm::ConstantVector::get(mask), "");
+    vector = Add(vector, half_remaining_lanes);
+  }
+
+  return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0),
+                                            name());
+}
+
+llvm::Value* VectorSupportLibrary::GetZeroVector() {
+  return llvm::Constant::getNullValue(vector_type());
+}
+
+llvm::Value* VectorSupportLibrary::GetZeroScalar() {
+  return llvm::Constant::getNullValue(scalar_type());
+}
+
+LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder)
+    : ir_builder_(ir_builder) {
+  alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_);
+}
+
+llvm::Value* LlvmVariable::Get() { return ir_builder_->CreateLoad(alloca_); }
+
+void LlvmVariable::Set(llvm::Value* new_value) {
+  ir_builder_->CreateStore(new_value, alloca_);
+}
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h
new file mode 100644
index 00000000000..3072677ab05
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h
@@ -0,0 +1,174 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_
+
+#include <string>
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+// A thin wrapper around llvm_util.h to make code generating vector math flow
+// more readable.
+class VectorSupportLibrary {
+ public:
+  // This VectorSupportLibrary instance remembers `primitive_type` and
+  // `vector_size`, and these are implicitly used by the methods on this
+  // instance (i.e. LoadVector will load a vector of type <`vector_size` x
+  // `primitive_type`>).
+  VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
+                       llvm::IRBuilder<>* ir_builder, std::string name);
+
+  llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
+  llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
+    return Mul(ir_builder()->getInt64(lhs), rhs);
+  }
+
+  llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
+  llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
+    return Add(ir_builder()->getInt64(lhs), rhs);
+  }
+
+  llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
+    return Add(c, Mul(a, b));
+  }
+
+  llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
+                                    llvm::Value* offset_elements);
+  llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
+                                    int64 offset_elements) {
+    return ComputeOffsetPointer(base_pointer,
+                                ir_builder()->getInt64(offset_elements));
+  }
+
+  llvm::Value* LoadVector(llvm::Value* pointer);
+
+  llvm::Value* LoadVector(llvm::Value* base_pointer,
+                          llvm::Value* offset_elements) {
+    return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
+  }
+
+  llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
+    return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements));
+  }
+
+  llvm::Value* LoadScalar(llvm::Value* pointer);
+
+  llvm::Value* LoadScalar(llvm::Value* base_pointer,
+                          llvm::Value* offset_elements) {
+    return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
+  }
+
+  llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
+    return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements));
+  }
+
+  void StoreVector(llvm::Value* value, llvm::Value* pointer);
+
+  void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
+                   llvm::Value* offset_elements) {
+    StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
+  }
+
+  void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
+                   int64 offset_elements) {
+    StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements));
+  }
+
+  void StoreScalar(llvm::Value* value, llvm::Value* pointer);
+  void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
+                   llvm::Value* offset_elements) {
+    StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
+  }
+
+  void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
+                   int64 offset_elements) {
+    StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements));
+  }
+
+  llvm::Value* LoadBroadcast(llvm::Value* pointer);
+  llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
+                             llvm::Value* offset_elements) {
+    return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
+  }
+  llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
+    return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements));
+  }
+
+  llvm::Value* AddReduce(llvm::Value* vector);
+
+  llvm::Value* GetZeroVector();
+  llvm::Value* GetZeroScalar();
+
+  llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
+  int64 vector_size() const { return vector_size_; }
+  llvm::Type* vector_type() const { return vector_type_; }
+  llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
+  llvm::Type* scalar_type() const { return scalar_type_; }
+  llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
+
+  const std::string& name() const { return name_; }
+
+ private:
+  int64 vector_size_;
+  PrimitiveType primitive_type_;
+  llvm::IRBuilder<>* ir_builder_;
+  llvm::Type* vector_type_;
+  llvm::Type* vector_pointer_type_;
+  llvm::Type* scalar_type_;
+  llvm::Type* scalar_pointer_type_;
+  std::string name_;
+};
+
+// This wraps an alloca-backed stack variable which LLVM's SSA construction pass
+// can later convert to a SSA value.
+class LlvmVariable {
+ public:
+  LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder);
+
+  llvm::Value* Get();
+  void Set(llvm::Value* new_value);
+
+ private:
+  llvm::AllocaInst* alloca_;
+  llvm::IRBuilder<>* ir_builder_;
+};
+
+class VectorVariable : public LlvmVariable {
+ public:
+  VectorVariable(VectorSupportLibrary* vector_support,
+                 llvm::Value* initial_value)
+      : LlvmVariable(vector_support->vector_type(),
+                     vector_support->ir_builder()) {
+    Set(initial_value);
+  }
+};
+
+class ScalarVariable : public LlvmVariable {
+ public:
+  ScalarVariable(VectorSupportLibrary* vector_support,
+                 llvm::Value* initial_value)
+      : LlvmVariable(vector_support->scalar_type(),
+                     vector_support->ir_builder()) {
+    Set(initial_value);
+  }
+};
+}  // namespace xla
+
+#endif  // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index cf089d748dc..c4e422b506b 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -277,6 +277,62 @@ XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) {
   TestMatrixDot(260, 3, 520, false, false);
 }
 
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) {
+  TestMatrixDot(1, 8, 8, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) {
+  TestMatrixDot(1, 130, 8, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) {
+  TestMatrixDot(1, 8, 130, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) {
+  TestMatrixDot(1, 290, 130, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) {
+  TestMatrixDot(2, 1, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) {
+  TestMatrixDot(8, 8, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) {
+  TestMatrixDot(16, 1, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) {
+  TestMatrixDot(16, 3, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) {
+  TestMatrixDot(3, 3, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) {
+  TestMatrixDot(29, 29, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) {
+  TestMatrixDot(1, 8, 2, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) {
+  TestMatrixDot(1, 2, 8, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) {
+  TestMatrixDot(259, 258, 1, true, true);
+}
+
+XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) {
+  TestMatrixDot(259, 258, 1, false, true);
+}
+
 XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) {
   constexpr bool kLhsRowMajor = false;
   constexpr bool kRhsRowMajor = false;
@@ -361,6 +417,30 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
   TestNonsquareMatrixDot<complex64>();
 }
 
+XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
+  auto lhs_handle =
+      client_
+          ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
+              {{1.0, 2.0, 3.0, -4.0}}, {1, 0}))
+          .ConsumeValueOrDie();
+  auto rhs_handle =
+      client_
+          ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
+              {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, {1, 0}))
+          .ConsumeValueOrDie();
+
+  ComputationBuilder builder(client_, TestName());
+  auto prim_type = primitive_util::NativeToPrimitiveType<complex64>();
+  auto result = builder.Dot(
+      builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"),
+      builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs"));
+
+  Array2D<complex64> expected({{30.0, -2.0}});
+
+  ComputeAndCompareR2<complex64>(
+      &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
+}
+
 XLA_TEST_F(DotOperationTest, ConcurrentMatMul) {
   ComputationBuilder builder(client_, TestName());
   auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});

From 505cbf22813dbd17482170562eb91e09d652b835 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 11:35:53 -0800
Subject: [PATCH 037/115] Go: Update generated wrapper functions for TensorFlow
 ops.

PiperOrigin-RevId: 175037663
---
 tensorflow/go/op/wrappers.go | 56 ++++++++++++++++++++++++++++++++++++
 1 file changed, 56 insertions(+)

diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index bdfad485673..eb79da53840 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -62,6 +62,29 @@ func WriteScalarSummary(scope *Scope, writer tf.Output, global_step tf.Output, t
 	return scope.AddOperation(opspec)
 }
 
+// Outputs a `tf.Event` protocol buffer.
+//
+// When CreateSummaryDbWriter is being used, this op can be useful for
+// importing data from event logs.
+//
+// Arguments:
+//	writer: A handle to a summary writer.
+//	event: A string containing a binary-encoded tf.Event proto.
+//
+// Returns the created operation.
+func ImportEvent(scope *Scope, writer tf.Output, event tf.Output) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "ImportEvent",
+		Input: []tf.Input{
+			writer, event,
+		},
+	}
+	return scope.AddOperation(opspec)
+}
+
 // Outputs a `Summary` protocol buffer with a tensor.
 //
 // Arguments:
@@ -22454,6 +22477,39 @@ func QuantizedBiasAdd(scope *Scope, input tf.Output, bias tf.Output, min_input t
 	return op.Output(0), op.Output(1), op.Output(2)
 }
 
+// Creates summary database writer accessible by given resource handle.
+//
+// This can be used to write tensors from the execution graph directly
+// to a database. Only SQLite is supported right now. This function
+// will create the schema if it doesn't exist. Entries in the Users,
+// Experiments, and Runs tables will be created automatically if they
+// don't already exist.
+//
+// Arguments:
+//	writer: Handle to SummaryWriter resource to overwrite.
+//	db_uri: For example "file:/tmp/foo.sqlite".
+//	experiment_name: Can't contain ASCII control characters or <>. Case
+// sensitive. If empty, then the Run will not be associated with any
+// Experiment.
+//	run_name: Can't contain ASCII control characters or <>. Case sensitive.
+// If empty, then each Tag will not be associated with any Run.
+//	user_name: Must be valid as both a DNS label and Linux username. If
+// empty, then the Experiment will not be associated with any User.
+//
+// Returns the created operation.
+func CreateSummaryDbWriter(scope *Scope, writer tf.Output, db_uri tf.Output, experiment_name tf.Output, run_name tf.Output, user_name tf.Output) (o *tf.Operation) {
+	if scope.Err() != nil {
+		return
+	}
+	opspec := tf.OpSpec{
+		Type: "CreateSummaryDbWriter",
+		Input: []tf.Input{
+			writer, db_uri, experiment_name, run_name, user_name,
+		},
+	}
+	return scope.AddOperation(opspec)
+}
+
 // HistogramFixedWidthAttr is an optional argument to HistogramFixedWidth.
 type HistogramFixedWidthAttr func(optionalAttr)
 

From 2eb8575a8d7bf7efcceb8283ba420c020ef35457 Mon Sep 17 00:00:00 2001
From: Michael Case <mikecase@google.com>
Date: Wed, 8 Nov 2017 11:54:14 -0800
Subject: [PATCH 038/115] Having with_gcp_support and windows causes build
 error.

Multiple statements in a select statement should not be able to
be true at the same time (unless one rule is more 'specific'
than another).

PiperOrigin-RevId: 175040618
---
 tensorflow/BUILD                              | 91 ++++++++++++++++---
 .../core/platform/default/build_config.bzl    | 21 +++--
 2 files changed, 91 insertions(+), 21 deletions(-)

diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index f2cdf37dbf6..5a408db94e1 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -110,7 +110,7 @@ config_setting(
 
 config_setting(
     name = "no_tensorflow_py_deps",
-    values = {"define": "no_tensorflow_py_deps=true"},
+    define_values = {"no_tensorflow_py_deps": "true"},
     visibility = ["//visibility:public"],
 )
 
@@ -166,55 +166,116 @@ config_setting(
 # TODO(jhseu): Enable on other platforms other than Linux.
 config_setting(
     name = "with_jemalloc_linux_x86_64",
-    values = {
-        "cpu": "k8",
-        "define": "with_jemalloc=true",
-    },
+    define_values = {"with_jemalloc": "true"},
+    values = {"cpu": "k8"},
     visibility = ["//visibility:public"],
 )
 
 config_setting(
     name = "with_jemalloc_linux_ppc64le",
-    values = {
-        "cpu": "ppc",
-        "define": "with_jemalloc=true",
-    },
+    define_values = {"with_jemalloc": "true"},
+    values = {"cpu": "ppc"},
     visibility = ["//visibility:public"],
 )
 
 config_setting(
     name = "with_gcp_support",
-    values = {"define": "with_gcp_support=true"},
+    define_values = {"with_gcp_support": "true"},
     visibility = ["//visibility:public"],
 )
 
 config_setting(
     name = "with_hdfs_support",
-    values = {"define": "with_hdfs_support=true"},
+    define_values = {"with_hdfs_support": "true"},
     visibility = ["//visibility:public"],
 )
 
 config_setting(
     name = "with_s3_support",
-    values = {"define": "with_s3_support=true"},
+    define_values = {"with_s3_support": "true"},
+    visibility = ["//visibility:public"],
+)
+
+# Crosses between platforms and file system libraries not supported on those
+# platforms due to limitations in nested select() statements.
+config_setting(
+    name = "with_gcp_support_windows_override",
+    define_values = {"with_gcp_support": "true"},
+    values = {"cpu": "x64_windows"},
+    visibility = ["//visibility:public"],
+)
+
+config_setting(
+    name = "with_hdfs_support_windows_override",
+    define_values = {"with_hdfs_support": "true"},
+    values = {"cpu": "x64_windows"},
+    visibility = ["//visibility:public"],
+)
+
+config_setting(
+    name = "with_s3_support_windows_override",
+    define_values = {"with_s3_support": "true"},
+    values = {"cpu": "x64_windows"},
+    visibility = ["//visibility:public"],
+)
+
+config_setting(
+    name = "with_gcp_support_android_override",
+    define_values = {"with_gcp_support": "true"},
+    values = {"crosstool_top": "//external:android/crosstool"},
+    visibility = ["//visibility:public"],
+)
+
+config_setting(
+    name = "with_hdfs_support_android_override",
+    define_values = {"with_hdfs_support": "true"},
+    values = {"crosstool_top": "//external:android/crosstool"},
+    visibility = ["//visibility:public"],
+)
+
+config_setting(
+    name = "with_s3_support_android_override",
+    define_values = {"with_s3_support": "true"},
+    values = {"crosstool_top": "//external:android/crosstool"},
+    visibility = ["//visibility:public"],
+)
+
+config_setting(
+    name = "with_gcp_support_ios_override",
+    define_values = {"with_gcp_support": "true"},
+    values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
+    visibility = ["//visibility:public"],
+)
+
+config_setting(
+    name = "with_hdfs_support_ios_override",
+    define_values = {"with_hdfs_support": "true"},
+    values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
+    visibility = ["//visibility:public"],
+)
+
+config_setting(
+    name = "with_s3_support_ios_override",
+    define_values = {"with_s3_support": "true"},
+    values = {"crosstool_top": "//tools/osx/crosstool:crosstool"},
     visibility = ["//visibility:public"],
 )
 
 config_setting(
     name = "with_xla_support",
-    values = {"define": "with_xla_support=true"},
+    define_values = {"with_xla_support": "true"},
     visibility = ["//visibility:public"],
 )
 
 config_setting(
     name = "with_gdr_support",
-    values = {"define": "with_gdr_support=true"},
+    define_values = {"with_gdr_support": "true"},
     visibility = ["//visibility:public"],
 )
 
 config_setting(
     name = "with_verbs_support",
-    values = {"define": "with_verbs_support=true"},
+    define_values = {"with_verbs_support": "true"},
     visibility = ["//visibility:public"],
 )
 
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 6225c2c705f..5eeb861bddf 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -458,16 +458,25 @@ def tf_additional_lib_deps():
 
 def tf_additional_core_deps():
   return select({
+      "//tensorflow:with_gcp_support_windows_override": [],
+      "//tensorflow:with_gcp_support_android_override": [],
+      "//tensorflow:with_gcp_support_ios_override": [],
       "//tensorflow:with_gcp_support": [
           "//tensorflow/core/platform/cloud:gcs_file_system",
       ],
       "//conditions:default": [],
   }) + select({
+      "//tensorflow:with_hdfs_support_windows_override": [],
+      "//tensorflow:with_hdfs_support_android_override": [],
+      "//tensorflow:with_hdfs_support_ios_override": [],
       "//tensorflow:with_hdfs_support": [
           "//tensorflow/core/platform/hadoop:hadoop_file_system",
       ],
       "//conditions:default": [],
   }) + select({
+      "//tensorflow:with_s3_support_windows_override": [],
+      "//tensorflow:with_s3_support_android_override": [],
+      "//tensorflow:with_s3_support_ios_override": [],
       "//tensorflow:with_s3_support": [
           "//tensorflow/core/platform/s3:s3_file_system",
       ],
@@ -477,9 +486,9 @@ def tf_additional_core_deps():
 # TODO(jart, jhseu): Delete when GCP is default on.
 def tf_additional_cloud_op_deps():
   return select({
-      "//tensorflow:windows": [],
-      "//tensorflow:android": [],
-      "//tensorflow:ios": [],
+      "//tensorflow:with_gcp_support_windows_override": [],
+      "//tensorflow:with_gcp_support_android_override": [],
+      "//tensorflow:with_gcp_support_ios_override": [],
       "//tensorflow:with_gcp_support": [
         "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
       ],
@@ -489,9 +498,9 @@ def tf_additional_cloud_op_deps():
 # TODO(jart, jhseu): Delete when GCP is default on.
 def tf_additional_cloud_kernel_deps():
   return select({
-      "//tensorflow:windows": [],
-      "//tensorflow:android": [],
-      "//tensorflow:ios": [],
+      "//tensorflow:with_gcp_support_windows_override": [],
+      "//tensorflow:with_gcp_support_android_override": [],
+      "//tensorflow:with_gcp_support_ios_override": [],
       "//tensorflow:with_gcp_support": [
         "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
       ],

From b5634b5e071e94d876d52ce7837dae3c5f37c9ba Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 12:03:44 -0800
Subject: [PATCH 039/115] Supports logits as a Tensor in MultiHead.

PiperOrigin-RevId: 175042091
---
 tensorflow/contrib/estimator/BUILD            |   5 +-
 .../estimator/python/estimator/multi_head.py  |  67 +++++--
 .../python/estimator/multi_head_test.py       | 188 +++++++++++++++++-
 3 files changed, 244 insertions(+), 16 deletions(-)

diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 6eb2cfdaca7..bc67ef83541 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -204,10 +204,13 @@ py_library(
     ],
     srcs_version = "PY2AND3",
     deps = [
+        "//tensorflow/python:array_ops",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:math_ops",
+        "//tensorflow/python:summary",
         "//tensorflow/python/estimator:head",
+        "//tensorflow/python/estimator:metric_keys",
         "//tensorflow/python/estimator:model_fn",
         "//tensorflow/python/saved_model:signature_constants",
         "@six_archive//:six",
@@ -229,7 +232,7 @@ py_test(
         "//tensorflow/python:string_ops",
         "//tensorflow/python/estimator:metric_keys",
         "//tensorflow/python/estimator:model_fn",
-        "//tensorflow/python/ops/losses",
+        "//tensorflow/python/estimator:prediction_keys",
         "//tensorflow/python/saved_model:signature_constants",
         "//third_party/py/numpy",
         "@six_archive//:six",
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py
index 69dbfcee62a..73bae5acf9c 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py
@@ -22,10 +22,13 @@ import six
 
 from tensorflow.python.estimator import model_fn
 from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.canned import metric_keys
 from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.summary import summary
 
 
 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -72,6 +75,23 @@ def multi_head(heads, head_weights=None):
   estimator.train(input_fn=input_fn, steps=100)
   ```
 
+  Also supports `logits` as a `Tensor` of shape
+  `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the
+  last dimension and distribute it appropriately among the heads. E.g.:
+
+  ```python
+  def model_fn(features, labels, mode):
+    # Create simple heads and specify head name.
+    head1 = multi_class_head(n_classes=3, name='head1')
+    head2 = binary_classification_head(name='head2')
+    # Create multi-head from two simple heads.
+    head = multi_head([head1, head2])
+    # Create logits for the multihead.
+    logits = logit_fn(logits_dimension=head.logits_dimension)
+    # Return the merged EstimatorSpec
+    return head.create_estimator_spec(..., logits=logits, ...)
+  ```
+
   Args:
     heads: List or tuple of `_Head` instances. All heads must have `name`
       specified. The first head in the list is the default used at serving time.
@@ -161,18 +181,17 @@ class _MultiHead(head_lib._Head):  # pylint:disable=protected-access
 
   def create_loss(self, features, mode, logits, labels):
     """See `Head`."""
-    # TODO(roumposg): Add support for logits as single Tensor (with
-    # _split_logits utility).
-    if not isinstance(logits, dict):
-      raise ValueError('logits must be a dict.  Single Tensor support coming '
-                       'soon.')
+    if isinstance(logits, dict):
+      logits_dict = logits
+    else:
+      logits_dict = self._split_logits(logits)
     weighted_sum_losses = []
     example_weight_sums = []
     labels_by_head = {}
     for head in self._heads:
       (weighted_sum_loss,
        example_weight_sum, processed_labels) = head.create_loss(
-           features, mode, logits[head.name], labels[head.name])
+           features, mode, logits_dict[head.name], labels[head.name])
       weighted_sum_losses.append(weighted_sum_loss)
       example_weight_sums.append(example_weight_sum)
       labels_by_head[head.name] = processed_labels
@@ -205,10 +224,10 @@ class _MultiHead(head_lib._Head):  # pylint:disable=protected-access
   def create_estimator_spec(
       self, features, mode, logits, labels=None, train_op_fn=None):
     """See `_Head`."""
-    # TODO(roumposg): Add support for logits as single Tensor (with
-    # _split_logits utility).
-    if not isinstance(logits, dict):
-      raise ValueError('logits must be a dict. Given: {}'.format(logits))
+    if isinstance(logits, dict):
+      logits_dict = logits
+    else:
+      logits_dict = self._split_logits(logits)
     if labels and not isinstance(labels, dict):
       raise ValueError('labels must be a dict. Given: {}'.format(labels))
 
@@ -219,22 +238,42 @@ class _MultiHead(head_lib._Head):  # pylint:disable=protected-access
           head.create_estimator_spec(
               features=features,
               mode=mode,
-              logits=logits[head_name],
+              logits=logits_dict[head_name],
               labels=labels[head_name] if labels else None,
               train_op_fn=_no_op_train_fn))
 
-    # TODO(roumposg): Add LOSS and LOSS_MEAN summaries for the total head-
-    # combined loss.
     if mode == model_fn.ModeKeys.TRAIN:
       if train_op_fn is None:
         raise ValueError('train_op_fn can not be None in TRAIN mode.')
-      return self._merge_train(all_estimator_spec, train_op_fn)
+      spec = self._merge_train(all_estimator_spec, train_op_fn)
+      with ops.name_scope(''):
+        summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss)
+      return spec
     if mode == model_fn.ModeKeys.PREDICT:
       return self._merge_predict(all_estimator_spec)
     if mode == model_fn.ModeKeys.EVAL:
       return self._merge_eval(all_estimator_spec)
     raise ValueError('mode={} unrecognized'.format(mode))
 
+  def _split_logits(self, logits):
+    """Splits logits along the last dimension and returns a dict."""
+    logits_dict = {}
+    with ops.name_scope(None, 'split_logits', values=[logits]):
+      logits = ops.convert_to_tensor(logits)
+      batch_shape = array_ops.shape(logits)[:-1]
+      zeros_like_batch_shape = array_ops.zeros_like(batch_shape)
+      minus_ones_like_batch_shape = -1 * array_ops.ones_like(batch_shape)
+      begin_idx = 0
+      for head in self._heads:
+        begin_tensor = array_ops.concat(
+            [zeros_like_batch_shape, [begin_idx]], axis=0)
+        size_tensor = array_ops.concat(
+            [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0)
+        logits_dict[head.name] = array_ops.slice(
+            logits, begin=begin_tensor, size=size_tensor)
+        begin_idx += head.logits_dimension
+    return logits_dict
+
   def _merge_train(self, all_estimator_spec, train_op_fn):
     """Merges list of `EstimatorSpec` for training.
 
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index 16177aebd53..8d51a298b23 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -106,7 +106,8 @@ class MultiHeadTest(test.TestCase):
     multi_head = multi_head_lib.multi_head([head1, head2])
     self.assertEqual('head1_head2', multi_head.name)
 
-  def test_predict_two_heads(self):
+  def test_predict_two_heads_logits_dict(self):
+    """Tests predict with logits as dict."""
     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
     head2 = head_lib.multi_label_head(n_classes=3, name='head2')
     multi_head = multi_head_lib.multi_head([head1, head2])
@@ -158,6 +159,111 @@ class MultiHeadTest(test.TestCase):
           expected_probabilities['head2'],
           sess.run(spec.export_outputs['head2'].scores))
 
+  def test_predict_two_heads_logits_tensor(self):
+    """Tests predict with logits as Tensor."""
+    head1 = head_lib.multi_label_head(n_classes=2, name='head1')
+    head2 = head_lib.multi_label_head(n_classes=3, name='head2')
+    multi_head = multi_head_lib.multi_head([head1, head2])
+
+    logits = np.array(
+        [[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]], dtype=np.float32)
+    expected_logits1 = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
+    expected_logits2 = np.array([[2., -2., 2.], [-3., 2., -2.]],
+                                dtype=np.float32)
+    expected_probabilities = {
+        'head1': _sigmoid(expected_logits1),
+        'head2': _sigmoid(expected_logits2),
+    }
+
+    spec = multi_head.create_estimator_spec(
+        features={'x': np.array(((42,),), dtype=np.int32)},
+        mode=model_fn.ModeKeys.PREDICT,
+        logits=logits)
+
+    self.assertItemsEqual(
+        (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1',
+         'head2', 'classification/head2', 'predict/head2'),
+        spec.export_outputs.keys())
+
+    # Assert predictions and export_outputs.
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      self.assertIsNone(spec.scaffold.summary_op)
+      predictions = sess.run(spec.predictions)
+      self.assertAllClose(
+          expected_logits1,
+          predictions[('head1', prediction_keys.PredictionKeys.LOGITS)])
+      self.assertAllClose(
+          expected_logits2,
+          predictions[('head2', prediction_keys.PredictionKeys.LOGITS)])
+      self.assertAllClose(
+          expected_probabilities['head1'],
+          predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)])
+      self.assertAllClose(
+          expected_probabilities['head2'],
+          predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)])
+
+      self.assertAllClose(
+          expected_probabilities['head1'],
+          sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))
+      self.assertAllClose(
+          expected_probabilities['head1'],
+          sess.run(spec.export_outputs['head1'].scores))
+      self.assertAllClose(
+          expected_probabilities['head2'],
+          sess.run(spec.export_outputs['head2'].scores))
+
+  def test_predict_two_heads_logits_tensor_multi_dim(self):
+    """Tests predict with multi-dimensional logits of shape [2, 2, 5]."""
+    head1 = head_lib.regression_head(label_dimension=2, name='head1')
+    head2 = head_lib.regression_head(label_dimension=3, name='head2')
+    multi_head = multi_head_lib.multi_head([head1, head2])
+
+    logits = np.array(
+        [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],
+         [[-1.5, 1., -3., 2., -2.], [-1.5, 1., -3., 2., -2.]]],
+        dtype=np.float32)
+    expected_logits1 = np.array(
+        [[[-1., 1.], [-1., 1.]],
+         [[-1.5, 1.], [-1.5, 1.]]],
+        dtype=np.float32)
+    expected_logits2 = np.array(
+        [[[2., -2., 2.], [2., -2., 2.]],
+         [[-3., 2., -2.], [-3., 2., -2.]]],
+        dtype=np.float32)
+
+    spec = multi_head.create_estimator_spec(
+        features={'x': np.array(((42,),), dtype=np.int32)},
+        mode=model_fn.ModeKeys.PREDICT,
+        logits=logits)
+
+    self.assertItemsEqual(
+        (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1',
+         'head2', 'regression/head2', 'predict/head2'),
+        spec.export_outputs.keys())
+
+    # Assert predictions and export_outputs.
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      self.assertIsNone(spec.scaffold.summary_op)
+      predictions = sess.run(spec.predictions)
+      self.assertAllClose(
+          expected_logits1,
+          predictions[('head1', prediction_keys.PredictionKeys.PREDICTIONS)])
+      self.assertAllClose(
+          expected_logits2,
+          predictions[('head2', prediction_keys.PredictionKeys.PREDICTIONS)])
+
+      self.assertAllClose(
+          expected_logits1,
+          sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].value))
+      self.assertAllClose(
+          expected_logits1,
+          sess.run(spec.export_outputs['head1'].value))
+      self.assertAllClose(
+          expected_logits2,
+          sess.run(spec.export_outputs['head2'].value))
+
   def test_eval_two_heads_with_weights(self):
     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
     head2 = head_lib.multi_label_head(n_classes=3, name='head2')
@@ -284,6 +390,84 @@ class MultiHeadTest(test.TestCase):
       # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13
       self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol)
 
+  def test_train_create_loss_logits_tensor(self):
+    """Tests create_loss with logits Tensor."""
+    weights1 = np.array([[1.], [2.]], dtype=np.float32)
+    weights2 = np.array([[2.], [3.]])
+    head1 = head_lib.multi_label_head(n_classes=2, name='head1',
+                                      weight_column='weights1')
+    head2 = head_lib.multi_label_head(n_classes=3, name='head2',
+                                      weight_column='weights2')
+    multi_head = multi_head_lib.multi_head(
+        [head1, head2], head_weights=[1., 2.])
+
+    logits = np.array([[-10., 10., 20., -20., 20.],
+                       [-15., 10., -30., 20., -20.]], dtype=np.float32)
+    labels = {
+        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
+        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
+    }
+    weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss(
+        features={
+            'x': np.array(((42,),), dtype=np.int32),
+            'weights1': weights1,
+            'weights2': weights2
+        },
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels)
+    tol = 1e-3
+    with self.test_session():
+      # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
+      # = [10, 7.5]
+      # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25
+      # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]
+      # = [20, 10]
+      # weighted_sum_loss = 2 * 20 + 3 * 10 = 70
+      # head-weighted merge = 1 * 25 + 2 * 70 = 165
+      self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol)
+      # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13
+      self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol)
+
+  def test_train_create_loss_logits_tensor_multi_dim(self):
+    """Tests create_loss with multi-dimensional logits of shape [2, 2, 5]."""
+    head1 = head_lib.regression_head(label_dimension=2, name='head1')
+    head2 = head_lib.regression_head(label_dimension=3, name='head2')
+    multi_head = multi_head_lib.multi_head([head1, head2])
+
+    logits = np.array(
+        [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],
+         [[-1.5, 1.5, -2., 2., -2.], [-1.5, 1.5, -2., 2., -2.]]],
+        dtype=np.float32)
+    labels = {
+        'head1': np.array([[[1., 0.], [1., 0.]],
+                           [[1.5, 1.5], [1.5, 1.5]]], dtype=np.float32),
+        'head2': np.array([[[0., 1., 0.], [0., 1., 0.]],
+                           [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32),
+    }
+    # Loss for the first head:
+    # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 +
+    #         (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2
+    #       = 28
+    # Loss for the second head:
+    # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 +
+    #         (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2
+    #       = 74
+    expected_weighted_sum_loss = 28. + 74.
+
+    weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss(
+        features={},
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels)
+    tol = 1e-3
+    with self.test_session():
+      self.assertAllClose(
+          expected_weighted_sum_loss, weighted_sum_loss.eval(),
+          rtol=tol, atol=tol)
+      self.assertAllClose(
+          2. * 2. * 5., example_weight_sum.eval(), rtol=tol, atol=tol)
+
   def test_train_one_head(self):
     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
     multi_head = multi_head_lib.multi_head([head1])
@@ -327,6 +511,7 @@ class MultiHeadTest(test.TestCase):
           six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
           train_result)
       _assert_simple_summaries(self, {
+          metric_keys.MetricKeys.LOSS: expected_loss,
           metric_keys.MetricKeys.LOSS + '/head1': expected_loss,
           # Average loss over examples.
           metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2,
@@ -387,6 +572,7 @@ class MultiHeadTest(test.TestCase):
           six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
           train_result)
       _assert_simple_summaries(self, {
+          metric_keys.MetricKeys.LOSS: expected_loss,
           metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1,
           metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2,
           # Average loss over examples.

From aa3d321213acdbe3a2403c9081a14762b8e9bb36 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 12:21:40 -0800
Subject: [PATCH 040/115] Add padded_batch_and_drop_remainder and factor out
 shared filter_irregular_batches.

PiperOrigin-RevId: 175045241
---
 .../kernel_tests/batch_dataset_op_test.py     | 225 +++++++++++-------
 .../contrib/data/python/ops/batching.py       |  87 +++++--
 2 files changed, 207 insertions(+), 105 deletions(-)

diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 670f622c3c3..951d4bb5f77 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -52,8 +52,9 @@ class BatchDatasetTest(test.TestCase):
     def _map_fn(x, y, z):
       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
 
-    iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
-                .repeat(count).batch(batch_size).make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+        .repeat(count).batch(batch_size).make_initializable_iterator())
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
@@ -69,7 +70,7 @@ class BatchDatasetTest(test.TestCase):
         result = sess.run(get_next)
         for component, result_component in zip(components, result):
           for j in range(14):
-            self.assertAllEqual(component[(i*14 + j) % 7]**2,
+            self.assertAllEqual(component[(i * 14 + j) % 7]**2,
                                 result_component[j])
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
@@ -84,12 +85,12 @@ class BatchDatasetTest(test.TestCase):
         result = sess.run(get_next)
         for component, result_component in zip(components, result):
           for j in range(8):
-            self.assertAllEqual(component[(i*8 + j) % 7]**2,
+            self.assertAllEqual(component[(i * 8 + j) % 7]**2,
                                 result_component[j])
       result = sess.run(get_next)
       for component, result_component in zip(components, result):
         for j in range((14 * 7) % 8):
-          self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2,
+          self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
                               result_component[j])
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
@@ -107,10 +108,10 @@ class BatchDatasetTest(test.TestCase):
     seq_lens = array_ops.placeholder(dtypes.int32, shape=[None])
     padded_shape = array_ops.placeholder(dtypes.int64, shape=[1])
 
-    iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens)
-                .map(lambda x: array_ops.fill([x], x)).padded_batch(
-                    4,
-                    padded_shapes=padded_shape).make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(seq_lens)
+        .map(lambda x: array_ops.fill([x], x)).padded_batch(
+            4, padded_shapes=padded_shape).make_initializable_iterator())
 
     init_op = iterator.initializer
     get_next = iterator.get_next()
@@ -118,35 +119,40 @@ class BatchDatasetTest(test.TestCase):
     with self.test_session() as sess:
       # Test with random sequence lengths, and max padding.
       random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
-      sess.run(init_op, feed_dict={padded_shape: [-1],
-                                   seq_lens: random_seq_lens})
+      sess.run(
+          init_op, feed_dict={
+              padded_shape: [-1],
+              seq_lens: random_seq_lens
+          })
       for i in range(8):
         result = sess.run(get_next)
         padded_len = np.max(result)
         self.assertEqual((4, padded_len), result.shape)
         for j in range(4):
-          seq_len = random_seq_lens[(i*4)+j]
+          seq_len = random_seq_lens[(i * 4) + j]
           self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
           self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len))
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
       # Test with random sequence lengths, and constant padding.
-      sess.run(init_op, feed_dict={padded_shape: [25],
-                                   seq_lens: random_seq_lens})
+      sess.run(
+          init_op, feed_dict={
+              padded_shape: [25],
+              seq_lens: random_seq_lens
+          })
       for i in range(8):
         result = sess.run(get_next)
         self.assertEqual((4, 25), result.shape)
         for j in range(4):
-          seq_len = random_seq_lens[(i*4)+j]
+          seq_len = random_seq_lens[(i * 4) + j]
           self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len)
           self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len))
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
       # Test correct handling of empty tensors.
-      sess.run(init_op, feed_dict={padded_shape: [-1],
-                                   seq_lens: [0, 0, 0, 0]})
+      sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]})
       result = sess.run(get_next)
       self.assertAllEqual([[], [], [], []], result)
       with self.assertRaises(errors.OutOfRangeError):
@@ -154,8 +160,7 @@ class BatchDatasetTest(test.TestCase):
 
       # Test error handling with constant sequence lengths, and
       # too-short padding.
-      sess.run(init_op, feed_dict={padded_shape: [5],
-                                   seq_lens: [6, 5, 5, 5]})
+      sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]})
       with self.assertRaises(errors.DataLossError):
         result = sess.run(get_next)
 
@@ -166,11 +171,13 @@ class BatchDatasetTest(test.TestCase):
     def fill_tuple(x):
       filled = array_ops.fill([x], x)
       return (filled, string_ops.as_string(filled))
-    iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
-                .padded_batch(
-                    4,
-                    padded_shapes=(padded_shape, padded_shape),
-                    padding_values=(-1, "<end>")).make_initializable_iterator())
+
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple)
+        .padded_batch(
+            4,
+            padded_shapes=(padded_shape, padded_shape),
+            padding_values=(-1, "<end>")).make_initializable_iterator())
 
     init_op = iterator.initializer
     get_next = iterator.get_next()
@@ -178,15 +185,18 @@ class BatchDatasetTest(test.TestCase):
     with self.test_session() as sess:
       # Test with random sequence lengths, and max padding.
       random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
-      sess.run(init_op, feed_dict={padded_shape: [-1],
-                                   seq_lens: random_seq_lens})
+      sess.run(
+          init_op, feed_dict={
+              padded_shape: [-1],
+              seq_lens: random_seq_lens
+          })
       for i in range(8):
         result = sess.run(get_next)
         padded_len = np.max(result[0])
         self.assertEqual((4, padded_len), result[0].shape)
         self.assertEqual((4, padded_len), result[1].shape)
         for j in range(4):
-          seq_len = random_seq_lens[(i*4)+j]
+          seq_len = random_seq_lens[(i * 4) + j]
           self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
           self.assertAllEqual(result[0][j, seq_len:],
                               [-1] * (padded_len - seq_len))
@@ -220,20 +230,21 @@ class BatchDatasetTest(test.TestCase):
                        constant_op.constant([-1, -1], dtype=dtypes.int64),
                        constant_op.constant([37], dtype=dtypes.int64)))
 
-    for dataset in [dynamic_padding_from_tensor_shapes,
-                    dynamic_padding_from_lists,
-                    dynamic_padding_from_lists_with_minus_one,
-                    dynamic_padding_from_tensors]:
+    for dataset in [
+        dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists,
+        dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors
+    ]:
       self.assertEqual([None, None], dataset.output_shapes[0].as_list())
       self.assertEqual([None, None, None], dataset.output_shapes[1].as_list())
       self.assertEqual([None, 37], dataset.output_shapes[2].as_list())
 
   def testDenseToSparseBatchDataset(self):
     components = np.random.randint(12, size=(100,)).astype(np.int32)
-    iterator = (dataset_ops.Dataset.from_tensor_slices(components)
-                .map(lambda x: array_ops.fill([x], x)).apply(
-                    batching.dense_to_sparse_batch(4, [12]))
-                .make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(components)
+        .map(lambda x: array_ops.fill([x], x)).apply(
+            batching.dense_to_sparse_batch(4,
+                                           [12])).make_initializable_iterator())
     init_op = iterator.initializer
     get_next = sparse_tensor.SparseTensor(*iterator.get_next())
 
@@ -242,24 +253,26 @@ class BatchDatasetTest(test.TestCase):
 
       for start in range(0, len(components), 4):
         results = sess.run(get_next)
+        self.assertAllEqual([[i, j]
+                             for i, c in enumerate(components[start:start + 4])
+                             for j in range(c)], results.indices)
         self.assertAllEqual(
-            [[i, j] for i, c in enumerate(components[start:start+4])
-             for j in range(c)], results.indices)
-        self.assertAllEqual(
-            [c for c in components[start:start+4] for _ in range(c)],
+            [c for c in components[start:start + 4] for _ in range(c)],
             results.values)
-        self.assertAllEqual(
-            [min(4, len(components) - start), 12], results.dense_shape)
+        self.assertAllEqual([min(4,
+                                 len(components) - start), 12],
+                            results.dense_shape)
 
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
   def testDenseToSparseBatchDatasetWithUnknownShape(self):
     components = np.random.randint(5, size=(40,)).astype(np.int32)
-    iterator = (dataset_ops.Dataset.from_tensor_slices(components)
-                .map(lambda x: array_ops.fill([x, x], x)).apply(
-                    batching.dense_to_sparse_batch(
-                        4, [5, -1])).make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(components)
+        .map(lambda x: array_ops.fill([x, x], x)).apply(
+            batching.dense_to_sparse_batch(
+                4, [5, -1])).make_initializable_iterator())
     init_op = iterator.initializer
     get_next = sparse_tensor.SparseTensor(*iterator.get_next())
 
@@ -268,27 +281,30 @@ class BatchDatasetTest(test.TestCase):
 
       for start in range(0, len(components), 4):
         results = sess.run(get_next)
-        self.assertAllEqual(
-            [[i, j, z] for i, c in enumerate(components[start:start+4])
-             for j in range(c) for z in range(c)], results.indices)
-        self.assertAllEqual(
-            [c for c in components[start:start+4]
-             for _ in range(c) for _ in range(c)],
-            results.values)
-        self.assertAllEqual(
-            [min(4, len(components) - start),
-             5,
-             np.max(components[start:start+4])],
-            results.dense_shape)
+        self.assertAllEqual([[i, j, z]
+                             for i, c in enumerate(components[start:start + 4])
+                             for j in range(c)
+                             for z in range(c)], results.indices)
+        self.assertAllEqual([
+            c
+            for c in components[start:start + 4] for _ in range(c)
+            for _ in range(c)
+        ], results.values)
+        self.assertAllEqual([
+            min(4,
+                len(components) - start), 5,
+            np.max(components[start:start + 4])
+        ], results.dense_shape)
 
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
 
   def testDenseToSparseBatchDatasetWithInvalidShape(self):
     input_tensor = array_ops.constant([[1]])
-    iterator = (dataset_ops.Dataset.from_tensors(input_tensor)
-                .apply(batching.dense_to_sparse_batch(4, [-2]))
-                .make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensors(input_tensor).apply(
+            batching.dense_to_sparse_batch(4, [-2]))
+        .make_initializable_iterator())
     init_op = iterator.initializer
 
     with self.test_session() as sess:
@@ -298,8 +314,10 @@ class BatchDatasetTest(test.TestCase):
 
   def testDenseToSparseBatchDatasetShapeErrors(self):
     input_tensor = array_ops.placeholder(dtypes.int32)
-    iterator = (dataset_ops.Dataset.from_tensors(input_tensor).apply(
-        batching.dense_to_sparse_batch(4, [12])).make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensors(input_tensor).apply(
+            batching.dense_to_sparse_batch(4,
+                                           [12])).make_initializable_iterator())
     init_op = iterator.initializer
     get_next = sparse_tensor.SparseTensor(*iterator.get_next())
 
@@ -356,8 +374,7 @@ class BatchDatasetTest(test.TestCase):
 
   def testUnbatchMultiElementTupleDataset(self):
     data = tuple([(math_ops.range(10 * i, 10 * i + 10),
-                   array_ops.fill([10], "hi"))
-                  for i in range(3)])
+                   array_ops.fill([10], "hi")) for i in range(3)])
     data = dataset_ops.Dataset.from_tensor_slices(data)
     expected_types = ((dtypes.int32, dtypes.string),) * 3
     data = data.batch(2)
@@ -370,9 +387,7 @@ class BatchDatasetTest(test.TestCase):
 
     with self.test_session() as sess:
       for i in range(10):
-        self.assertEqual(((i, b"hi"),
-                          (10 + i, b"hi"),
-                          (20 + i, b"hi")),
+        self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
                          sess.run(op))
 
       with self.assertRaises(errors.OutOfRangeError):
@@ -385,9 +400,10 @@ class BatchDatasetTest(test.TestCase):
 
     batch_size = array_ops.placeholder(dtypes.int64, shape=[])
 
-    iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply(
-        batching.batch_and_drop_remainder(batch_size))
-                .make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(components).apply(
+            batching.batch_and_drop_remainder(batch_size))
+        .make_initializable_iterator())
 
     next_element = iterator.get_next()
 
@@ -404,14 +420,51 @@ class BatchDatasetTest(test.TestCase):
         with self.assertRaises(errors.OutOfRangeError):
           sess.run(next_element)
 
+  def testPaddedBatchAndDropRemainder(self):
+    els = []
+    for length in [3, 6, 9, 4, 12, 10, 2]:
+      els.append((np.array(length), np.arange(length) + 1,
+                  np.array(length * 2)))
+
+    dataset = dataset_ops.Dataset.from_tensors(els[0])
+    for el in els[1:]:
+      dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el))
+
+    batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+    iterator = (
+        dataset.apply(
+            batching.padded_batch_and_drop_remainder(
+                batch_size, ([], [None], []))).make_initializable_iterator())
+
+    next_element = iterator.get_next()
+
+    with self.test_session() as sess:
+      for test_batch_size in [1, 3, 7, 10]:
+        sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
+        num_batches = 7 // test_batch_size
+        for i in range(num_batches):
+          result = sess.run(next_element)
+          for component_idx, result_component in enumerate(result):
+            for j in range(test_batch_size):
+              data_idx = i * test_batch_size + j
+              comp = result_component[j]
+              unpadded = comp[comp > 0]
+              if np.isscalar(comp):
+                # The boolean mask indexing above adds a dim back. Rm it.
+                unpadded = unpadded[0]
+              self.assertAllEqual(els[data_idx][component_idx], unpadded)
+        with self.assertRaises(errors.OutOfRangeError):
+          sess.run(next_element)
+
   def testBatchAndDropRemainderShapeInference(self):
-    components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder(
-        dtypes.int32, shape=[None]), array_ops.placeholder(
-            dtypes.int32, shape=[20, 30])))
+    components = (array_ops.placeholder(dtypes.int32),
+                  (array_ops.placeholder(dtypes.int32, shape=[None]),
+                   array_ops.placeholder(dtypes.int32, shape=[20, 30])))
 
     # Test with a statically known batch size.
-    dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply(
-        batching.batch_and_drop_remainder(128)))
+    dataset = (
+        dataset_ops.Dataset.from_tensor_slices(components).apply(
+            batching.batch_and_drop_remainder(128)))
 
     self.assertIs(None, dataset.output_shapes[0].ndims)
     self.assertEqual([128], dataset.output_shapes[1][0].as_list())
@@ -420,8 +473,9 @@ class BatchDatasetTest(test.TestCase):
     # Test with a dynamic batch size: the static shape will be unknown, because
     # `batch_size` is a placeholder.
     batch_size = array_ops.placeholder(dtypes.int64)
-    dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply(
-        batching.batch_and_drop_remainder(batch_size)))
+    dataset = (
+        dataset_ops.Dataset.from_tensor_slices(components).apply(
+            batching.batch_and_drop_remainder(batch_size)))
 
     self.assertIs(None, dataset.output_shapes[0].ndims)
     self.assertEqual([None], dataset.output_shapes[1][0].as_list())
@@ -441,9 +495,10 @@ class BatchDatasetTest(test.TestCase):
     def _map_fn(x, y, z):
       return math_ops.square(x), math_ops.square(y), math_ops.square(z)
 
-    iterator = (dataset_ops.Dataset.from_tensor_slices(components).repeat(count)
-                .apply(batching.map_and_batch(_map_fn, batch_size))
-                .make_initializable_iterator())
+    iterator = (
+        dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
+            batching.map_and_batch(_map_fn, batch_size))
+        .make_initializable_iterator())
     init_op = iterator.initializer
     get_next = iterator.get_next()
 
@@ -459,7 +514,7 @@ class BatchDatasetTest(test.TestCase):
         result = sess.run(get_next)
         for component, result_component in zip(components, result):
           for j in range(14):
-            self.assertAllEqual(component[(i*14 + j) % 7]**2,
+            self.assertAllEqual(component[(i * 14 + j) % 7]**2,
                                 result_component[j])
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(get_next)
@@ -474,7 +529,7 @@ class BatchDatasetTest(test.TestCase):
         result = sess.run(get_next)
         for component, result_component in zip(components, result):
           for j in range(8):
-            self.assertAllEqual(component[(i*8 + j) % 7]**2,
+            self.assertAllEqual(component[(i * 8 + j) % 7]**2,
                                 result_component[j])
       # The last batch should fail with `OutOfRange`.
       with self.assertRaises(errors.OutOfRangeError):
@@ -495,8 +550,9 @@ class BatchDatasetTest(test.TestCase):
         array_ops.check_numerics(
             constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
     batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-    iterator = (dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
-                .make_initializable_iterator())
+    iterator = (
+        dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+        .make_initializable_iterator())
     init_op = iterator.initializer
     with self.test_session() as sess:
       with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
@@ -504,6 +560,7 @@ class BatchDatasetTest(test.TestCase):
 
   def testBatchAndMapDatasetShapeMismatch(self):
     """Test a dataset that maps a TF function across its input elements."""
+
     def generator():
       yield [1]
       yield [2]
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index abc9212a875..d4ade7adfd2 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -103,6 +103,42 @@ def unbatch():
   return _apply_fn
 
 
+def filter_irregular_batches(batch_size):
+  """Transformation that filters out batches that are not of size batch_size."""
+
+  def _apply_fn(dataset):
+    """Function from `Dataset` to `Dataset` that applies the transformation."""
+    tensor_batch_size = ops.convert_to_tensor(
+        batch_size, dtype=dtypes.int64, name="batch_size")
+
+    flattened = _RestructuredDataset(dataset,
+                                     tuple(nest.flatten(dataset.output_types)))
+
+    def _predicate(*xs):
+      """Return `True` if this element is a full batch."""
+      # Extract the dynamic batch size from the first component of the flattened
+      # batched element.
+      first_component = xs[0]
+      first_component_batch_size = array_ops.shape(
+          first_component, out_type=dtypes.int64)[0]
+
+      return math_ops.equal(first_component_batch_size, tensor_batch_size)
+
+    filtered = flattened.filter(_predicate)
+
+    maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
+
+    def _set_first_dimension(shape):
+      return shape.merge_with(
+          tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
+
+    known_shapes = nest.map_structure(_set_first_dimension,
+                                      dataset.output_shapes)
+    return _RestructuredDataset(filtered, dataset.output_types, known_shapes)
+
+  return _apply_fn
+
+
 def batch_and_drop_remainder(batch_size):
   """A batching transformation that omits the final small batch (if present).
 
@@ -135,34 +171,43 @@ def batch_and_drop_remainder(batch_size):
 
   def _apply_fn(dataset):
     """Function from `Dataset` to `Dataset` that applies the transformation."""
-    tensor_batch_size = ops.convert_to_tensor(
-        batch_size, dtype=dtypes.int64, name="batch_size")
+    batched = dataset.batch(batch_size)
+    return filter_irregular_batches(batch_size)(batched)
 
-    batched = dataset.batch(tensor_batch_size)
-    flattened = _RestructuredDataset(batched,
-                                     tuple(nest.flatten(batched.output_types)))
+  return _apply_fn
 
-    def _predicate(*xs):
-      """Return `True` if this element is a full batch."""
-      # Extract the dynamic batch size from the first component of the flattened
-      # batched element.
-      first_component = xs[0]
-      first_component_batch_size = array_ops.shape(
-          first_component, out_type=dtypes.int64)[0]
 
-      return math_ops.equal(first_component_batch_size, tensor_batch_size)
+def padded_batch_and_drop_remainder(batch_size,
+                                    padded_shapes,
+                                    padding_values=None):
+  """A batching and padding transformation that omits the final small batch.
 
-    filtered = flattened.filter(_predicate)
+  Like @{tf.data.Dataset.padded_batch}, this transformation combines
+  consecutive elements of this dataset into batches. However, if the batch
+  size does not evenly divide the input dataset size, this transformation will
+  drop the final smaller element.
 
-    maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
+  See `@{tf.contrib.data.batch_and_drop_remainder}` for more details.
 
-    def _set_first_dimension(shape):
-      return shape.merge_with(
-          tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
+  Args:
+    batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+      consecutive elements of this dataset to combine in a single batch.
+    padded_shapes: A nested structure of `tf.TensorShape` or
+      `tf.int64` vector tensor-like objects. See
+      @{tf.data.Dataset.padded_batch} for details.
+    padding_values: (Optional.) A nested structure of scalar-shaped
+      `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details.
 
-    known_shapes = nest.map_structure(_set_first_dimension,
-                                      batched.output_shapes)
-    return _RestructuredDataset(filtered, batched.output_types, known_shapes)
+  Returns:
+    A `Dataset` transformation function, which can be passed to
+    @{tf.data.Dataset.apply}
+  """
+
+  def _apply_fn(dataset):
+    """Function from `Dataset` to `Dataset` that applies the transformation."""
+    batched = dataset.padded_batch(
+        batch_size, padded_shapes=padded_shapes, padding_values=padding_values)
+    return filter_irregular_batches(batch_size)(batched)
 
   return _apply_fn
 

From 3a17aa5ea18e43e6974bbf6d5cef6d02edfada5c Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 12:56:45 -0800
Subject: [PATCH 041/115] Support replacing tpu_config.

PiperOrigin-RevId: 175049981
---
 tensorflow/contrib/tpu/python/tpu/tpu_config.py | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 3965c087a18..097acd5ee73 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -109,3 +109,12 @@ class RunConfig(run_config_lib.RunConfig):
   @property
   def tpu_config(self):
     return self._tpu_config
+
+  def replace(self, **kwargs):
+    if 'tpu_config' not in kwargs:
+      return super(RunConfig, self).replace(**kwargs)
+
+    tpu_config = kwargs.pop('tpu_config')
+    new_instance = super(RunConfig, self).replace(**kwargs)
+    new_instance._tpu_config = tpu_config  # pylint: disable=protected-access
+    return new_instance

From 8507c4a122c83fdad7b1855d5d43d51b6bd8009d Mon Sep 17 00:00:00 2001
From: Skye Wanderman-Milne <skyewm@google.com>
Date: Wed, 8 Nov 2017 13:00:12 -0800
Subject: [PATCH 042/115] Allow Operation._get_attr() to work with all attr
 types with C API enabled

This is achieved by accessing the AttrValue directly and using the
existing Python code instead of dispatching to the specific C API attr
getter for every type. I started going down the dispatch path, but it
turns out to be a lot of code (spread across Python, C, and SWIG), and
this is likely good enough from a performance standpoint. We can
optimize in the future if necessary.

In addition, changes the colocation group logic to use _set_attr() and
get_attr(), and makes _set_attr() work with the C API disabled. This
allows the colocation tests to pass with both the C API enabled and
disabled. Without these additional changes, the "_class" attribute
would be set on the C NodeDef, and then it would try to retrieve it
from the Python NodeDef.

PiperOrigin-RevId: 175050473
---
 tensorflow/c/c_api.cc                   |  4 +-
 tensorflow/c/c_api_test.cc              |  4 +-
 tensorflow/python/client/tf_session.i   | 10 ---
 tensorflow/python/framework/ops.py      | 71 +++++++++++----------
 tensorflow/python/framework/ops_test.py | 85 +++++++++++++------------
 tensorflow/python/framework/test_ops.cc | 23 +++++++
 6 files changed, 109 insertions(+), 88 deletions(-)

diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 6dd1b999102..dd638de3c69 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -890,8 +890,8 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper,
                                           TF_Status* status) {
   const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name);
   if (attr == nullptr) {
-    status->status =
-        InvalidArgument("Operation has no attr named '", attr_name, "'.");
+    status->status = InvalidArgument("Operation '", oper->node.name(),
+                                     "' has no attr named '", attr_name, "'.");
   }
   return attr;
 }
diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc
index 05881e619ba..e0057eb51cd 100644
--- a/tensorflow/c/c_api_test.cc
+++ b/tensorflow/c/c_api_test.cc
@@ -383,7 +383,7 @@ TEST(CAPI, Graph) {
   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
 
   ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s));
-  EXPECT_EQ(string("Operation has no attr named 'missing'."),
+  EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."),
             string(TF_Message(s)));
 
   // Make a constant oper with the scalar "3".
@@ -1054,7 +1054,7 @@ class CApiColocationTest : public ::testing::Test {
         TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_);
     if (expected.empty()) {
       ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_);
-      EXPECT_EQ(std::string("Operation has no attr named '_class'."),
+      EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."),
                 std::string(TF_Message(s_)));
       return;
     }
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index f45bc13602e..40731aba7d4 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -344,16 +344,6 @@ bool PyTensorListToVector(PyObject* py_tensor_list,
 %rename("_TF_SetConfig") TF_SetConfig;
 %rename("_TF_NewSessionOptions") TF_NewSessionOptions;
 
-// Create temporary int64_t to pass to TF_OperationGetAttrInt
-%typemap(in, numinputs=0) int64_t* value (int64_t val) {
-  $1 = &val;
-}
-
-// Convert value to Python int
-%typemap(argout) int64_t* value {
-  $result = PyInt_FromLong(*$1);
-}
-
 %include "tensorflow/c/c_api.h"
 %include "tensorflow/c/python_api.h"
 
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index b256af2182a..ad2e2993c1b 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1641,13 +1641,15 @@ class Operation(object):
     default_colocation_group = [
         compat.as_bytes("loc:@%s" % self._node_def.name)
     ]
-    if "_class" not in self._node_def.attr:
+    try:
+      class_attr = self.get_attr("_class")
+    except ValueError:
       # This op has no explicit colocation group, so it is itself its
       # own root of a colocation group.
       return default_colocation_group
 
     attr_groups = [
-        class_name for class_name in self.get_attr("_class")
+        class_name for class_name in class_attr
         if class_name.startswith(b"loc:@")
     ]
 
@@ -2062,16 +2064,19 @@ class Operation(object):
 
   def _set_attr(self, attr_name, attr_value):
     """Private method used to set an attribute in the node_def."""
-    if not _USE_C_API:
-      assert "_set_attr not supported with _USE_C_API == False"
-      return
-    buf = c_api.TF_NewBufferFromString(
-        compat.as_bytes(attr_value.SerializeToString()))
-    try:
-      with errors.raise_exception_on_not_ok_status() as status:
-        c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf, status)  # pylint: disable=protected-access
-    finally:
-      c_api.TF_DeleteBuffer(buf)
+    if _USE_C_API:
+      buf = c_api.TF_NewBufferFromString(
+          compat.as_bytes(attr_value.SerializeToString()))
+      try:
+        with errors.raise_exception_on_not_ok_status() as status:
+          # pylint: disable=protected-access
+          c_api.SetAttr(self._graph._c_graph, self._c_op, attr_name, buf,
+                        status)
+          # pylint: enable=protected-access
+      finally:
+        c_api.TF_DeleteBuffer(buf)
+    else:
+      self._node_def.attr[attr_name].CopyFrom(attr_value)
 
   def get_attr(self, name):
     """Returns the value of the attr of this op with the given `name`.
@@ -2085,25 +2090,24 @@ class Operation(object):
     Raises:
       ValueError: If this op does not have an attr with the given `name`.
     """
-    if _USE_C_API:
-      try:
-        # TODO(b/65162920): remove this try/except block when all attrs are
-        # implemented to use the _set_attr method instead of node_def.attr.
-        with errors.raise_exception_on_not_ok_status() as status:
-          metadata = c_api.TF_OperationGetAttrMetadata(self._c_op, name, status)
-        with errors.raise_exception_on_not_ok_status() as status:
-          if metadata.type == c_api.TF_ATTR_INT and metadata.is_list == 0:
-            return c_api.TF_OperationGetAttrInt(self._c_op, name, status)
-      except errors.InvalidArgumentError:
-        # Colocation ops are failing to find attrs begininning with "_*". They
-        # should fall through to the not-CAPI logic until the attribute is set
-        # via the C-API always.
-        pass
-
     fields = ["s", "i", "f", "b", "type", "shape", "tensor", "func"]
-    if name not in self._node_def.attr:
-      raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
-    x = self._node_def.attr[name]
+    if self._c_op:
+      try:
+        with c_api_util.tf_buffer() as buf:
+          with errors.raise_exception_on_not_ok_status() as status:
+            c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf, status)
+          data = c_api.TF_GetBuffer(buf)
+      except errors.InvalidArgumentError as e:
+        # Convert to ValueError for backwards compatibility.
+        raise ValueError(str(e))
+      x = attr_value_pb2.AttrValue()
+      x.ParseFromString(data)
+    else:
+      if name not in self._node_def.attr:
+        raise ValueError(
+            "No attr named '" + name + "' in " + str(self._node_def))
+      x = self._node_def.attr[name]
+
     # Treat an empty oneof value as an empty list.
     if not x.WhichOneof("value"):
       return []
@@ -3103,9 +3107,10 @@ class Graph(object):
             ret._set_device(colocation_op.device)  # pylint: disable=protected-access
 
       all_colocation_groups = sorted(set(all_colocation_groups))
-      ret.node_def.attr["_class"].CopyFrom(
-          attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue(
-              s=all_colocation_groups)))
+      # pylint: disable=protected-access
+      ret._set_attr("_class", attr_value_pb2.AttrValue(
+          list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups)))
+      # pylint: enable=protected-access
 
     # Sets "container" attribute if
     # (1) self._container is not None
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 3087d6060b9..4e931e00c59 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -31,9 +31,11 @@ from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
 from tensorflow.python.framework import test_ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.framework import versions
@@ -357,54 +359,55 @@ class OperationTest(test_util.TensorFlowTestCase):
     self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
 
   def testGetAttr(self):
-    # TODO(b/65162920): implement all tests for get_attr with C API
+    op = test_ops.default_attrs()
+    self.assertEqual(op.get_attr("string_val"), b"abc")
+    self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
+    self.assertEqual(op.get_attr("int_val"), 123)
+    self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
+    self.assertEqual(op.get_attr("float_val"), 10.0)
+    self.assertEqual(op.get_attr("float_list_val"), [10.0])
+    self.assertEqual(op.get_attr("bool_val"), True)
+    self.assertEqual(op.get_attr("bool_list_val"), [True, False])
+    self.assertEqual(op.get_attr("shape_val"),
+                     tensor_shape.as_shape([2, 1]).as_proto())
+    self.assertEqual(op.get_attr("shape_list_val"),
+                     [tensor_shape.as_shape([]).as_proto(),
+                      tensor_shape.as_shape([1]).as_proto()])
+    self.assertEqual(op.get_attr("tensor_val"),
+                     tensor_util.make_tensor_proto(1, dtypes.int32))
+    self.assertEqual(op.get_attr("tensor_list_val"),
+                     [tensor_util.make_tensor_proto(1, dtypes.int32)])
+
+    type_val = op.get_attr("type_val")
+    # First check that type_val is a DType, because the assertEquals will work
+    # no matter what since DType overrides __eq__
+    self.assertIsInstance(type_val, dtypes.DType)
+    self.assertEqual(type_val, dtypes.int32)
+
+    type_list_val = op.get_attr("type_list_val")
+    self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
+    self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
+
+    @function.Defun(dtypes.float32, func_name="MyFunc")
+    def func(x):
+      return x
+
+    op = test_ops.func_attr(func)
+    self.assertEqual(op.get_attr("f"),
+                     attr_value_pb2.NameAttrList(name="MyFunc"))
+
+    # Try fetching missing attr
     if ops._USE_C_API:
-      op = test_ops.int_attr().op
-      self.assertEqual(op.get_attr("foo"), 1)
-
-      op_str = test_ops.string_list_attr(a=["z"], b="y")
-      self.assertEqual(op_str.get_attr("a"), [b"z"])
-      self.assertEqual(op_str.get_attr("b"), b"y")
-
+      error_msg = "Operation 'FuncAttr' has no attr named 'FakeAttr'."
     else:
-      list_value = attr_value_pb2.AttrValue.ListValue()
+      error_msg = "No attr named 'FakeAttr' in name: \"FuncAttr\""
 
-      list_value.type.append(types_pb2.DT_STRING)
-      list_value.type.append(types_pb2.DT_DOUBLE)
-      op = ops.Operation(
-          ops._NodeDef(
-              "None",
-              "op1",
-              attrs={
-                  "value":
-                      attr_value_pb2.AttrValue(i=32),
-                  "dtype":
-                      attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
-                  "list":
-                      attr_value_pb2.AttrValue(list=list_value),
-                  "func":
-                      attr_value_pb2.AttrValue(
-                          func=attr_value_pb2.NameAttrList())
-              }), ops.Graph(), [], [dtypes.int32])
-      self.assertEqual(32, op.get_attr("value"))
-      self.assertEqual("", op.get_attr("func").name)
-
-      d = op.get_attr("dtype")
-      # First check that d is a DType, because the assertEquals will
-      # work no matter what since DType overrides __eq__
-      self.assertIsInstance(d, dtypes.DType)
-      self.assertEqual(dtypes.int32, d)
-
-      l = op.get_attr("list")
-      for x in l:
-        self.assertIsInstance(x, dtypes.DType)
-      self.assertEqual([dtypes.string, dtypes.double], l)
+    with self.assertRaisesRegexp(ValueError, error_msg):
+      op.get_attr("FakeAttr")
 
   # TODO(b/65162920): remove this test when users who are directly mutating the
   # node_def have been updated to proper usage.
   def testSetAttr(self):
-    if not ops._USE_C_API:
-      return
     op = test_ops.int_attr().op
     op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
     # TODO(skyewm): add node_def check
diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc
index a8b7fc543f0..35e0167b260 100644
--- a/tensorflow/python/framework/test_ops.cc
+++ b/tensorflow/python/framework/test_ops.cc
@@ -341,4 +341,27 @@ REGISTER_OP("StringListAttr")
     .Attr("b: string")
     .SetShapeFn(shape_inference::UnknownShape);
 
+REGISTER_OP("DefaultAttrs")
+    .Attr("string_val: string = 'abc'")
+    .Attr("string_list_val: list(string) = ['abc', '']")
+    .Attr("int_val: int = 123")
+    .Attr("int_list_val: list(int) = [1, 2, 3]")
+    .Attr("float_val: float = 10.0")
+    .Attr("float_list_val: list(float) = [10.0]")
+    .Attr("bool_val: bool = true")
+    .Attr("bool_list_val: list(bool) = [true, false]")
+    .Attr("type_val: type = DT_INT32")
+    .Attr("type_list_val: list(type) = [DT_INT32, DT_FLOAT]")
+    .Attr("shape_val: shape = { dim { size: 2 } dim { size: 1 } }")
+    .Attr("shape_list_val: list(shape) = [{}, { dim { size: 1} }]")
+    .Attr("tensor_val: tensor = { dtype: DT_INT32 tensor_shape: {} int_val: 1}")
+    .Attr(
+        "tensor_list_val: list(tensor) = "
+        "[{ dtype: DT_INT32 tensor_shape: {} int_val: 1}]")
+    .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("FuncAttr")
+    .Attr("f: func")
+    .SetShapeFn(shape_inference::UnknownShape);
+
 }  // end namespace tensorflow

From 904abae95b2c88f4379e8133e5b8dfd2e2526ed0 Mon Sep 17 00:00:00 2001
From: Igor Ganichev <iga@google.com>
Date: Wed, 8 Nov 2017 13:19:27 -0800
Subject: [PATCH 043/115] Give a better error message when placeholders are
 used with eager

PiperOrigin-RevId: 175053592
---
 tensorflow/python/ops/array_ops.py | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 6b4919b16f0..61bd41e7de6 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1653,6 +1653,8 @@ def placeholder(dtype, shape=None, name=None):
     print(sess.run(y, feed_dict={x: rand_array}))  # Will succeed.
   ```
 
+  @compatibility{eager} Placeholders are not compatible with eager execution.
+
   Args:
     dtype: The type of elements in the tensor to be fed.
     shape: The shape of the tensor to be fed (optional). If the shape is not
@@ -1662,7 +1664,14 @@ def placeholder(dtype, shape=None, name=None):
   Returns:
     A `Tensor` that may be used as a handle for feeding a value, but not
     evaluated directly.
+
+  Raises:
+    RuntimeError: if eager execution is enabled
   """
+  if context.in_eager_mode():
+    raise RuntimeError("tf.placeholder() is not compatible with "
+                       "eager execution.")
+
   return gen_array_ops._placeholder(dtype=dtype, shape=shape, name=name)
 
 
@@ -1706,6 +1715,8 @@ def sparse_placeholder(dtype, shape=None, name=None):
     print(sess.run(y, feed_dict={x: sp_value}))  # Will succeed.
   ```
 
+  @compatibility{eager} Placeholders are not compatible with eager execution.
+
   Args:
     dtype: The type of `values` elements in the tensor to be fed.
     shape: The shape of the tensor to be fed (optional). If the shape is not
@@ -1715,7 +1726,14 @@ def sparse_placeholder(dtype, shape=None, name=None):
   Returns:
     A `SparseTensor` that may be used as a handle for feeding a value, but not
     evaluated directly.
+
+  Raises:
+    RuntimeError: if eager execution is enabled
   """
+  if context.in_eager_mode():
+    raise RuntimeError("tf.placeholder() is not compatible with "
+                       "eager execution.")
+
   shape_name = (name + "/shape") if name is not None else None
   shape, rank = _normalize_sparse_shape(shape, shape_name)
   if shape is None:

From fa318123adcf457f3ed92e617c6fa34a695d2279 Mon Sep 17 00:00:00 2001
From: Benoit Steiner <bsteiner@google.com>
Date: Wed, 8 Nov 2017 13:34:58 -0800
Subject: [PATCH 044/115] Optimize gradient subgraphs by taking advantage of
 symbolic shapes whenever possible.

PiperOrigin-RevId: 175055770
---
 tensorflow/core/grappler/optimizers/BUILD     |   1 +
 .../grappler/optimizers/constant_folding.cc   | 100 ++++++++++++++++--
 .../grappler/optimizers/constant_folding.h    |   6 +-
 .../optimizers/constant_folding_test.cc       |  53 ++++++++++
 .../grappler/optimizers/meta_optimizer.cc     |   4 +-
 5 files changed, 153 insertions(+), 11 deletions(-)

diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 669d02815c7..54004a5e07f 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -112,6 +112,7 @@ tf_cc_test(
     deps = [
         ":constant_folding",
         "//tensorflow/cc:cc_ops",
+        "//tensorflow/cc:cc_ops_internal",
         "//tensorflow/core:all_kernels",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:core_cpu_internal",
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index cb023141833..a364ca487ea 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -95,11 +95,15 @@ class DeviceSimple : public DeviceBase {
 };
 
 }  // namespace
-ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
-    : cpu_device_(cpu_device) {
+ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
+                                 DeviceBase* cpu_device)
+    : opt_level_(opt_level), cpu_device_(cpu_device) {
   resource_mgr_.reset(new ResourceMgr());
 }
 
+ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
+    : ConstantFolding(RewriterConfig::ON, cpu_device) {}
+
 // static
 string ConstantFolding::AddControlDependency(const string& input_name,
                                              GraphDef* graph,
@@ -281,6 +285,84 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
   return Status::OK();
 }
 
+bool ShapesEqual(const TensorShapeProto& shape1,
+                 const TensorShapeProto& shape2) {
+  if (shape1.unknown_rank() || shape2.unknown_rank()) {
+    return false;
+  }
+  if (shape1.dim_size() != shape2.dim_size()) {
+    return false;
+  }
+  for (int i = 0; i < shape1.dim_size(); ++i) {
+    if (shape1.dim(i).size() != shape2.dim(i).size()) {
+      return false;
+    }
+  }
+  return true;
+}
+
+Status ConstantFolding::MaterializeConstants(
+    const GrapplerItem& item, const GraphProperties& properties) {
+  const int node_count = graph_.node_size();
+  for (int i = 0; i < node_count; ++i) {
+    NodeDef& node = *graph_.mutable_node(i);
+    const string& op = node.op();
+    if (op != "BroadcastGradientArgs") {
+      continue;
+    }
+    const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
+    const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
+    if (shape_node1 == nullptr || shape_node1->op() != "Shape" ||
+        shape_node2 == nullptr || shape_node2->op() != "Shape") {
+      continue;
+    }
+    const std::vector<OpInfo::TensorProperties>& prop1 =
+        properties.GetInputProperties(shape_node1->name());
+    const std::vector<OpInfo::TensorProperties>& prop2 =
+        properties.GetInputProperties(shape_node2->name());
+    if (prop1.size() != 1 || prop2.size() != 1) {
+      continue;
+    }
+    const TensorShapeProto& shape1 = prop1[0].shape();
+    const TensorShapeProto& shape2 = prop2[0].shape();
+    if (ShapesEqual(shape1, shape2)) {
+      DataType type = node.attr().at("T").type();
+      Tensor empty(type, TensorShape());
+      NodeDef* out[2];
+      for (int i = 0; i < 2; ++i) {
+        string const_name = AddPrefixToNodeName(
+            strings::StrCat(node.name(), "-", i), kConstantFoldingConst);
+        out[i] = node_map_->GetNode(const_name);
+        if (!out[i]) {
+          out[i] = graph_.add_node();
+          *out[i] = CreateNodeDef(const_name, TensorValue(&empty));
+          out[i]->set_device(node.device());
+          node_map_->AddNode(const_name, out[i]);
+          string ctrl_dep =
+              AddControlDependency(node.name(), &graph_, node_map_.get());
+          *out[i]->add_input() = ctrl_dep;
+          node_map_->AddOutput(NodeName(ctrl_dep), const_name);
+        }
+      }
+
+      auto outputs = node_map_->GetOutputs(node.name());
+      for (const auto& output : outputs) {
+        for (int k = 0; k < output->input_size(); ++k) {
+          int port;
+          string node_name = ParseNodeName(output->input(k), &port);
+          if (node_name == node.name() && port >= 0 && port < 2) {
+            *output->mutable_input(k) = out[port]->name();
+            node_map_->UpdateInput(output->name(), node_name,
+                                   out[port]->name());
+          }
+        }
+      }
+    }
+  }
+
+  return Status::OK();
+}
+
 bool ConstantFolding::IsFoldable(const NodeDef& node) const {
   // Folding not applicable to ops with no inputs.
   if (node.input().empty()) {
@@ -921,23 +1003,25 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
   }
 
   GraphProperties properties(item);
+  Status s = properties.InferStatically();
   bool has_feed = !item.feed.empty();
-  if (!has_feed) {
+  // bool has_feed = false;
+  if (!has_feed && s.ok()) {
     // Only use static shape information when there is no feed in the
     // graph. That's because it's possible to feed a placeholder with a tensor
     // of any shape, which could make the static information inconsistent with
     // the shapes actually fed.
-    Status s = properties.InferStatically();
-    if (!s.ok()) {
-      VLOG(1) << "Failed to infer graph shapes: " << s;
-    } else {
+    if (s.ok()) {
       TF_RETURN_IF_ERROR(MaterializeShapes(item, properties));
     }
   }
+  if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) {
+    TF_RETURN_IF_ERROR(MaterializeConstants(item, properties));
+  }
 
   TF_RETURN_IF_ERROR(FoldGraph(output));
 
-  if (!has_feed) {
+  if (!has_feed && s.ok()) {
     TF_RETURN_IF_ERROR(SimplifyGraph(output, properties));
   }
   return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 30d778789a4..dd988f336cb 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -22,6 +22,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/costs/graph_properties.h"
 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
 #include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
 
 namespace tensorflow {
 namespace grappler {
@@ -37,6 +38,7 @@ class ConstantFolding : public GraphOptimizer {
                                      NodeMap* node_map);
 
   ConstantFolding(DeviceBase* cpu_device);
+  ConstantFolding(RewriterConfig::Toggle opt_level, DeviceBase* cpu_device);
 
   ~ConstantFolding() override {}
 
@@ -51,7 +53,8 @@ class ConstantFolding : public GraphOptimizer {
  private:
   Status MaterializeShapes(const GrapplerItem& item,
                            const GraphProperties& properties);
-
+  Status MaterializeConstants(const GrapplerItem& item,
+                              const GraphProperties& properties);
   bool IsFoldable(const NodeDef& node) const;
 
   Status EvaluateNode(const NodeDef& node,
@@ -74,6 +77,7 @@ class ConstantFolding : public GraphOptimizer {
                              GraphDef* output);
 
   // Points to an externally provided device or to owned_device_;
+  RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;
   std::unique_ptr<DeviceBase> owned_device_;
 
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index a1dee6d2fb8..17f9854b599 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
+#include "tensorflow/cc/ops/array_ops_internal.h"
 #include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/core/framework/node_def.pb.h"
 #include "tensorflow/core/framework/tensor_testutil.h"
@@ -838,6 +839,58 @@ TEST_F(ConstantFoldingTest, Packing) {
   // size needed to naively encode 1000 floats folded twice).
   EXPECT_GT(8000, output.ByteSizeLong());
 }
+
+TEST_F(ConstantFoldingTest, ConstantMaterialization) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output a =
+      ops::Placeholder(s.WithOpName("a"), DT_FLOAT,
+                       ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+  Output b = ops::Square(s.WithOpName("b"), a);
+  Output c = ops::Mul(s.WithOpName("c"), a, b);
+  Output d = ops::Shape(s.WithOpName("d"), a);
+  Output e = ops::Shape(s.WithOpName("e"), b);
+  auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
+  Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
+  Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */);
+  GraphDef output;
+  Status status = fold.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  // Run a second time to make sure the optimization is idempotent.
+  item.graph.Swap(&output);
+  status = fold.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  int found = 0;
+  for (const auto& node : output.node()) {
+    if (node.name() == "o1") {
+      ++found;
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("ConstantFolding/f-0", node.input(0));
+    } else if (node.name() == "o2") {
+      ++found;
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("ConstantFolding/f-1", node.input(0));
+    } else if (node.name() == "ConstantFolding/f-0") {
+      ++found;
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("^f", node.input(0));
+    } else if (node.name() == "ConstantFolding/f-1") {
+      ++found;
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("^f", node.input(0));
+    }
+  }
+  EXPECT_EQ(4, found);
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index a9875c06d8b..6204a81f805 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -64,8 +64,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
       optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
     }
     if (cfg_.constant_folding() != RewriterConfig::OFF) {
-      optimizers.push_back(
-          std::unique_ptr<GraphOptimizer>(new ConstantFolding(cpu_device_)));
+      optimizers.push_back(std::unique_ptr<GraphOptimizer>(
+          new ConstantFolding(cfg_.constant_folding(), cpu_device_)));
     }
     if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
       optimizers.push_back(std::unique_ptr<GraphOptimizer>(

From 8f7aa84efea39b71b45040d89ef01fc15faa519b Mon Sep 17 00:00:00 2001
From: Alexandre Passos <apassos@google.com>
Date: Wed, 8 Nov 2017 13:44:26 -0800
Subject: [PATCH 045/115] Moves imperative_grad to C

Neutral-to-positive on all benchmarks. Also reduces overhead of should_record.

PiperOrigin-RevId: 175057104
---
 tensorflow/c/eager/BUILD                   |   1 +
 tensorflow/c/eager/tape.cc                 | 312 +++++++++++++++++++-
 tensorflow/c/eager/tape.h                  |  58 +++-
 tensorflow/python/eager/BUILD              |   7 +-
 tensorflow/python/eager/backprop.py        |  14 +-
 tensorflow/python/eager/backprop_test.py   |  57 +---
 tensorflow/python/eager/imperative_grad.py | 194 +------------
 tensorflow/python/eager/pywrap_tensor.cc   |   8 +-
 tensorflow/python/eager/pywrap_tensor.h    |  25 ++
 tensorflow/python/eager/pywrap_tfe.h       |  13 +-
 tensorflow/python/eager/pywrap_tfe_src.cc  | 317 +++++++++++++++++----
 tensorflow/python/eager/tape.py            |  12 +-
 tensorflow/python/eager/tape_test.py       |  20 --
 tensorflow/python/pywrap_tfe.i             |   4 +-
 14 files changed, 704 insertions(+), 338 deletions(-)
 create mode 100644 tensorflow/python/eager/pywrap_tensor.h

diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index c77896b80b4..74e94be8d68 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -39,6 +39,7 @@ tf_cuda_library(
 tf_cuda_library(
     name = "c_api_internal",
     hdrs = ["c_api_internal.h"],
+    visibility = ["//tensorflow:internal"],
     deps = [
         ":c_api",
         ":runtime",
diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc
index 464612a81eb..459499bb694 100644
--- a/tensorflow/c/eager/tape.cc
+++ b/tensorflow/c/eager/tape.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#include <unordered_set>
+
 #include "tensorflow/c/eager/tape.h"
 
 namespace tensorflow {
@@ -94,8 +96,314 @@ void GradientTape::DeleteTrace(int64 tensor_id) {
   op_tape_.erase(op_it);
 }
 
-std::pair<TensorTape, OpTape> GradientTape::Export() {
-  return {std::move(tensor_tape_), std::move(op_tape_)};
+// Terminology:
+//
+//  - op: a possibly composite operation, which has an entry in the tape
+//  - target: dy in dx/dy
+//  - source: dx in dx/dy
+//  - tensor: one of the many inputs or outputs of an operation
+//
+// Below here we do the gradient algorithm. It works as follows:
+//
+// First we filter the tape to just the subset of operations we want to
+// differentiate. In the process of doing so we count how many times each Tensor
+// is used as an input to an op (so we know when we're done computing gradients
+// for that Tensor). We also count, for each tape entry, how many of its output
+// Tensors need gradients to be computed (Tensors which are not used do not need
+// any gradients to be computed).
+//
+// Finally, we start a backprop stack with a set of tape entries for which we
+// have all gradients available. This set usually is a subset of the set of
+// targets (not all since targets which have outputs in the tape will not have
+// gradients available initially).
+//
+// Then we repeatedly pop an entry from the stack, run its backprop, and update
+// the gradients of its inputs. Once we have computed all gradients for a single
+// input we can mark this input as done, and this can trigger adding an entry to
+// the stack if all outputs of that entry are now done.
+//
+// When the stack is empty we have gradients for all tensors we're interested
+// in.
+
+struct BackpropInitialState {
+  OpTape op_tape;
+
+  // Map from tensor ID to how many references still exist for this tensor in
+  // the tape.
+  std::unordered_map<int64, int64> tensor_usage_counts;
+
+  // Maps from op ID to how many output tensors of this op still need to have
+  // their gradients computed.
+  std::unordered_map<int64, int64> op_missing_tensor;
+};
+
+BackpropInitialState PrepareBackprop(
+    gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
+    OpTape op_tape, const std::unordered_set<int64>& sources_set) {
+  std::vector<int64> tensor_stack;
+  tensor_stack.reserve(target.size());
+  for (auto t : target) {
+    tensor_stack.push_back(t);
+  }
+  BackpropInitialState result;
+  while (!tensor_stack.empty()) {
+    int64 tensor_id = tensor_stack.back();
+    tensor_stack.pop_back();
+    auto op_id_it = tensor_tape.find(tensor_id);
+    if (op_id_it == tensor_tape.end()) {
+      continue;
+    }
+    int64 op_id = op_id_it->second;
+    auto op_it = op_tape.find(op_id);
+    auto result_op_it = result.op_tape.find(op_id);
+    if (op_id == -1 || op_it == op_tape.end() ||
+        result_op_it != result.op_tape.end()) {
+      continue;
+    }
+    CHECK(result.op_tape.emplace(op_id, op_it->second).second);
+    for (auto it : op_it->second.input_tensor_id) {
+      auto count_it = result.tensor_usage_counts.find(it);
+      if (count_it != result.tensor_usage_counts.end()) {
+        count_it->second++;
+      } else {
+        result.tensor_usage_counts[it] = 1;
+        if (sources_set.find(it) == sources_set.end() &&
+            tensor_tape.find(it) != tensor_tape.end()) {
+          tensor_stack.push_back(it);
+        }
+      }
+    }
+    op_tape.erase(op_it);
+  }
+  for (auto& pair : result.tensor_usage_counts) {
+    auto it = tensor_tape.find(pair.first);
+    if (it != tensor_tape.end() && it->second != -1) {
+      result.op_missing_tensor[it->second] += 1;
+    }
+  }
+  // Call destructors for all unneeded gradient functions.
+  for (const auto& op_pair : op_tape) {
+    op_pair.second.backward_function_deleter();
+  }
+  return result;
+}
+
+std::vector<int64> InitialStack(
+    const OpTape& op_tape,
+    const std::unordered_map<int64, int64>& op_missing_tensor) {
+  std::vector<int64> result;
+  for (auto& op_entry : op_tape) {
+    if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
+      result.push_back(op_entry.first);
+    }
+  }
+  return result;
+}
+
+Status InitialGradients(const VSpace& vspace, gtl::ArraySlice<void*> target,
+                        gtl::ArraySlice<void*> output_gradients,
+                        std::unordered_map<int64, int64> tensor_usage_counts,
+                        std::unordered_map<int64, std::vector<void*>>* result) {
+  for (int i = 0; i < target.size(); ++i) {
+    int64 id = vspace.TensorId(target[i]);
+    if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
+      if (!output_gradients.empty() && output_gradients[i] != nullptr) {
+        // TODO(apassos) figure out how to print debugging information here.
+        return errors::InvalidArgument(
+            "A gradient was provided for a tensor which is used as part of the "
+            "computation.");
+      }
+    } else {
+      if (output_gradients.empty() || output_gradients[i] == nullptr) {
+        (*result)[id].push_back(vspace.OnesLike(target[i]));
+      } else {
+        (*result)[id].push_back(output_gradients[i]);
+      }
+    }
+  }
+  return Status::OK();
+}
+
+// If over kMinAggregateCount gradients are accumulated and the total
+// memory consumption is over kMinAggregateBytes, do an early aggregation
+// so as to release the gradient tensor to save memory.
+static const int kMinAggregateCount = 4;
+static const int kMinAggregateBytes = 128 * 1024 * 1024;
+
+Status GradientTape::Gradient(const VSpace& vspace,
+                              gtl::ArraySlice<void*> target,
+                              gtl::ArraySlice<void*> sources,
+                              gtl::ArraySlice<void*> output_gradients,
+                              std::vector<void*>* result) {
+  std::vector<int64> id_sources;
+  id_sources.reserve(sources.size());
+  for (void* s : sources) {
+    id_sources.push_back(vspace.TensorId(s));
+  }
+  std::unordered_set<int64> sources_set(id_sources.begin(), id_sources.end());
+  std::vector<int64> id_targets;
+  id_sources.reserve(target.size());
+  for (void* t : target) {
+    id_targets.push_back(vspace.TensorId(t));
+  }
+  BackpropInitialState state = PrepareBackprop(
+      id_targets, tensor_tape_, std::move(op_tape_), sources_set);
+  std::vector<int64> op_stack =
+      InitialStack(state.op_tape, state.op_missing_tensor);
+  std::unordered_map<int64, std::vector<void*>> gradients;
+  Status s = InitialGradients(vspace, target, output_gradients,
+                              state.tensor_usage_counts, &gradients);
+  auto cleanup = [&state]() {
+    // Release all backprop functions
+    for (const auto& pair : state.op_tape) {
+      pair.second.backward_function_deleter();
+    }
+  };
+  if (!s.ok()) {
+    cleanup();
+    return s;
+  }
+  std::unordered_map<int64, int64> gradients_size;
+  // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
+  // time, for better CPU backprop performance.
+  VLOG(1) << "Initial stack:";
+  if (VLOG_IS_ON(1)) {
+    for (auto t : op_stack) {
+      VLOG(1) << "  " << t;
+    }
+  }
+  std::unordered_map<string, std::unordered_set<int>>
+      functions_accept_none_for_indices({
+          {"SoftmaxCrossEntropyWithLogits", {1}},
+          {"FusedBatchNorm", {1, 2, 3, 4}},
+      });
+  while (!op_stack.empty()) {
+    const int64 op = op_stack.back();
+    VLOG(1) << "Popped " << op;
+    op_stack.pop_back();
+    auto op_it = state.op_tape.find(op);
+    if (op_it == state.op_tape.end()) {
+      // It is possible for ops to end up on the stack if they are unrelated to
+      // the target; we should just skip them.
+      continue;
+    }
+    auto trace = std::move(op_it->second);
+    state.op_tape.erase(op_it);
+    std::vector<void*> out_gradients;
+    out_gradients.reserve(trace.output_tensor_info.size());
+    for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
+      const int64 id = trace.output_tensor_info[i].id;
+      auto grad_it = gradients.find(id);
+      if (grad_it == gradients.end()) {
+        auto func_name_it =
+            functions_accept_none_for_indices.find(trace.op_type);
+        if (func_name_it != functions_accept_none_for_indices.end() &&
+            func_name_it->second.find(i) != func_name_it->second.end()) {
+          out_gradients.push_back(nullptr);
+        } else {
+          out_gradients.push_back(
+              vspace.Zeros(trace.output_tensor_info[i].shape,
+                           trace.output_tensor_info[i].dtype));
+        }
+      } else {
+        out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
+        if (sources_set.find(grad_it->first) == sources_set.end()) {
+          gradients.erase(grad_it);
+        }
+      }
+    }
+    std::vector<void*> in_gradients;
+    Status s = vspace.CallBackwardFunction(trace.backward_function,
+                                           out_gradients, &in_gradients);
+    if (!s.ok()) {
+      VLOG(1) << "Gradient function failed.";
+      cleanup();
+      return s;
+    }
+    VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
+            << trace.input_tensor_id.size() << " sources";
+    for (int i = 0; i < in_gradients.size(); ++i) {
+      const int64 id = trace.input_tensor_id[i];
+      if (in_gradients[i] != nullptr) {
+        auto& unaggregated_grads = gradients[id];
+        unaggregated_grads.push_back(in_gradients[i]);
+        if (unaggregated_grads.size() > kMinAggregateCount) {
+          auto size_it = gradients_size.find(id);
+          int64 size;
+          if (size_it == gradients_size.end()) {
+            size = vspace.NumElements(unaggregated_grads[0]);
+            gradients_size.emplace(id, size);
+          } else {
+            size = size_it->second;
+          }
+          if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
+            void* tensor = vspace.AggregateGradients(unaggregated_grads);
+            unaggregated_grads.clear();
+            unaggregated_grads.push_back(tensor);
+          }
+        }
+      }
+      auto usage_count_it = state.tensor_usage_counts.find(id);
+      if (usage_count_it == state.tensor_usage_counts.end()) {
+        VLOG(1) << "Tensor " << id << " not used";
+        continue;
+      }
+      usage_count_it->second--;
+      if (usage_count_it->second > 0) {
+        VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
+        continue;
+      }
+      auto tape_it = tensor_tape_.find(id);
+      if (tape_it == tensor_tape_.end()) {
+        VLOG(1) << "Tensor " << id
+                << " has no associated op. Deleting gradient";
+        auto grad_it = gradients.find(id);
+        if (grad_it != gradients.end()) {
+          for (auto g : grad_it->second) {
+            vspace.DeleteTensor(g);
+          }
+          gradients.erase(grad_it);
+        }
+        continue;
+      }
+      const int64 op_id = tape_it->second;
+      if (op_id == -1) {
+        VLOG(1) << "Tensor " << id << " is source";
+        continue;
+      }
+      auto missing_it = state.op_missing_tensor.find(op_id);
+      if (missing_it != state.op_missing_tensor.end()) {
+        missing_it->second--;
+        VLOG(1) << "Op " << op_id << " missing " << missing_it->second
+                << " output gradients";
+        if (missing_it->second == 0) {
+          op_stack.push_back(op_id);
+        }
+      }
+    }
+  }
+  CHECK(state.op_tape.empty());
+  result->reserve(sources.size());
+  for (auto is : id_sources) {
+    auto grad_it = gradients.find(is);
+    if (grad_it == gradients.end()) {
+      result->push_back(nullptr);
+    } else {
+      if (grad_it->second.size() == 1) {
+        result->push_back(grad_it->second[0]);
+      } else {
+        result->push_back(vspace.AggregateGradients(grad_it->second));
+      }
+      gradients.erase(grad_it);
+    }
+  }
+  VLOG(1) << "Final gradients size: " << gradients.size();
+  for (auto grad_pair : gradients) {
+    for (const auto& g : grad_pair.second) {
+      vspace.DeleteTensor(g);
+    }
+  }
+  return Status::OK();
 }
 
 }  // namespace eager
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index df51f300eb6..2bb62a7ab37 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -57,11 +57,57 @@ using TensorTape = std::unordered_map<int64, int64>;
 // Map from operation-id to tape entry.
 using OpTape = std::unordered_map<int64, OpTapeEntry>;
 
+// Operations the tape needs to perform on tensors to do backpropagation. Named
+// "vspace" because a subset of these are related to a vector space, such as
+// adding gradients, getting zeroes, etc. Currently cannot be implemented
+// without using tensorflow python code, hence left unspecified here.
+//
+// We currently use void* for tensors, backward functions, and gradients (which
+// can be but are not required to be tensors). TODO(apassos) replace this first
+// with templates to allow for pyobject specialization in the client followed by
+// a TFE_TensorHandle specialization, which is blocked by quite a few things
+// still.
+class VSpace {
+ public:
+  virtual ~VSpace() {}
+
+  // Returns the number of elements in the tensor.
+  virtual int64 NumElements(void* tensor) const = 0;
+
+  // Consumes references to the tensors in the gradient_tensors list and returns
+  // a tensor with the result.
+  virtual void* AggregateGradients(
+      gtl::ArraySlice<void*> gradient_tensors) const = 0;
+
+  // Returns a tensor of the right shape and dtype filled with zeros.
+  virtual void* Zeros(TensorShape shape, DataType dtype) const = 0;
+
+  // Returns a Tensor which is filled with ones and like the input.
+  virtual void* OnesLike(void*) const = 0;
+
+  // Returns an integer which is a unique-to-within-this-program handle for this
+  // tensor.
+  virtual int64 TensorId(void* tensor) const = 0;
+
+  // Calls the passed-in backward function.
+  virtual Status CallBackwardFunction(void* backward_function,
+                                      gtl::ArraySlice<void*> output_gradients,
+                                      std::vector<void*>* result) const = 0;
+
+  // Deletes the input tensor.
+  virtual void DeleteTensor(void* tensor) const = 0;
+};
+
 // Traces the execution of operations, doing eager garbage collection, and
 // exporting a full trace so other code can do backpropagation. Not thread-safe.
 class GradientTape {
  public:
   GradientTape() {}
+  ~GradientTape() {
+    for (const auto& pair : op_tape_) {
+      pair.second.backward_function_deleter();
+    }
+  }
 
   bool ShouldRecord(gtl::ArraySlice<int64> tensor_ids);
 
@@ -75,10 +121,14 @@ class GradientTape {
 
   void DeleteTrace(int64 tensor_id);
 
-  // Note: it is only valid to call Export once per tape, and after calling
-  // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch,
-  // Record, and Delete have undefined behavior).
-  std::pair<TensorTape, OpTape> Export();
+  // Consumes the internal state of the tape (so cannot be called more than
+  // once) and produces the gradient of the target tensors with respect to the
+  // source tensors. The output gradients are used if not empty and not
+  // null. The result is populated with one tensor per target element.
+  Status Gradient(const VSpace& vspace, gtl::ArraySlice<void*> target,
+                  gtl::ArraySlice<void*> sources,
+                  gtl::ArraySlice<void*> output_gradients,
+                  std::vector<void*>* result);
 
  private:
   TensorTape tensor_tape_;
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index bcd1e1d0dca..c36647b21c4 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -14,11 +14,16 @@ cc_library(
         "pywrap_tensor.cc",
         "pywrap_tfe_src.cc",
     ],
-    hdrs = ["pywrap_tfe.h"],
+    hdrs = [
+        "pywrap_tensor.h",
+        "pywrap_tfe.h",
+    ],
     visibility = ["//tensorflow:internal"],
     deps = [
         "//tensorflow/c:c_api",
+        "//tensorflow/c:c_api_internal",
         "//tensorflow/c/eager:c_api",
+        "//tensorflow/c/eager:c_api_internal",
         "//tensorflow/c/eager:tape",
         "//tensorflow/core:lib",
         "//tensorflow/python:ndarray_tensor",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 86b3776b8c5..111d7cef56a 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -727,11 +727,23 @@ def _num_elements(grad):
   raise ValueError("`grad` not a Tensor or IndexedSlices.")
 
 
+_last_shape_dtype = [None, None]
+_last_zero = [None]
+
+
+def _zeros(shape, dtype):
+  """Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
+  if [shape, dtype] != _last_shape_dtype:
+    _last_shape_dtype[:] = [shape, dtype]
+    _last_zero[0] = array_ops.zeros(shape, dtype)
+  return _last_zero[0]
+
+
 _default_vspace = imperative_grad.VSpace(
     num_elements_fn=_num_elements,
     aggregate_fn=_aggregate_grads,
     tensor_id=ops.tensor_id,
-    zeros=array_ops.zeros,
+    zeros=_zeros,
     ones_like=lambda x: ops.convert_to_tensor(array_ops.ones_like(x)))
 
 
diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py
index ed54b8e12e7..ec9a185b736 100644
--- a/tensorflow/python/eager/backprop_test.py
+++ b/tensorflow/python/eager/backprop_test.py
@@ -24,11 +24,11 @@ from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import custom_gradient
-from tensorflow.python.eager import imperative_grad
 from tensorflow.python.eager import tape
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
@@ -41,7 +41,6 @@ from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.training import training
-from tensorflow.python.util import compat
 
 
 class BackpropTest(test.TestCase):
@@ -103,6 +102,18 @@ class BackpropTest(test.TestCase):
     grad_fn = backprop.gradients_function(f)
     self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
 
+  def testErrors(self):
+
+    @custom_gradient.custom_gradient
+    def f(x):
+      def grad(_):
+        raise RuntimeError('x')
+      return x, grad
+
+    # TODO(apassos) raise the right error here
+    with self.assertRaises(errors_impl.InternalError):
+      backprop.gradients_function(f)(constant_op.constant(1.0))
+
   def testImplicitGradOverEmbeddingLookup(self):
     batch_size = 8
     embedding_size = 512
@@ -483,48 +494,6 @@ class BackpropTest(test.TestCase):
         initial_value=1., name='testSameObjectForMultipleArguments.Variable')
     self.assertAllEqual([1., 1.], np_g(v, v))
 
-  def testEarlyGradAggregation(self):
-    # Needs to be a list so mutations by the callback affect this function.
-    add_n = []
-    def callback(op_type, unused_1, unused_2, unused_3, unused_4):
-      if compat.as_bytes(op_type) == compat.as_bytes('AddN'):
-        add_n.append(1)
-    context.context().add_post_execution_callback(callback)
-
-    v = resource_variable_ops.ResourceVariable(constant_op.constant(2.0),
-                                               name='v')
-    def fn():
-      outputs = []
-      for _ in range(20):
-        outputs.append(v * constant_op.constant(2.0))
-      return math_ops.add_n(outputs)
-
-    # By default the aggregation count is 2.
-    _ = backprop.implicit_grad(fn)()[0][1]
-    self.assertEqual(len(add_n), 2)
-    del add_n[:]
-
-    # Reduce the aggregation limit, cause the backprop to do some
-    # early aggregation.
-    # pylint: disable=protected-access
-    old_cnt = imperative_grad._MIN_AGGREGATE_COUNT
-    old_bytes = imperative_grad._MIN_AGGREGATE_BYTES
-    imperative_grad._MIN_AGGREGATE_COUNT = 10
-    imperative_grad._MIN_AGGREGATE_BYTES = 1
-    _ = backprop.implicit_grad(fn)()
-    self.assertEqual(len(add_n), 6)
-    del add_n[:]
-
-    # Aggregation is also limited by the memory.
-    imperative_grad._MIN_AGGREGATE_BYTES = 10000
-    _ = backprop.implicit_grad(fn)()
-    self.assertEqual(len(add_n), 2)
-
-    imperative_grad._MIN_AGGREGATE_COUNT = old_cnt
-    imperative_grad._MIN_AGGREGATE_BYTES = old_bytes
-    # pylint: enable=protected-access
-    context.context().clear_post_execution_callbacks()
-
   def testImplicitGradientsCustomGradientAndCachedVariableValue(self):
 
     @custom_gradient.custom_gradient
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index c87719f84ab..8932b7157b2 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -20,102 +20,8 @@ from __future__ import print_function
 
 import collections
 
-from tensorflow.python.eager import tape as tape_module
-
-
-# Terminology:
-#
-#  - op: a possibly composite operation, which has an entry in the tape
-#  - target: dy in dx/dy
-#  - source: dx in dx/dy
-#  - tensor: one of the many inputs or outputs of an operation
-#
-# Below here we do the gradient algorithm. It works as follows:
-#
-# First we filter the tape to just the subset of operations we want to
-# differentiate. In the process of doing so we count how many times each Tensor
-# is used as an input to an op (so we know when we're done computing gradients
-# for that Tensor). We also count, for each tape entry, how many of its output
-# Tensors need gradients to be computed (Tensors which are not used do not need
-# any gradients to be computed).
-#
-# Finally, we start a backprop stack with a set of tape entries for which we
-# have all gradients available. This set usually is a subset of the set of
-# targets (not all since targets which have outputs in the tape will not have
-# gradients available initially).
-#
-# Then we repeatedly pop an entry from the stack, run its backprop, and update
-# the gradients of its inputs. Once we have computed all gradients for a single
-# input we can mark this input as done, and this can trigger adding an entry to
-# the stack if all outputs of that entry are now done.
-#
-# When the stack is empty we have gradients for all tensors we're interested in.
-def _prepare_backprop(vspace, target, tensor_to_op, op_to_entry, id_sources):
-  """Filters the tape to only include relevant entries and counts tensor usages.
-
-  Args:
-    vspace: information about the space we're differentiating in.
-    target: the target to optimize.
-    tensor_to_op: Map from tensor id to key in op_to_entry that produced it.
-    op_to_entry: Map from op id to a tape.TapeEntry object
-    id_sources: the ids of the sources wrt the gradient is being taken.
-
-  Returns:
-    usage counts (how many entries downstream from a tensor use it)
-    op_to_entry_map: entry map (a filtered tape, with only the relevant
-     entries),
-    missing: map from tensor id to how many downstream gradients still need
-     to be computed before this tensor's gradient can be computed.
-  """
-  tensor_stack = [vspace.tensor_id(x) for x in target]
-  tensor_usage_counts = {}
-  o_to_e = {}  # Copy of just the bits we need from op_to_entry
-  while tensor_stack:
-    t = tensor_stack.pop()
-    op = tensor_to_op.get(t, None)
-    # op is None or -1 if the tensor is a source (i.e. was watched directly)
-    if op is None or op == -1 or op in o_to_e:
-      continue
-    op_trace = tape_module.TapeEntry(*op_to_entry[op])
-    o_to_e[op] = op_trace
-    for it in op_trace.input_ids:
-      if it in tensor_usage_counts:
-        tensor_usage_counts[it] += 1
-      else:
-        tensor_usage_counts[it] = 1
-        if it not in id_sources and it in tensor_to_op:
-          tensor_stack.append(it)
-  op_missing_tensor_counts = collections.defaultdict(int)
-  for t in tensor_usage_counts:
-    if t in tensor_to_op and tensor_to_op[t] is not None:
-      op_missing_tensor_counts[tensor_to_op[t]] += 1
-  return tensor_usage_counts, o_to_e, op_missing_tensor_counts
-
-
-def _initialize_backprop_stack(op_to_entry, op_missing_tensor):
-  """Returns the set of tape entries which are available for backprop."""
-  ready_ops = []
-  for op in op_to_entry:
-    if op not in op_missing_tensor:
-      ready_ops.append(op)
-  return ready_ops
-
-
-def _initial_gradients(vspace, target, output_gradients, tensor_usage_counts):
-  """Computes the initial gradients for each Tensor."""
-  # Initialize the backprop stack
-  gradients = collections.defaultdict(list)
-  for i, t in enumerate(target):
-    if vspace.tensor_id(t) in tensor_usage_counts:
-      # Can't provide a gradient of something we're trying to differentiate
-      assert output_gradients is None or output_gradients[i] is None
-    else:
-      if output_gradients is None or output_gradients[i] is None:
-        out_grad = vspace.ones_like(t)
-      else:
-        out_grad = output_gradients[i]
-      gradients[vspace.tensor_id(t)].append(out_grad)
-  return gradients
+from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.framework import errors
 
 
 VSpace = collections.namedtuple(
@@ -123,13 +29,6 @@ VSpace = collections.namedtuple(
     ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones_like"])
 
 
-# If over MIN_AGGREGATE_COUNT gradients are accumulated and the total
-# memory consumption is over MIN_AGGREGATE_BYTES, do an early aggregation
-# so as to release the gradient tensor to save memory.
-_MIN_AGGREGATE_COUNT = 4
-_MIN_AGGREGATE_BYTES = 128 * 1024 * 1024
-
-
 def imperative_grad(
     vspace,
     tape,
@@ -161,89 +60,6 @@ def imperative_grad(
      or if only non-differentiable functions of the source were used in the
      computation of target.
   """
-  tensor_to_op, op_to_entry = tape.export()
-  # This overwrites the op_to_entry variable, which will release all memory used
-  # to keep traces that are irrelevant to the gradient computation we're doing
-  # here.
-  id_sources = [vspace.tensor_id(t) for t in sources]
-  tensor_usage_counts, op_to_entry, op_missing_tensor = _prepare_backprop(
-      vspace, target, tensor_to_op, op_to_entry, id_sources)
-  ready_ops = _initialize_backprop_stack(op_to_entry, op_missing_tensor)
-  gradients = _initial_gradients(vspace, target, output_gradients,
-                                 tensor_usage_counts)
-  gradients_size = dict()
-  # Now exhaust the backprop stack
-  while ready_ops:
-    op = ready_ops.pop()
-    op_trace = op_to_entry.pop(op)
-    out_gradients = [gradients.pop(t, None) for t in op_trace.output_ids]
-
-    # Cache the last used zero tensor. We reuse it if the next one
-    # we need is of the same shape and dtype. This is very helpful in
-    # large splits and should have negligible overhead in other cases.
-    last_shape_and_dtype = None
-    last_zeros = None
-    for i in range(len(out_gradients)):
-      if out_gradients[i] is None:
-        # TODO(apassos) this should be in the right device
-        none_indices = _grad_fn_accepts_none_for_indices.get(
-            op_trace.op_type, None)
-        if none_indices is None or i not in none_indices:
-          shape_and_dtype = op_trace.output_shape_and_dtype[i]
-          if shape_and_dtype != last_shape_and_dtype:
-            last_shape_and_dtype = shape_and_dtype
-            last_zeros = vspace.zeros(*shape_and_dtype)
-          out_gradients[i] = last_zeros
-      else:
-        out_gradients[i] = vspace.aggregate_fn(out_gradients[i])
-
-    in_gradients = op_trace.backward_function(*(out_gradients))
-    for i, t in enumerate(op_trace.input_ids):
-      if in_gradients[i] is not None:
-        t_grads = gradients.setdefault(t, [])
-        t_grads.append(in_gradients[i])
-        if len(t_grads) >= _MIN_AGGREGATE_COUNT:
-          if t not in gradients_size:
-            gradients_size[t] = vspace.num_elements_fn(t_grads[-1])
-          size = gradients_size[t]
-
-          if len(t_grads) * size * 4 > _MIN_AGGREGATE_BYTES:
-            t_grads[:] = [vspace.aggregate_fn(t_grads)]
-      if tensor_usage_counts.get(t, 0) > 0:
-        tensor_usage_counts[t] -= 1
-        if (t in tensor_to_op
-            and tensor_usage_counts[t] == 0
-            and t not in id_sources):
-          in_op = tensor_to_op[t]
-          if in_op is None or in_op == -1:
-            continue
-          if op_missing_tensor.get(in_op, 0) > 0:
-            op_missing_tensor[in_op] -= 1
-            if op_missing_tensor.get(in_op, 0) == 0:
-              ready_ops.append(in_op)
-  result = []
-  for i, s in enumerate(sources):
-    g = gradients.get(vspace.tensor_id(s), None)
-    if g is None:
-      result.append(None)
-    else:
-      result.append(vspace.aggregate_fn(g))
-  return result
-
-
-# TODO(agarwal): use an automatic mechanism for handling None arguments to
-# gradient functions.
-# Some gradient functions can accept None arguments for gradients. The following
-# maps the operation name to the indices at which the corresponding gradient
-# function can accept None values.
-# e.g. FusedBatchNorm outputs 5 values and hence receives 5 gradient values
-# during backprop. However the gradient function uses only the first of those
-# values and ignores the rest. The entry, "FusedBatchNorm": [1, 2, 3, 4],
-# indicates that only the gradient corresponding to index 0 is used, and the
-# gradient values at indices 1-4 are ignored (and hence can be None). The
-# backprop algorithm can then leverage this by not constructing zeros to
-# pass for those indices.
-_grad_fn_accepts_none_for_indices = {
-    "SoftmaxCrossEntropyWithLogits": [1],
-    "FusedBatchNorm": [1, 2, 3, 4]
-}
+  with errors.raise_exception_on_not_ok_status() as status:
+    return pywrap_tensorflow.TFE_Py_TapeGradient(
+        tape._tape, vspace, target, sources, output_gradients, status)  # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc
index ca283862f93..653f3ef84e3 100644
--- a/tensorflow/python/eager/pywrap_tensor.cc
+++ b/tensorflow/python/eager/pywrap_tensor.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/python/lib/core/py_seq_tensor.h"
 #include "tensorflow/python/lib/core/safe_ptr.h"
 
+#include "tensorflow/python/eager/pywrap_tensor.h"
 #include "tensorflow/python/eager/pywrap_tfe.h"
 
 #include "tensorflow/c/c_api.h"
@@ -573,7 +574,7 @@ bool EagerTensor_CheckExact(const PyObject* o) {
   return Py_TYPE(o) == EagerTensorType;
 }
 
-TFE_TensorHandle* EagerTensorHandle(const PyObject* o) {
+TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
   return reinterpret_cast<const EagerTensor*>(o)->handle;
 }
 
@@ -594,6 +595,11 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
   return reinterpret_cast<PyObject*>(t);
 }
 
+tensorflow::int64 EagerTensor_id(const PyObject* tensor) {
+  CHECK(EagerTensor_CheckExact(tensor));
+  return reinterpret_cast<const EagerTensor*>(tensor)->id;
+}
+
 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
   if (!PyType_Check(base_class)) {
     PyErr_SetString(
diff --git a/tensorflow/python/eager/pywrap_tensor.h b/tensorflow/python/eager/pywrap_tensor.h
new file mode 100644
index 00000000000..aa1efdd1b81
--- /dev/null
+++ b/tensorflow/python/eager/pywrap_tensor.h
@@ -0,0 +1,25 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
+#define TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
+
+#include "tensorflow/c/eager/c_api.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/lib/core/numpy.h"
+
+bool EagerTensor_CheckExact(const PyObject* o);
+tensorflow::int64 EagerTensor_id(const PyObject* tensor);
+
+#endif  // TENSORFLOW_PYTHON_EAGER_PYWRAP_TENSOR_H_
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 1d03df29336..6705483f3b3 100644
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -81,7 +81,7 @@ bool EagerTensor_CheckExact(const PyObject* o);
 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle);
 
 // Extracts the handle inside EagerTensor object `o`. Returns nullptr on error.
-TFE_TensorHandle* EagerTensorHandle(const PyObject* o);
+TFE_TensorHandle* EagerTensor_Handle(const PyObject* o);
 
 // Creates the `EagerTensor` class by subclassing `base_class` and returns the
 // newly created type, or nullptr on error.
@@ -103,7 +103,16 @@ void TFE_Py_TapeRecordOperation(PyObject* tape, PyObject* op_type,
                                 PyObject* output_tensors,
                                 PyObject* input_tensor_ids,
                                 PyObject* backward_function);
-PyObject* TFE_Py_TapeExport(PyObject* tape);
+
+// Computes a gradient based on information recorded on the tape.`tape` must
+// have been produced by TFE_Py_NewTape. `vspace` must be a
+// imperative_grad.py:VSpace named tuple. `target` and `sources` must be python
+// lists of Tensor objects. `output_gradients` is either None or a python list
+// of either Tensor or None, and if not None should have the same length as
+// target.
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
+                              PyObject* target, PyObject* sources,
+                              PyObject* output_gradients, TF_Status* status);
 
 // Returns an EagerTensor of dimension [len(`tensor_list`)] containing
 // the `slice_dim`'th dimension of each tensor in `tensor_list`. In other words,
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 7456eb10f86..a00a7615d7d 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -16,10 +16,13 @@ limitations under the License.
 #include "tensorflow/python/eager/pywrap_tfe.h"
 
 #include "tensorflow/c/c_api.h"
+#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/c/eager/c_api_internal.h"
 #include "tensorflow/c/eager/tape.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/python/eager/pywrap_tensor.h"
 
 using tensorflow::string;
 
@@ -515,18 +518,50 @@ static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
 }
 
 PyObject* TFE_Py_TapeShouldRecord(PyObject* py_tape, PyObject* tensors) {
+  if (tensors == Py_None) {
+    Py_RETURN_FALSE;
+  }
+  PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+  if (seq == nullptr) {
+    return nullptr;
+  }
+  int len = PySequence_Fast_GET_SIZE(seq);
+  // TODO(apassos) consider not building a list and changing the API to check
+  // each tensor individually.
+  std::vector<tensorflow::int64> tensor_ids;
+  tensor_ids.reserve(len);
+  for (int i = 0; i < len; ++i) {
+    PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+    if (EagerTensor_CheckExact(item)) {
+      tensor_ids.push_back(EagerTensor_id(item));
+    } else {
+      PyObject* id_field = PyObject_GetAttrString(item, "_id");
+      if (id_field == nullptr) {
+        return nullptr;
+      }
+      tensor_ids.push_back(MakeInt(id_field));
+      Py_DECREF(id_field);
+    }
+  }
+  Py_DECREF(seq);
   TFE_Py_Tape* tape = reinterpret_cast<TFE_Py_Tape*>(py_tape);
-  return PyBool_FromLong(tape->tape->ShouldRecord(MakeIntList(tensors)));
+  if (tape->tape->ShouldRecord(tensor_ids)) {
+    Py_RETURN_TRUE;
+  } else {
+    Py_RETURN_FALSE;
+  }
 }
 
 void TFE_Py_TapeWatch(PyObject* tape, tensorflow::int64 tensor_id) {
   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
 }
 
-// TODO(apassos) have a fast path for eager tensors here which gets information
-// from the handle instead of from the python object, and use this only for the
-// case of graph tensors.
 static tensorflow::eager::TapeTensor TapeTensorFromTensor(PyObject* tensor) {
+  if (EagerTensor_CheckExact(tensor)) {
+    TFE_TensorHandle* t = EagerTensor_Handle(tensor);
+    tensorflow::int64 id = EagerTensor_id(tensor);
+    return tensorflow::eager::TapeTensor{id, t->t.dtype(), t->t.shape()};
+  }
   PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
   tensorflow::int64 id = MakeInt(id_field);
   Py_DECREF(id_field);
@@ -592,64 +627,224 @@ void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) {
   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->DeleteTrace(tensor_id);
 }
 
-// TODO(apassos) when backprop.py moves to C most of this exporting logic can
-// disappear.
-PyObject* TFE_Py_TapeExport(PyObject* tape) {
-  std::pair<tensorflow::eager::TensorTape, tensorflow::eager::OpTape> exported =
-      reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Export();
-  PyObject* tensor_tape = PyDict_New();
-  for (const auto& pair : exported.first) {
-    PyObject* tid = PyLong_FromLong(pair.first);
-    PyObject* opid = PyLong_FromLong(pair.second);
-    PyDict_SetItem(tensor_tape, tid, opid);
-    Py_DECREF(tid);
-    Py_DECREF(opid);
+// TODO(apassos): cache the attribute lookups as member variables and decref
+// them in the destructor.
+class PyVSpace : public tensorflow::eager::VSpace {
+ public:
+  explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
+
+  tensorflow::Status Initialize() {
+    num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
+    if (num_elements_ == nullptr) {
+      return tensorflow::errors::InvalidArgument("invalid vspace");
+    }
+    aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
+    if (aggregate_fn_ == nullptr) {
+      return tensorflow::errors::InvalidArgument("invalid vspace");
+    }
+    zeros_ = PyObject_GetAttrString(py_vspace_, "zeros");
+    if (zeros_ == nullptr) {
+      return tensorflow::errors::InvalidArgument("invalid vspace");
+    }
+    ones_like_ = PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_),
+                                        "ones_like");
+    if (ones_like_ == nullptr) {
+      return tensorflow::errors::InvalidArgument("invalid vspace");
+    }
+    return tensorflow::Status::OK();
   }
 
-  PyObject* op_tape = PyDict_New();
-  for (const auto& pair : exported.second) {
-    PyObject* opid = PyLong_FromLong(pair.first);
-    const auto& entry = pair.second;
-    PyObject* op_type = PyBytes_FromString(entry.op_type.c_str());
-    PyObject* output_ids = PyList_New(entry.output_tensor_info.size());
-    for (int i = 0; i < entry.output_tensor_info.size(); ++i) {
-      PyObject* tid = PyLong_FromLong(entry.output_tensor_info[i].id);
-      PyList_SET_ITEM(output_ids, i, tid);
-    }
-    PyObject* input_ids = PyList_New(entry.input_tensor_id.size());
-    for (int i = 0; i < entry.input_tensor_id.size(); ++i) {
-      PyObject* tid = PyLong_FromLong(entry.input_tensor_id[i]);
-      PyList_SET_ITEM(input_ids, i, tid);
-    }
-    PyObject* backward_function =
-        reinterpret_cast<PyObject*>(entry.backward_function);
-    PyObject* output_shape_and_dtype =
-        PyList_New(entry.output_tensor_info.size());
-    for (int i = 0; i < entry.output_tensor_info.size(); ++i) {
-      const tensorflow::TensorShape& shape = entry.output_tensor_info[i].shape;
-      PyObject* shape_list = PyList_New(shape.dims());
-      for (int j = 0; j < shape.dims(); ++j) {
-        PyList_SET_ITEM(shape_list, j, PyLong_FromLong(shape.dim_size(j)));
-      }
-      PyObject* type_enum = PyLong_FromLong(entry.output_tensor_info[i].dtype);
-      PyObject* tuple = PyTuple_Pack(2, shape_list, type_enum);
-      Py_DECREF(shape_list);
-      Py_DECREF(type_enum);
-      PyList_SET_ITEM(output_shape_and_dtype, i, tuple);
-    }
-    PyObject* opinfo = PyTuple_Pack(5, op_type, output_ids, input_ids,
-                                    backward_function, output_shape_and_dtype);
-    Py_DECREF(op_type);
-    Py_DECREF(output_ids);
-    Py_DECREF(input_ids);
-    Py_DECREF(backward_function);
-    Py_DECREF(output_shape_and_dtype);
-    PyDict_SetItem(op_tape, opid, opinfo);
-    Py_DECREF(opid);
-    Py_DECREF(opinfo);
+  ~PyVSpace() override {
+    Py_XDECREF(num_elements_);
+    Py_XDECREF(aggregate_fn_);
+    Py_XDECREF(zeros_);
+    Py_XDECREF(ones_like_);
   }
-  PyObject* retval = PyTuple_Pack(2, tensor_tape, op_tape);
-  Py_DECREF(tensor_tape);
-  Py_DECREF(op_tape);
-  return retval;
+
+  tensorflow::int64 NumElements(void* tensor) const final {
+    PyObject* arglist =
+        Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
+    PyObject* result = PyEval_CallObject(num_elements_, arglist);
+    tensorflow::int64 r = MakeInt(result);
+    Py_DECREF(result);
+    Py_DECREF(arglist);
+    return r;
+  }
+
+  void* AggregateGradients(
+      tensorflow::gtl::ArraySlice<void*> gradient_tensors) const final {
+    PyObject* list = PyList_New(gradient_tensors.size());
+    for (int i = 0; i < gradient_tensors.size(); ++i) {
+      // Note: stealing a reference to the gradient tensors.
+      CHECK(gradient_tensors[i] != nullptr);
+      CHECK(gradient_tensors[i] != Py_None);
+      PyList_SET_ITEM(list, i,
+                      reinterpret_cast<PyObject*>(gradient_tensors[i]));
+    }
+    PyObject* arglist = Py_BuildValue("(O)", list);
+    CHECK(arglist != nullptr);
+    PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
+    Py_DECREF(arglist);
+    Py_DECREF(list);
+    return result;
+  }
+
+  void* Zeros(tensorflow::TensorShape shape,
+              tensorflow::DataType dtype) const final {
+    PyObject* py_shape = PyTuple_New(shape.dims());
+    for (int i = 0; i < shape.dims(); ++i) {
+      PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
+    }
+    PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
+    PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+    PyObject* result = PyEval_CallObject(zeros_, arg_list);
+    Py_DECREF(arg_list);
+    Py_DECREF(py_dtype);
+    Py_DECREF(py_shape);
+    return reinterpret_cast<void*>(result);
+  }
+
+  void* OnesLike(void* tensor) const final {
+    PyObject* arg_list = Py_BuildValue("(O)", tensor);
+    PyObject* result = PyEval_CallObject(ones_like_, arg_list);
+    if (result == nullptr) {
+      VLOG(1) << "Call to ones_like failed";
+    }
+    Py_DECREF(arg_list);
+    return reinterpret_cast<void*>(result);
+  }
+
+  tensorflow::int64 TensorId(void* tensor) const final {
+    PyObject* py_tensor = reinterpret_cast<PyObject*>(tensor);
+    PyObject* id_field = PyObject_GetAttrString(py_tensor, "_id");
+    tensorflow::int64 id = MakeInt(id_field);
+    Py_DECREF(id_field);
+    return id;
+  }
+
+  tensorflow::Status CallBackwardFunction(
+      void* backward_function,
+      tensorflow::gtl::ArraySlice<void*> output_gradients,
+      std::vector<void*>* result) const final {
+    PyObject* grads = PyTuple_New(output_gradients.size());
+    for (int i = 0; i < output_gradients.size(); ++i) {
+      if (output_gradients[i] == nullptr) {
+        Py_INCREF(Py_None);
+        PyTuple_SET_ITEM(grads, i, Py_None);
+      } else {
+        PyTuple_SET_ITEM(grads, i,
+                         reinterpret_cast<PyObject*>(output_gradients[i]));
+      }
+    }
+    PyObject* py_result = PyEval_CallObject(
+        reinterpret_cast<PyObject*>(backward_function), grads);
+    Py_DECREF(grads);
+    Py_DECREF(backward_function);
+    if (py_result == nullptr) {
+      VLOG(1) << "Gradient function threw exceptions";
+      if (VLOG_IS_ON(1)) {
+        PyErr_Print();
+      }
+      return tensorflow::errors::Internal("gradient function threw exceptions");
+    }
+    result->clear();
+    PyObject* seq =
+        PySequence_Fast(py_result, "expected a sequence of gradients");
+    if (seq == nullptr) {
+      return tensorflow::errors::InvalidArgument(
+          "gradient function did not return a list");
+    }
+    int len = PySequence_Fast_GET_SIZE(seq);
+    VLOG(1) << "Gradient length is " << len;
+    result->reserve(len);
+    for (int i = 0; i < len; ++i) {
+      PyObject* item = PySequence_Fast_GET_ITEM(seq, i);
+      if (item == Py_None) {
+        result->push_back(nullptr);
+      } else {
+        Py_INCREF(item);
+        result->push_back(item);
+      }
+    }
+    Py_DECREF(seq);
+    Py_DECREF(py_result);
+    return tensorflow::Status::OK();
+  }
+
+  void DeleteTensor(void* tensor) const final {
+    Py_XDECREF(reinterpret_cast<PyObject*>(tensor));
+  }
+
+ private:
+  PyObject* py_vspace_;
+
+  PyObject* num_elements_;
+  PyObject* aggregate_fn_;
+  PyObject* zeros_;
+  PyObject* ones_like_;
+};
+
+std::vector<void*> MakeTensorList(PyObject* tensors) {
+  PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+  if (seq == nullptr) {
+    return {};
+  }
+  int len = PySequence_Fast_GET_SIZE(seq);
+  std::vector<void*> list;
+  list.reserve(len);
+  for (int i = 0; i < len; ++i) {
+    list.push_back(PySequence_Fast_GET_ITEM(seq, i));
+  }
+  Py_DECREF(seq);
+  return list;
+}
+
+PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
+                              PyObject* target, PyObject* sources,
+                              PyObject* output_gradients, TF_Status* status) {
+  PyVSpace c_vspace(vspace);
+  if (!c_vspace.Initialize().ok()) {
+    return nullptr;
+  }
+
+  std::vector<void*> target_vec = MakeTensorList(target);
+  if (PyErr_Occurred()) {
+    return nullptr;
+  }
+  std::vector<void*> sources_vec = MakeTensorList(sources);
+  if (PyErr_Occurred()) {
+    return nullptr;
+  }
+  std::vector<void*> outgrad_vec;
+  if (output_gradients != Py_None) {
+    outgrad_vec = MakeTensorList(output_gradients);
+    if (PyErr_Occurred()) {
+      return nullptr;
+    }
+    for (void* tensor : outgrad_vec) {
+      // Calling the backward function will eat a reference to the tensors in
+      // outgrad_vec, so we need to increase their reference count.
+      Py_INCREF(reinterpret_cast<PyObject*>(tensor));
+    }
+  }
+  TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
+  std::vector<void*> result;
+  status->status = tape_obj->tape->Gradient(c_vspace, target_vec, sources_vec,
+                                            outgrad_vec, &result);
+  if (!status->status.ok()) {
+    return nullptr;
+  }
+  if (!result.empty()) {
+    PyObject* py_result = PyList_New(result.size());
+    for (int i = 0; i < result.size(); ++i) {
+      if (result[i] == nullptr) {
+        Py_INCREF(Py_None);
+        result[i] = Py_None;
+      }
+      PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
+    }
+    return py_result;
+  }
+  Py_INCREF(Py_None);
+  return Py_None;
 }
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index c16aa8c2f7e..a06f5e1a670 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -72,7 +72,7 @@ class Tape(object):
       True if any of the tensors is in the tape.
     """
     return pywrap_tensorflow.TFE_Py_TapeShouldRecord(
-        self._tape, [x._id  for x in tensors])  # pylint: disable=protected-access
+        self._tape, tensors)
 
   def watch(self, tensor):
     """Adds a tensor to the tape."""
@@ -99,16 +99,6 @@ class Tape(object):
     """Deletes any trace we have for this tensor."""
     self._delete_tensor_id(tensor_id)
 
-  def export(self):
-    """Exports the internal state of this tape.
-
-    Returns:
-      tensor_tape: a map from tensor_id(tensor) to <identifier for op>
-       responsible for generating that tensor.
-      op_tape: a map from <identifier for op> to TapeEntry for that op.
-    """
-    return pywrap_tensorflow.TFE_Py_TapeExport(self._tape)
-
 
 class _TapeStack(threading.local):
 
diff --git a/tensorflow/python/eager/tape_test.py b/tensorflow/python/eager/tape_test.py
index c97cb621257..b490bac66db 100644
--- a/tensorflow/python/eager/tape_test.py
+++ b/tensorflow/python/eager/tape_test.py
@@ -22,7 +22,6 @@ from __future__ import print_function
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
 from tensorflow.python.eager import custom_gradient
-from tensorflow.python.eager import tape
 from tensorflow.python.eager import test
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
@@ -166,25 +165,6 @@ class TapeTest(test.TestCase):
     g, = backprop.gradients_function(fn, [0])(t)
     self.assertAllEqual(g, 1.0)
 
-  def testTapeGC(self):
-    # TODO(apassos) figure out how to test this without using tape internal
-    # APIs.
-    tape.push_new_tape()
-
-    def f():
-      x = constant_op.constant(1.0)
-      tape.watch(x)
-      x = gradient_is_constant(x)
-      x = gradient_is_constant(x)
-      x = gradient_is_constant(x)
-
-    f()
-    t = tape.pop_tape()
-    tensor_tape, op_tape = t.export()
-    self.assertEqual(len(tensor_tape), 1)  # The watched tensor will remain on
-                                           # the tape
-    self.assertEqual(len(op_tape), 0)  # No operations should remain on the tape
-
   def testCustomGradientGraphMode(self):
     with context.graph_mode(), self.test_session():
 
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 637f738fede..cbacf458a03 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -29,7 +29,7 @@ limitations under the License.
 %rename("%s") TFE_Py_TapeWatch;
 %rename("%s") TFE_Py_TapeDeleteTrace;
 %rename("%s") TFE_Py_TapeRecordOperation;
-%rename("%s") TFE_Py_TapeExport;
+%rename("%s") TFE_Py_TapeGradient;
 %rename("%s") TFE_NewContextOptions;
 %rename("%s") TFE_ContextOptionsSetConfig;
 %rename("%s") TFE_ContextOptionsSetDevicePlacementPolicy;
@@ -125,7 +125,7 @@ limitations under the License.
         SWIG_fail;
       }
       if (EagerTensor_CheckExact(elem)) {
-        (*$1)[i] = EagerTensorHandle(elem);
+        (*$1)[i] = EagerTensor_Handle(elem);
       } else {
         SWIG_exception_fail(SWIG_TypeError,
                             "provided list of inputs contains objects other "

From 2f796016426ada5346089111995a0bd64ee870e8 Mon Sep 17 00:00:00 2001
From: Justin Lebar <jlebar@google.com>
Date: Wed, 8 Nov 2017 13:49:36 -0800
Subject: [PATCH 046/115] [XLA:GPU] Add more logging to convolution autotuning.

PiperOrigin-RevId: 175057863
---
 .../xla/service/gpu/convolution_thunk.cc      | 22 +++++++++++++++++--
 1 file changed, 20 insertions(+), 2 deletions(-)

diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
index 536b96dcf62..e79d0a4c795 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -19,6 +19,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -279,6 +280,13 @@ std::vector<AlgorithmDesc> ConvolutionThunk::GetAlgorithms(
   return algorithms;
 }
 
+static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) {
+  if (algo.tensor_ops_enabled()) {
+    return tensorflow::strings::StrCat(algo.algo_id(), "+TC");
+  }
+  return tensorflow::strings::StrCat(algo.algo_id());
+}
+
 tensorflow::Status ConvolutionThunk::ConvolveWithTune(
     const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data,
     const FilterDescriptor& filter_descriptor,
@@ -303,6 +311,8 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
           buffer_allocations.device_ordinal(),
           buffer_allocations.memory_allocator());
       se::dnn::ProfileResult profile_result;
+      VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm)
+              << " for ConvolutionThunk: " << this;
       bool launch_ok =
           Convolve(input_descriptor, input_data, filter_descriptor, filter_data,
                    output_descriptor, output_data, convolution_descriptor,
@@ -310,6 +320,11 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
                    &scratch_allocator, &profile_result)
               .ok();
       if (launch_ok && profile_result.is_valid()) {
+        VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm)
+                << " for ConvolutionThunk " << this << " succeeded, taking "
+                << profile_result.elapsed_time_in_ms()
+                << "ms. (Best result: " << best_result.elapsed_time_in_ms()
+                << "ms)";
         if (profile_result.elapsed_time_in_ms() <
             best_result.elapsed_time_in_ms()) {
           best_result = profile_result;
@@ -319,6 +334,9 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
                 best_result_without_scratch.elapsed_time_in_ms()) {
           best_result_without_scratch = profile_result;
         }
+      } else {
+        VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm)
+                << " for ConvolutionThunk " << this << " failed.";
       }
     }
 
@@ -343,8 +361,8 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune(
 
   {
     VLOG(2) << "Using convolution algorithm ("
-            << best_algorithm_.algorithm().algo_id() << ", "
-            << best_algorithm_.algorithm_no_scratch().algo_id()
+            << AlgorithmToString(best_algorithm_.algorithm()) << ", "
+            << AlgorithmToString(best_algorithm_.algorithm_no_scratch())
             << ") for ConvolutionThunk: " << this;
     ConvolveScratchAllocator scratch_allocator(
         buffer_allocations.device_ordinal(),

From e0046de7afa46199e11bb3aef823a55dfa6a0355 Mon Sep 17 00:00:00 2001
From: Yifei Feng <yifeif@google.com>
Date: Wed, 8 Nov 2017 14:10:50 -0800
Subject: [PATCH 047/115] Fix typo in
 tensorflow/python/client/session_clusterspec_prop_test.py

PiperOrigin-RevId: 175061854
---
 tensorflow/python/client/session_clusterspec_prop_test.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/client/session_clusterspec_prop_test.py b/tensorflow/python/client/session_clusterspec_prop_test.py
index b77912b4f74..28a4dd27a76 100644
--- a/tensorflow/python/client/session_clusterspec_prop_test.py
+++ b/tensorflow/python/client/session_clusterspec_prop_test.py
@@ -169,7 +169,7 @@ class SessionClusterSpecPropagationTest(test_util.TensorFlowTestCase):
     # BaseRemoteRendezvous::SameWorkerRecvDone that means the test doesn't
     # actually capture the motivating bug unless run on a GPU machine.
     #
-    # Example error message (before bugfix -- linebreaks added because  lint):
+    # Example error message (before bugfix -- line breaks added because  lint):
     #
     # W0718 17:14:41.521534  190121 device_mgr.cc:107] Unknown device:
     #     /job:worker/replica:0/task:0/device:CPU:0 all devices:

From 6c382bb5e8860eb786dd51f5af639549a468bfdf Mon Sep 17 00:00:00 2001
From: Alexandre Passos <apassos@google.com>
Date: Wed, 8 Nov 2017 14:20:09 -0800
Subject: [PATCH 048/115] More idiomatic tests for defuns using variables.

PiperOrigin-RevId: 175063558
---
 tensorflow/python/eager/function_test.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 243efccac44..209715894ee 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -62,13 +62,21 @@ class FunctionTest(test.TestCase):
     @function.defun
     def step():
       def inner():
-        tape.watch_variable(v)
         return v * v
 
       return backprop.implicit_grad(inner)()[0][0]
 
     self.assertAllEqual(step(), 2.0)
 
+  def testDefunDifferentiable(self):
+    v = resource_variable_ops.ResourceVariable(1.0)
+
+    @function.defun
+    def f():
+      return v * v
+
+    self.assertAllEqual(backprop.implicit_grad(f)()[0][0], 2.0)
+
   def testGraphModeCaptureVariable(self):
     with context.graph_mode(), self.test_session() as sess:
 

From ffe3636221ff8ecf93f9f78e19edf1419e20c67d Mon Sep 17 00:00:00 2001
From: Igor Saprykin <isaprykin@google.com>
Date: Wed, 8 Nov 2017 14:21:28 -0800
Subject: [PATCH 049/115] Run Estimator.export_savedmodel with the user's
 TFSession config.

Estimator assumes a particular config_pb2.ConfigProto that configures the underlying session.  The config is either the default one or a user-supplied one.  The default config has allow_soft_placement=True, the option that allows silent placement of operations on devices with kernels when the requested device doesn't have a kernel for that operation.

Estimator's train(), eval() and predict() calls run with the underlying session configured in accordance to the ConfigProto.  However, export_savedmodel runs without such a configuration.  This appears to be a problem when the ModeKeys.PREDICT graph has an op that was placed on GPU but doesn't have a GPU kernel.  The graph works for predict(), but when export_savedmodel() is trying to restore the corresponding variable, the code fails with "no kernel for the op" error.  I attempted to show that in a test.

To fix this issue, I am passing the ConfigProto to the session inside export_savedmodel.  An alternative conservative and ugly fix is to pass a new instance ConfigProto with only allow_soft_placement=Estimator._session_config.allow_soft_placement.  Passing the whole ConfigProto feels like the right thing to do.  Here's what else is in ConfigProto: https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/core/protobuf/config.proto#L280.

I verified by running an internal pipeline.  Here's allow_soft_placement logic: https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/core/common_runtime/placer.cc#L322.

PiperOrigin-RevId: 175063803
---
 .../estimator/replicate_model_fn_test.py      |  5 +-
 tensorflow/python/estimator/estimator.py      |  2 +-
 tensorflow/python/estimator/estimator_test.py | 66 +++++++++++++++++++
 3 files changed, 68 insertions(+), 5 deletions(-)

diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index 10b47fba5af..ce286c33b01 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -90,14 +90,11 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
     def optimizer_fn():
       return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05)
 
-    # TODO(isaprykin):  Switch Estimator to use allow_soft_placement=True
-    # during export_savedmodel and then switch this test to replicate over
-    # GPUs instead of CPUs.
     estimator = estimator_lib.Estimator(
         model_fn=replicate_model_fn.replicate_model_fn(
             estimator.model_fn,
             optimizer_fn,
-            devices=['/cpu:0', '/cpu:0', '/cpu:0']),
+            devices=['/gpu:0', '/gpu:1', '/gpu:2']),
         model_dir=estimator.model_dir,
         config=estimator.config,
         params=estimator.params)
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index a730e107bae..2d036e2cfba 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -537,7 +537,7 @@ class Estimator(object):
       temp_export_dir = get_temp_export_dir(export_dir)
 
       # TODO(soergel): Consider whether MonitoredSession makes sense here
-      with tf_session.Session() as session:
+      with tf_session.Session(config=self._session_config) as session:
 
         saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
             sharded=True)
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 2b9b44523bb..c1b773b8c40 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -50,6 +50,7 @@ from tensorflow.python.ops import lookup_ops
 from tensorflow.python.ops import metrics as metrics_lib
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import string_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.ops.losses import losses
 from tensorflow.python.platform import gfile
@@ -1910,6 +1911,71 @@ class EstimatorExportTest(test.TestCase):
     est.train(dummy_input_fn, steps=1)
     est.export_savedmodel(tempfile.mkdtemp(), serving_input_receiver_fn)
 
+  def test_export_savedmodel_respects_soft_placement(self):
+    def model_fn_with_a_gpu_op_but_no_kernel(features, labels, mode):
+      _, _ = features, labels
+      table = saver_test_utils.CheckpointedOp(name='v2')
+
+      update_global_step = state_ops.assign_add(training.get_global_step(), 1)
+      with ops.control_dependencies([update_global_step]):
+        train_op = table.insert('k1', 30.0)
+
+      #  In this test, there are no GPUs available.  The goal is to verify that
+      #  export_savedmodel executes nevertheless.
+      with ops.device('/gpu:0'):
+        string_op = string_ops.as_string(update_global_step)
+
+      with ops.control_dependencies([string_op]):
+        prediction = table.lookup('k1', 0.0)
+
+      return model_fn_lib.EstimatorSpec(
+          mode,
+          predictions=prediction,
+          loss=constant_op.constant(1.),
+          train_op=train_op,
+          export_outputs={
+              'test': export_output.PredictOutput({
+                  'prediction': prediction
+              })
+          })
+
+    tmpdir = tempfile.mkdtemp()
+    est = estimator.Estimator(
+        model_fn=model_fn_with_a_gpu_op_but_no_kernel)
+    est.train(input_fn=dummy_input_fn, steps=1)
+    feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
+                    'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
+    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+        feature_spec)
+    export_dir_base = os.path.join(
+        compat.as_bytes(tmpdir), compat.as_bytes('export'))
+
+    export_dir = est.export_savedmodel(
+        export_dir_base, serving_input_receiver_fn)
+
+    # At this point, if export_savedmodel executed with
+    # allow_soft_placement=True, then the GPU-assigned operation was silently
+    # placed on the CPU.  Otherwise, an exception would have been raised
+    # related to the fact that the requested GPU device isn't available.
+
+    # Expectations below assume that export_savedmodel has completed normally.
+    self.assertTrue(gfile.Exists(export_dir_base))
+    self.assertTrue(gfile.Exists(export_dir))
+    self.assertTrue(gfile.Exists(os.path.join(
+        compat.as_bytes(export_dir),
+        compat.as_bytes('saved_model.pb'))))
+    self.assertTrue(gfile.Exists(os.path.join(
+        compat.as_bytes(export_dir),
+        compat.as_bytes('variables'))))
+    self.assertTrue(gfile.Exists(os.path.join(
+        compat.as_bytes(export_dir),
+        compat.as_bytes('variables/variables.index'))))
+    self.assertTrue(gfile.Exists(os.path.join(
+        compat.as_bytes(export_dir),
+        compat.as_bytes('variables/variables.data-00000-of-00001'))))
+
+    gfile.DeleteRecursively(tmpdir)
+
 
 class EstimatorHookOrderingTest(test.TestCase):
 

From 544b47d5702787083445d64af4d4683141c0ffc9 Mon Sep 17 00:00:00 2001
From: Igor Saprykin <isaprykin@google.com>
Date: Wed, 8 Nov 2017 14:39:42 -0800
Subject: [PATCH 050/115] Fix tensorflow.org rendering of the example code for
 run_step_fn.

Python code isn't indented correctly.

PiperOrigin-RevId: 175067065
---
 tensorflow/python/training/monitored_session.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index af9f11bb077..1f6016a91b6 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -536,6 +536,7 @@ class _MonitoredSession(object):
         will return True.
 
         Example usage:
+
         ```python
            with tf.Graph().as_default():
              c = tf.placeholder(dtypes.float32)
@@ -552,6 +553,7 @@ class _MonitoredSession(object):
                while not session.should_stop():
                  a = session.run_step_fn(step_fn)
         ```
+
         Hooks interact with the `run_with_hooks()` call inside the `step_fn`
         as they do with a `MonitoredSession.run` call.
 

From 6488286b2678dddd7c8ed534d92f228bd4c532c9 Mon Sep 17 00:00:00 2001
From: Francois Chollet <fchollet@google.com>
Date: Wed, 8 Nov 2017 14:57:30 -0800
Subject: [PATCH 051/115] Update tf.keras RNNs to the Keras 2.0.9 API. Does not
 include cuDNN layers. Additionally, fix a bug with handling of
 activity_regularizer in tf.layers base Layer (and add test).

PiperOrigin-RevId: 175070161
---
 tensorflow/python/keras/BUILD                 |   12 +
 .../keras/_impl/keras/engine/topology.py      |    9 +-
 .../keras/_impl/keras/integration_test.py     |    2 +-
 .../keras/_impl/keras/layers/gru_test.py      |   12 +-
 .../keras/_impl/keras/layers/lstm_test.py     |   11 +-
 .../keras/_impl/keras/layers/recurrent.py     | 3045 ++++++++++++-----
 .../_impl/keras/layers/recurrent_test.py      |  378 ++
 .../_impl/keras/layers/simplernn_test.py      |   12 +-
 tensorflow/python/keras/layers/__init__.py    |    5 +
 tensorflow/python/layers/base.py              |    2 +-
 tensorflow/python/layers/base_test.py         |    7 +
 .../tensorflow.keras.layers.-g-r-u-cell.pbtxt |  179 +
 .../tensorflow.keras.layers.-g-r-u.pbtxt      |   86 +-
 ...ensorflow.keras.layers.-l-s-t-m-cell.pbtxt |  179 +
 .../tensorflow.keras.layers.-l-s-t-m.pbtxt    |   90 +-
 .../tensorflow.keras.layers.-r-n-n.pbtxt      |  191 ++
 ...flow.keras.layers.-simple-r-n-n-cell.pbtxt |  179 +
 ...ensorflow.keras.layers.-simple-r-n-n.pbtxt |   78 +-
 ...ow.keras.layers.-stacked-r-n-n-cells.pbtxt |  183 +
 .../api/golden/tensorflow.keras.layers.pbtxt  |   20 +
 tensorflow/tools/ci_build/ci_sanity.sh        |    3 +-
 21 files changed, 3712 insertions(+), 971 deletions(-)
 create mode 100644 tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
 create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
 create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
 create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
 create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
 create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt

diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 4db48b45edd..6a762ee5d25 100644
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -498,6 +498,18 @@ py_test(
     ],
 )
 
+py_test(
+    name = "recurrent_test",
+    size = "small",
+    srcs = ["_impl/keras/layers/recurrent_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":keras",
+        "//tensorflow/python:client_testlib",
+        "//third_party/py/numpy",
+    ],
+)
+
 py_test(
     name = "serialization_test",
     size = "small",
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py
index f9be782f85e..2bcbabf19ce 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology.py
@@ -29,6 +29,9 @@ from six.moves import zip  # pylint: disable=redefined-builtin
 from tensorflow.python.eager import context
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import constraints
+from tensorflow.python.keras._impl.keras import initializers
+from tensorflow.python.keras._impl.keras import regularizers
 from tensorflow.python.keras._impl.keras.utils import conv_utils
 from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
 from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary
@@ -209,9 +212,9 @@ class Layer(tf_base_layers.Layer):
       dtype = K.floatx()
     weight = self.add_variable(name, shape,
                                dtype=dtype,
-                               initializer=initializer,
-                               regularizer=regularizer,
-                               constraint=constraint,
+                               initializer=initializers.get(initializer),
+                               regularizer=regularizers.get(regularizer),
+                               constraint=constraints.get(constraint),
                                trainable=trainable)
     return weight
 
diff --git a/tensorflow/python/keras/_impl/keras/integration_test.py b/tensorflow/python/keras/_impl/keras/integration_test.py
index 71100368480..871a8c73298 100644
--- a/tensorflow/python/keras/_impl/keras/integration_test.py
+++ b/tensorflow/python/keras/_impl/keras/integration_test.py
@@ -93,7 +93,7 @@ class KerasIntegrationTest(test.TestCase):
       y_test = keras.utils.to_categorical(y_test)
 
       model = keras.models.Sequential()
-      model.add(keras.layers.LSTM(3, return_sequences=True,
+      model.add(keras.layers.LSTM(5, return_sequences=True,
                                   input_shape=x_train.shape[1:]))
       model.add(keras.layers.GRU(y_train.shape[-1], activation='softmax'))
       model.compile(loss='categorical_crossentropy',
diff --git a/tensorflow/python/keras/_impl/keras/layers/gru_test.py b/tensorflow/python/keras/_impl/keras/layers/gru_test.py
index 03f0736161e..c57fbac41cc 100644
--- a/tensorflow/python/keras/_impl/keras/layers/gru_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/gru_test.py
@@ -156,8 +156,10 @@ class GRULayerTest(test.TestCase):
           activity_regularizer='l1')
       layer.build((None, None, 2))
       self.assertEqual(len(layer.losses), 3)
-      layer(keras.backend.variable(np.ones((2, 3, 2))))
-      self.assertEqual(len(layer.losses), 4)
+
+      x = keras.backend.variable(np.ones((2, 3, 2)))
+      layer(x)
+      self.assertEqual(len(layer.get_losses_for(x)), 1)
 
   def test_constraints_GRU(self):
     embedding_dim = 4
@@ -175,9 +177,9 @@ class GRULayerTest(test.TestCase):
           recurrent_constraint=r_constraint,
           bias_constraint=b_constraint)
       layer.build((None, None, embedding_dim))
-      self.assertEqual(layer.kernel.constraint, k_constraint)
-      self.assertEqual(layer.recurrent_kernel.constraint, r_constraint)
-      self.assertEqual(layer.bias.constraint, b_constraint)
+      self.assertEqual(layer.cell.kernel.constraint, k_constraint)
+      self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
+      self.assertEqual(layer.cell.bias.constraint, b_constraint)
 
   def test_with_masking_layer_GRU(self):
     layer_class = keras.layers.GRU
diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
index f43d90fec8f..8d359bf17cd 100644
--- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py
@@ -156,8 +156,9 @@ class LSTMLayerTest(test.TestCase):
           activity_regularizer='l1')
       layer.build((None, None, 2))
       self.assertEqual(len(layer.losses), 3)
-      layer(keras.backend.variable(np.ones((2, 3, 2))))
-      self.assertEqual(len(layer.losses), 4)
+      x = keras.backend.variable(np.ones((2, 3, 2)))
+      layer(x)
+      self.assertEqual(len(layer.get_losses_for(x)), 1)
 
   def test_constraints_LSTM(self):
     embedding_dim = 4
@@ -175,9 +176,9 @@ class LSTMLayerTest(test.TestCase):
           recurrent_constraint=r_constraint,
           bias_constraint=b_constraint)
       layer.build((None, None, embedding_dim))
-      self.assertEqual(layer.kernel.constraint, k_constraint)
-      self.assertEqual(layer.recurrent_kernel.constraint, r_constraint)
-      self.assertEqual(layer.bias.constraint, b_constraint)
+      self.assertEqual(layer.cell.kernel.constraint, k_constraint)
+      self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
+      self.assertEqual(layer.cell.bias.constraint, b_constraint)
 
   def test_with_masking_layer_LSTM(self):
     layer_class = keras.layers.LSTM
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 139523403c1..2bc74d5f807 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -29,93 +29,2157 @@ from tensorflow.python.keras._impl.keras import initializers
 from tensorflow.python.keras._impl.keras import regularizers
 from tensorflow.python.keras._impl.keras.engine import InputSpec
 from tensorflow.python.keras._impl.keras.engine import Layer
+from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
+from tensorflow.python.platform import tf_logging as logging
 
 
-# pylint: disable=access-member-before-definition
+class StackedRNNCells(Layer):
+  """Wrapper allowing a stack of RNN cells to behave as a single cell.
 
-
-def _time_distributed_dense(x,
-                            w,
-                            b=None,
-                            dropout=None,
-                            input_dim=None,
-                            output_dim=None,
-                            timesteps=None,
-                            training=None):
-  """Apply `y . w + b` for every temporal slice y of x.
+  Used to implement efficient stacked RNNs.
 
   Arguments:
-      x: input tensor.
-      w: weight matrix.
-      b: optional bias vector.
-      dropout: whether to apply dropout (same dropout mask
-          for every temporal slice of the input).
-      input_dim: integer; optional dimensionality of the input.
-      output_dim: integer; optional dimensionality of the output.
-      timesteps: integer; optional number of timesteps.
-      training: training phase tensor or boolean.
+      cells: List of RNN cell instances.
 
-  Returns:
-      Output tensor.
+  Examples:
+
+  ```python
+      cells = [
+          keras.layers.LSTMCell(output_dim),
+          keras.layers.LSTMCell(output_dim),
+          keras.layers.LSTMCell(output_dim),
+      ]
+
+      inputs = keras.Input((timesteps, input_dim))
+      x = keras.layers.RNN(cells)(inputs)
+  ```
   """
-  if not input_dim:
-    input_dim = K.shape(x)[2]
-  if not timesteps:
-    timesteps = K.shape(x)[1]
-  if not output_dim:
-    output_dim = K.shape(w)[1]
 
-  if dropout is not None and 0. < dropout < 1.:
-    # apply the same dropout pattern at every timestep
-    ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim)))
-    dropout_matrix = K.dropout(ones, dropout)
-    expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps)
-    x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training)
+  def __init__(self, cells, **kwargs):
+    for cell in cells:
+      if not hasattr(cell, 'call'):
+        raise ValueError('All cells must have a `call` method. '
+                         'received cells:', cells)
+      if not hasattr(cell, 'state_size'):
+        raise ValueError('All cells must have a '
+                         '`state_size` attribute. '
+                         'received cells:', cells)
+    self.cells = cells
+    super(StackedRNNCells, self).__init__(**kwargs)
 
-  # collapse time dimension and batch dimension together
-  x = K.reshape(x, (-1, input_dim))
-  x = K.dot(x, w)
-  if b is not None:
-    x = K.bias_add(x, b)
-  # reshape to 3D tensor
-  if K.backend() == 'tensorflow':
-    x = K.reshape(x, K.stack([-1, timesteps, output_dim]))
-    x.set_shape([None, None, output_dim])
-  else:
-    x = K.reshape(x, (-1, timesteps, output_dim))
-  return x
+  @property
+  def state_size(self):
+    # States are a flat list
+    # in reverse order of the cell stack.
+    # This allows to preserve the requirement
+    # `stack.state_size[0] == output_dim`.
+    # e.g. states of a 2-layer LSTM would be
+    # `[h2, c2, h1, c1]`
+    # (assuming one LSTM has states [h, c])
+    state_size = []
+    for cell in self.cells[::-1]:
+      if hasattr(cell.state_size, '__len__'):
+        state_size += list(cell.state_size)
+      else:
+        state_size.append(cell.state_size)
+    return tuple(state_size)
+
+  def call(self, inputs, states, **kwargs):
+    # Recover per-cell states.
+    nested_states = []
+    for cell in self.cells[::-1]:
+      if hasattr(cell.state_size, '__len__'):
+        nested_states.append(states[:len(cell.state_size)])
+        states = states[len(cell.state_size):]
+      else:
+        nested_states.append([states[0]])
+        states = states[1:]
+    nested_states = nested_states[::-1]
+
+    # Call the cells in order and store the returned states.
+    new_nested_states = []
+    for cell, states in zip(self.cells, nested_states):
+      inputs, states = cell.call(inputs, states, **kwargs)
+      new_nested_states.append(states)
+
+    # Format the new states as a flat list
+    # in reverse cell order.
+    states = []
+    for cell_states in new_nested_states[::-1]:
+      states += cell_states
+    return inputs, states
+
+  def build(self, input_shape):
+    for cell in self.cells:
+      if isinstance(cell, Layer):
+        cell.build(input_shape)
+      if hasattr(cell.state_size, '__len__'):
+        output_dim = cell.state_size[0]
+      else:
+        output_dim = cell.state_size
+      input_shape = (input_shape[0], input_shape[1], output_dim)
+    self.built = True
+
+  def get_config(self):
+    cells = []
+    for cell in self.cells:
+      cells.append({
+          'class_name': cell.__class__.__name__,
+          'config': cell.get_config()
+      })
+    config = {'cells': cells}
+    base_config = super(StackedRNNCells, self).get_config()
+    return dict(list(base_config.items()) + list(config.items()))
+
+  @classmethod
+  def from_config(cls, config, custom_objects=None):
+    from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
+    cells = []
+    for cell_config in config.pop('cells'):
+      cells.append(
+          deserialize_layer(cell_config, custom_objects=custom_objects))
+    return cls(cells, **config)
+
+  @property
+  def trainable_weights(self):
+    if not self.trainable:
+      return []
+    weights = []
+    for cell in self.cells:
+      if isinstance(cell, Layer):
+        weights += cell.trainable_weights
+    return weights
+
+  @property
+  def non_trainable_weights(self):
+    weights = []
+    for cell in self.cells:
+      if isinstance(cell, Layer):
+        weights += cell.non_trainable_weights
+    if not self.trainable:
+      trainable_weights = []
+      for cell in self.cells:
+        if isinstance(cell, Layer):
+          trainable_weights += cell.trainable_weights
+      return trainable_weights + weights
+    return weights
+
+  def get_weights(self):
+    """Retrieves the weights of the model.
+
+    Returns:
+        A flat list of Numpy arrays.
+    """
+    weights = []
+    for cell in self.cells:
+      if isinstance(cell, Layer):
+        weights += cell.weights
+    return K.batch_get_value(weights)
+
+  def set_weights(self, weights):
+    """Sets the weights of the model.
+
+    Arguments:
+        weights: A list of Numpy arrays with shapes and types matching
+            the output of `model.get_weights()`.
+    """
+    tuples = []
+    for cell in self.cells:
+      if isinstance(cell, Layer):
+        num_param = len(cell.weights)
+        weights = weights[:num_param]
+        for sw, w in zip(cell.weights, weights):
+          tuples.append((sw, w))
+        weights = weights[num_param:]
+    K.batch_set_value(tuples)
+
+  @property
+  def losses(self):
+    losses = []
+    for cell in self.cells:
+      if isinstance(cell, Layer):
+        cell_losses = cell.losses
+        losses += cell_losses
+    return losses
+
+  def get_losses_for(self, inputs=None):
+    losses = []
+    for cell in self.cells:
+      if isinstance(cell, Layer):
+        cell_losses = cell.get_losses_for(inputs)
+        losses += cell_losses
+    return losses
+
+
+class RNN(Layer):
+  """Base class for recurrent layers.
+
+  Arguments:
+      cell: A RNN cell instance. A RNN cell is a class that has:
+          - a `call(input_at_t, states_at_t)` method, returning
+              `(output_at_t, states_at_t_plus_1)`. The call method of the
+              cell can also take the optional argument `constants`, see
+              section "Note on passing external constants" below.
+          - a `state_size` attribute. This can be a single integer
+              (single state) in which case it is
+              the size of the recurrent state
+              (which should be the same as the size of the cell output).
+              This can also be a list/tuple of integers
+              (one size per state). In this case, the first entry
+              (`state_size[0]`) should be the same as
+              the size of the cell output.
+          It is also possible for `cell` to be a list of RNN cell instances,
+          in which cases the cells get stacked on after the other in the RNN,
+          implementing an efficient stacked RNN.
+      return_sequences: Boolean. Whether to return the last output.
+          in the output sequence, or the full sequence.
+      return_state: Boolean. Whether to return the last state
+          in addition to the output.
+      go_backwards: Boolean (default False).
+          If True, process the input sequence backwards and return the
+          reversed sequence.
+      stateful: Boolean (default False). If True, the last state
+          for each sample at index i in a batch will be used as initial
+          state for the sample of index i in the following batch.
+      unroll: Boolean (default False).
+          If True, the network will be unrolled,
+          else a symbolic loop will be used.
+          Unrolling can speed-up a RNN,
+          although it tends to be more memory-intensive.
+          Unrolling is only suitable for short sequences.
+      input_dim: dimensionality of the input (integer).
+          This argument (or alternatively,
+          the keyword argument `input_shape`)
+          is required when using this layer as the first layer in a model.
+      input_length: Length of input sequences, to be specified
+          when it is constant.
+          This argument is required if you are going to connect
+          `Flatten` then `Dense` layers upstream
+          (without it, the shape of the dense outputs cannot be computed).
+          Note that if the recurrent layer is not the first layer
+          in your model, you would need to specify the input length
+          at the level of the first layer
+          (e.g. via the `input_shape` argument)
+
+  Input shape:
+      3D tensor with shape `(batch_size, timesteps, input_dim)`,
+      (Optional) 2D tensors with shape `(batch_size, output_dim)`.
+
+  Output shape:
+      - if `return_state`: a list of tensors. The first tensor is
+          the output. The remaining tensors are the last states,
+          each with shape `(batch_size, units)`.
+      - if `return_sequences`: 3D tensor with shape
+          `(batch_size, timesteps, units)`.
+      - else, 2D tensor with shape `(batch_size, units)`.
+
+  # Masking
+      This layer supports masking for input data with a variable number
+      of timesteps. To introduce masks to your data,
+      use an [Embedding](embeddings.md) layer with the `mask_zero` parameter
+      set to `True`.
+
+  # Note on using statefulness in RNNs
+      You can set RNN layers to be 'stateful', which means that the states
+      computed for the samples in one batch will be reused as initial states
+      for the samples in the next batch. This assumes a one-to-one mapping
+      between samples in different successive batches.
+
+      To enable statefulness:
+          - specify `stateful=True` in the layer constructor.
+          - specify a fixed batch size for your model, by passing
+              if sequential model:
+                `batch_input_shape=(...)` to the first layer in your model.
+              else for functional model with 1 or more Input layers:
+                `batch_shape=(...)` to all the first layers in your model.
+              This is the expected shape of your inputs
+              *including the batch size*.
+              It should be a tuple of integers, e.g. `(32, 10, 100)`.
+          - specify `shuffle=False` when calling fit().
+
+      To reset the states of your model, call `.reset_states()` on either
+      a specific layer, or on your entire model.
+
+  # Note on specifying the initial state of RNNs
+      You can specify the initial state of RNN layers symbolically by
+      calling them with the keyword argument `initial_state`. The value of
+      `initial_state` should be a tensor or list of tensors representing
+      the initial state of the RNN layer.
+
+      You can specify the initial state of RNN layers numerically by
+      calling `reset_states` with the keyword argument `states`. The value of
+      `states` should be a numpy array or list of numpy arrays representing
+      the initial state of the RNN layer.
+
+  # Note on passing external constants to RNNs
+      You can pass "external" constants to the cell using the `constants`
+      keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
+      requires that the `cell.call` method accepts the same keyword argument
+      `constants`. Such constants can be used to condition the cell
+      transformation on additional static inputs (not changing over time),
+      a.k.a. an attention mechanism.
+
+  Examples:
+
+  ```python
+      # First, let's define a RNN Cell, as a layer subclass.
+
+      class MinimalRNNCell(keras.layers.Layer):
+
+          def __init__(self, units, **kwargs):
+              self.units = units
+              self.state_size = units
+              super(MinimalRNNCell, self).__init__(**kwargs)
+
+          def build(self, input_shape):
+              self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
+                                            initializer='uniform',
+                                            name='kernel')
+              self.recurrent_kernel = self.add_weight(
+                  shape=(self.units, self.units),
+                  initializer='uniform',
+                  name='recurrent_kernel')
+              self.built = True
+
+          def call(self, inputs, states):
+              prev_output = states[0]
+              h = K.dot(inputs, self.kernel)
+              output = h + K.dot(prev_output, self.recurrent_kernel)
+              return output, [output]
+
+      # Let's use this cell in a RNN layer:
+
+      cell = MinimalRNNCell(32)
+      x = keras.Input((None, 5))
+      layer = RNN(cell)
+      y = layer(x)
+
+      # Here's how to use the cell to build a stacked RNN:
+
+      cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
+      x = keras.Input((None, 5))
+      layer = RNN(cells)
+      y = layer(x)
+  ```
+  """
+
+  def __init__(self,
+               cell,
+               return_sequences=False,
+               return_state=False,
+               go_backwards=False,
+               stateful=False,
+               unroll=False,
+               activity_regularizer=None,
+               **kwargs):
+    if isinstance(cell, (list, tuple)):
+      cell = StackedRNNCells(cell)
+    if not hasattr(cell, 'call'):
+      raise ValueError('`cell` should have a `call` method. '
+                       'The RNN was passed:', cell)
+    if not hasattr(cell, 'state_size'):
+      raise ValueError('The RNN cell should have '
+                       'an attribute `state_size` '
+                       '(tuple of integers, '
+                       'one integer per RNN state).')
+    super(RNN, self).__init__(
+        activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
+    self.cell = cell
+    self.return_sequences = return_sequences
+    self.return_state = return_state
+    self.go_backwards = go_backwards
+    self.stateful = stateful
+    self.unroll = unroll
+
+    self.supports_masking = True
+    self.input_spec = [InputSpec(ndim=3)]
+    self.state_spec = None
+    self._states = None
+    self.constants_spec = None
+    self._num_constants = None
+
+  @property
+  def states(self):
+    if self._states is None:
+      if isinstance(self.cell.state_size, int):
+        num_states = 1
+      else:
+        num_states = len(self.cell.state_size)
+      return [None for _ in range(num_states)]
+    return self._states
+
+  @states.setter
+  def states(self, states):
+    self._states = states
+
+  def _compute_output_shape(self, input_shape):
+    if isinstance(input_shape, list):
+      input_shape = input_shape[0]
+    input_shape = tensor_shape.TensorShape(input_shape).as_list()
+
+    if hasattr(self.cell.state_size, '__len__'):
+      output_dim = self.cell.state_size[0]
+    else:
+      output_dim = self.cell.state_size
+
+    if self.return_sequences:
+      output_shape = (input_shape[0], input_shape[1], output_dim)
+    else:
+      output_shape = (input_shape[0], output_dim)
+
+    if self.return_state:
+      state_shape = [(input_shape[0], output_dim) for _ in self.states]
+      output_shape = [output_shape] + state_shape
+    else:
+      output_shape = output_shape
+    return tensor_shape.TensorShape(output_shape)
+
+  def compute_mask(self, inputs, mask):
+    if isinstance(mask, list):
+      mask = mask[0]
+    output_mask = mask if self.return_sequences else None
+    if self.return_state:
+      state_mask = [None for _ in self.states]
+      return [output_mask] + state_mask
+    else:
+      return output_mask
+
+  def build(self, input_shape):
+    # Note input_shape will be list of shapes of initial states and
+    # constants if these are passed in __call__.
+    if self._num_constants is not None:
+      constants_shape = input_shape[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
+    else:
+      constants_shape = None
+
+    if isinstance(input_shape, list):
+      input_shape = input_shape[0]
+    input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
+
+    batch_size = input_shape[0] if self.stateful else None
+    input_dim = input_shape[-1]
+    self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim))
+
+    # allow cell (if layer) to build before we set or validate state_spec
+    if isinstance(self.cell, Layer):
+      step_input_shape = (input_shape[0],) + input_shape[2:]
+      if constants_shape is not None:
+        self.cell.build([step_input_shape] + constants_shape)
+      else:
+        self.cell.build(step_input_shape)
+
+    # set or validate state_spec
+    if hasattr(self.cell.state_size, '__len__'):
+      state_size = list(self.cell.state_size)
+    else:
+      state_size = [self.cell.state_size]
+
+    if self.state_spec is not None:
+      # initial_state was passed in call, check compatibility
+      if [spec.shape[-1] for spec in self.state_spec] != state_size:
+        raise ValueError(
+            'An initial_state was passed that is not compatible with '
+            '`cell.state_size`. Received `state_spec`={}; '
+            'However `cell.state_size` is '
+            '{}'.format(self.state_spec, self.cell.state_size))
+    else:
+      self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
+    if self.stateful:
+      self.reset_states()
+
+  def get_initial_state(self, inputs):
+    # build an all-zero tensor of shape (samples, output_dim)
+    initial_state = K.zeros_like(inputs)  # (samples, timesteps, input_dim)
+    initial_state = K.sum(initial_state, axis=(1, 2))  # (samples,)
+    initial_state = K.expand_dims(initial_state)  # (samples, 1)
+    if hasattr(self.cell.state_size, '__len__'):
+      return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size]
+    else:
+      return [K.tile(initial_state, [1, self.cell.state_size])]
+
+  def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
+    inputs, initial_state, constants = self._standardize_args(
+        inputs, initial_state, constants)
+
+    if initial_state is None and constants is None:
+      return super(RNN, self).__call__(inputs, **kwargs)
+
+    # If any of `initial_state` or `constants` are specified and are Keras
+    # tensors, then add them to the inputs and temporarily modify the
+    # input_spec to include them.
+
+    additional_inputs = []
+    additional_specs = []
+    if initial_state is not None:
+      kwargs['initial_state'] = initial_state
+      additional_inputs += initial_state
+      self.state_spec = [
+          InputSpec(shape=K.int_shape(state)) for state in initial_state
+      ]
+      additional_specs += self.state_spec
+    if constants is not None:
+      kwargs['constants'] = constants
+      additional_inputs += constants
+      self.constants_spec = [
+          InputSpec(shape=K.int_shape(constant)) for constant in constants
+      ]
+      self._num_constants = len(constants)
+      additional_specs += self.constants_spec
+    # at this point additional_inputs cannot be empty
+    is_keras_tensor = hasattr(additional_inputs[0], '_keras_history')
+    for tensor in additional_inputs:
+      if hasattr(tensor, '_keras_history') != is_keras_tensor:
+        raise ValueError('The initial state or constants of an RNN'
+                         ' layer cannot be specified with a mix of'
+                         ' Keras tensors and non-Keras tensors')
+
+    if is_keras_tensor:
+      # Compute the full input spec, including state and constants
+      full_input = [inputs] + additional_inputs
+      full_input_spec = self.input_spec + additional_specs
+      # Perform the call with temporarily replaced input_spec
+      original_input_spec = self.input_spec
+      self.input_spec = full_input_spec
+      output = super(RNN, self).__call__(full_input, **kwargs)
+      self.input_spec = original_input_spec
+      return output
+    else:
+      return super(RNN, self).__call__(inputs, **kwargs)
+
+  def call(self,
+           inputs,
+           mask=None,
+           training=None,
+           initial_state=None,
+           constants=None):
+    # input shape: `(samples, time (padded with zeros), input_dim)`
+    # note that the .build() method of subclasses MUST define
+    # self.input_spec and self.state_spec with complete input shapes.
+    if isinstance(inputs, list):
+      inputs = inputs[0]
+    if initial_state is not None:
+      pass
+    elif self.stateful:
+      initial_state = self.states
+    else:
+      initial_state = self.get_initial_state(inputs)
+
+    if isinstance(mask, list):
+      mask = mask[0]
+
+    if len(initial_state) != len(self.states):
+      raise ValueError(
+          'Layer has ' + str(len(self.states)) + ' states but was passed ' +
+          str(len(initial_state)) + ' initial states.')
+    input_shape = K.int_shape(inputs)
+    timesteps = input_shape[1]
+    if self.unroll and timesteps in [None, 1]:
+      raise ValueError('Cannot unroll a RNN if the '
+                       'time dimension is undefined or equal to 1. \n'
+                       '- If using a Sequential model, '
+                       'specify the time dimension by passing '
+                       'an `input_shape` or `batch_input_shape` '
+                       'argument to your first layer. If your '
+                       'first layer is an Embedding, you can '
+                       'also use the `input_length` argument.\n'
+                       '- If using the functional API, specify '
+                       'the time dimension by passing a `shape` '
+                       'or `batch_shape` argument to your Input layer.')
+
+    kwargs = {}
+    if has_arg(self.cell.call, 'training'):
+      kwargs['training'] = training
+
+    if constants:
+      if not has_arg(self.cell.call, 'constants'):
+        raise ValueError('RNN cell does not support constants')
+
+      def step(inputs, states):
+        constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
+        states = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
+        return self.cell.call(inputs, states, constants=constants, **kwargs)
+    else:
+
+      def step(inputs, states):
+        return self.cell.call(inputs, states, **kwargs)
+
+    last_output, outputs, states = K.rnn(
+        step,
+        inputs,
+        initial_state,
+        constants=constants,
+        go_backwards=self.go_backwards,
+        mask=mask,
+        unroll=self.unroll)
+    if self.stateful:
+      updates = []
+      for i in range(len(states)):
+        updates.append((self.states[i], states[i]))
+      self.add_update(updates, inputs)
+
+    if self.return_sequences:
+      output = outputs
+    else:
+      output = last_output
+
+    # Properly set learning phase
+    if getattr(last_output, '_uses_learning_phase', False):
+      output._uses_learning_phase = True
+
+    if self.return_state:
+      if not isinstance(states, (list, tuple)):
+        states = [states]
+      else:
+        states = list(states)
+      return [output] + states
+    else:
+      return output
+
+  def _standardize_args(self, inputs, initial_state, constants):
+    """Standardize `__call__` arguments to a single list of tensor inputs.
+
+    When running a model loaded from file, the input tensors
+    `initial_state` and `constants` can be passed to `RNN.__call__` as part
+    of `inputs` instead of by the dedicated keyword arguments. This method
+    makes sure the arguments are separated and that `initial_state` and
+    `constants` are lists of tensors (or None).
+
+    Arguments:
+        inputs: tensor or list/tuple of tensors
+        initial_state: tensor or list of tensors or None
+        constants: tensor or list of tensors or None
+
+    Returns:
+        inputs: tensor
+        initial_state: list of tensors or None
+        constants: list of tensors or None
+    """
+    if isinstance(inputs, list):
+      assert initial_state is None and constants is None
+      if self._num_constants is not None:
+        constants = inputs[-self._num_constants:]  # pylint: disable=invalid-unary-operand-type
+        inputs = inputs[:-self._num_constants]  # pylint: disable=invalid-unary-operand-type
+      if len(inputs) > 1:
+        initial_state = inputs[1:]
+      inputs = inputs[0]
+
+    def to_list_or_none(x):
+      if x is None or isinstance(x, list):
+        return x
+      if isinstance(x, tuple):
+        return list(x)
+      return [x]
+
+    initial_state = to_list_or_none(initial_state)
+    constants = to_list_or_none(constants)
+
+    return inputs, initial_state, constants
+
+  def reset_states(self, states=None):
+    if not self.stateful:
+      raise AttributeError('Layer must be stateful.')
+    batch_size = self.input_spec[0].shape[0]
+    if not batch_size:
+      raise ValueError('If a RNN is stateful, it needs to know '
+                       'its batch size. Specify the batch size '
+                       'of your input tensors: \n'
+                       '- If using a Sequential model, '
+                       'specify the batch size by passing '
+                       'a `batch_input_shape` '
+                       'argument to your first layer.\n'
+                       '- If using the functional API, specify '
+                       'the time dimension by passing a '
+                       '`batch_shape` argument to your Input layer.')
+    # initialize state if None
+    if self.states[0] is None:
+      if hasattr(self.cell.state_size, '__len__'):
+        self.states = [
+            K.zeros((batch_size, dim)) for dim in self.cell.state_size
+        ]
+      else:
+        self.states = [K.zeros((batch_size, self.cell.state_size))]
+    elif states is None:
+      if hasattr(self.cell.state_size, '__len__'):
+        for state, dim in zip(self.states, self.cell.state_size):
+          K.set_value(state, np.zeros((batch_size, dim)))
+      else:
+        K.set_value(self.states[0], np.zeros((batch_size,
+                                              self.cell.state_size)))
+    else:
+      if not isinstance(states, (list, tuple)):
+        states = [states]
+      if len(states) != len(self.states):
+        raise ValueError('Layer ' + self.name + ' expects ' +
+                         str(len(self.states)) + ' states, '
+                         'but it received ' + str(len(states)) +
+                         ' state values. Input received: ' + str(states))
+      for index, (value, state) in enumerate(zip(states, self.states)):
+        if hasattr(self.cell.state_size, '__len__'):
+          dim = self.cell.state_size[index]
+        else:
+          dim = self.cell.state_size
+        if value.shape != (batch_size, dim):
+          raise ValueError(
+              'State ' + str(index) + ' is incompatible with layer ' +
+              self.name + ': expected shape=' + str(
+                  (batch_size, dim)) + ', found shape=' + str(value.shape))
+        # TODO(fchollet): consider batch calls to `set_value`.
+        K.set_value(state, value)
+
+  def get_config(self):
+    config = {
+        'return_sequences': self.return_sequences,
+        'return_state': self.return_state,
+        'go_backwards': self.go_backwards,
+        'stateful': self.stateful,
+        'unroll': self.unroll
+    }
+    if self._num_constants is not None:
+      config['num_constants'] = self._num_constants
+
+    cell_config = self.cell.get_config()
+    config['cell'] = {
+        'class_name': self.cell.__class__.__name__,
+        'config': cell_config
+    }
+    base_config = super(RNN, self).get_config()
+    return dict(list(base_config.items()) + list(config.items()))
+
+  @classmethod
+  def from_config(cls, config, custom_objects=None):
+    from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
+    cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
+    num_constants = config.pop('num_constants', None)
+    layer = cls(cell, **config)
+    layer._num_constants = num_constants
+    return layer
+
+  @property
+  def trainable_weights(self):
+    if isinstance(self.cell, Layer):
+      return self.cell.trainable_weights
+    return []
+
+  @property
+  def non_trainable_weights(self):
+    if isinstance(self.cell, Layer):
+      return self.cell.non_trainable_weights
+    return []
+
+  @property
+  def losses(self):
+    if isinstance(self.cell, Layer):
+      return self.cell.losses
+    return []
+
+  def get_losses_for(self, inputs=None):
+    if isinstance(self.cell, Layer):
+      cell_losses = self.cell.get_losses_for(inputs)
+      return cell_losses + super(RNN, self).get_losses_for(inputs)
+    return super(RNN, self).get_losses_for(inputs)
+
+
+class SimpleRNNCell(Layer):
+  """Cell class for SimpleRNN.
+
+  Arguments:
+      units: Positive integer, dimensionality of the output space.
+      activation: Activation function to use
+          (see [activations](../activations.md)).
+          If you pass None, no activation is applied
+          (ie. "linear" activation: `a(x) = x`).
+      use_bias: Boolean, whether the layer uses a bias vector.
+      kernel_initializer: Initializer for the `kernel` weights matrix,
+          used for the linear transformation of the inputs.
+          (see [initializers](../initializers.md)).
+      recurrent_initializer: Initializer for the `recurrent_kernel`
+          weights matrix,
+          used for the linear transformation of the recurrent state.
+          (see [initializers](../initializers.md)).
+      bias_initializer: Initializer for the bias vector
+          (see [initializers](../initializers.md)).
+      kernel_regularizer: Regularizer function applied to
+          the `kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      recurrent_regularizer: Regularizer function applied to
+          the `recurrent_kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      bias_regularizer: Regularizer function applied to the bias vector
+          (see [regularizer](../regularizers.md)).
+      kernel_constraint: Constraint function applied to
+          the `kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      recurrent_constraint: Constraint function applied to
+          the `recurrent_kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      bias_constraint: Constraint function applied to the bias vector
+          (see [constraints](../constraints.md)).
+      dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the inputs.
+      recurrent_dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the recurrent state.
+  """
+
+  def __init__(self,
+               units,
+               activation='tanh',
+               use_bias=True,
+               kernel_initializer='glorot_uniform',
+               recurrent_initializer='orthogonal',
+               bias_initializer='zeros',
+               kernel_regularizer=None,
+               recurrent_regularizer=None,
+               bias_regularizer=None,
+               kernel_constraint=None,
+               recurrent_constraint=None,
+               bias_constraint=None,
+               dropout=0.,
+               recurrent_dropout=0.,
+               **kwargs):
+    super(SimpleRNNCell, self).__init__(**kwargs)
+    self.units = units
+    self.activation = activations.get(activation)
+    self.use_bias = use_bias
+
+    self.kernel_initializer = initializers.get(kernel_initializer)
+    self.recurrent_initializer = initializers.get(recurrent_initializer)
+    self.bias_initializer = initializers.get(bias_initializer)
+
+    self.kernel_regularizer = regularizers.get(kernel_regularizer)
+    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
+    self.bias_regularizer = regularizers.get(bias_regularizer)
+
+    self.kernel_constraint = constraints.get(kernel_constraint)
+    self.recurrent_constraint = constraints.get(recurrent_constraint)
+    self.bias_constraint = constraints.get(bias_constraint)
+
+    self.dropout = min(1., max(0., dropout))
+    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
+    self.state_size = self.units
+    self._dropout_mask = None
+    self._recurrent_dropout_mask = None
+
+  def build(self, input_shape):
+    self.kernel = self.add_weight(
+        shape=(input_shape[-1], self.units),
+        name='kernel',
+        initializer=self.kernel_initializer,
+        regularizer=self.kernel_regularizer,
+        constraint=self.kernel_constraint)
+    self.recurrent_kernel = self.add_weight(
+        shape=(self.units, self.units),
+        name='recurrent_kernel',
+        initializer=self.recurrent_initializer,
+        regularizer=self.recurrent_regularizer,
+        constraint=self.recurrent_constraint)
+    if self.use_bias:
+      self.bias = self.add_weight(
+          shape=(self.units,),
+          name='bias',
+          initializer=self.bias_initializer,
+          regularizer=self.bias_regularizer,
+          constraint=self.bias_constraint)
+    else:
+      self.bias = None
+    self.built = True
+
+  def _generate_dropout_mask(self, inputs, training=None):
+    if 0 < self.dropout < 1:
+      ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
+
+      def dropped_inputs():
+        return K.dropout(ones, self.dropout)
+
+      self._dropout_mask = K.in_train_phase(
+          dropped_inputs, ones, training=training)
+    else:
+      self._dropout_mask = None
+
+  def _generate_recurrent_dropout_mask(self, inputs, training=None):
+    if 0 < self.recurrent_dropout < 1:
+      ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
+      ones = K.tile(ones, (1, self.units))
+
+      def dropped_inputs():
+        return K.dropout(ones, self.dropout)
+
+      self._recurrent_dropout_mask = K.in_train_phase(
+          dropped_inputs, ones, training=training)
+    else:
+      self._recurrent_dropout_mask = None
+
+  def call(self, inputs, states, training=None):
+    prev_output = states[0]
+    dp_mask = self._dropout_mask
+    rec_dp_mask = self._recurrent_dropout_mask
+
+    if dp_mask is not None:
+      h = K.dot(inputs * dp_mask, self.kernel)
+    else:
+      h = K.dot(inputs, self.kernel)
+    if self.bias is not None:
+      h = K.bias_add(h, self.bias)
+
+    if rec_dp_mask is not None:
+      prev_output *= rec_dp_mask
+    output = h + K.dot(prev_output, self.recurrent_kernel)
+    if self.activation is not None:
+      output = self.activation(output)
+
+    # Properly set learning phase on output tensor.
+    if 0 < self.dropout + self.recurrent_dropout:
+      if training is None:
+        output._uses_learning_phase = True
+    return output, [output]
+
+
+class SimpleRNN(RNN):
+  """Fully-connected RNN where the output is to be fed back to input.
+
+  Arguments:
+      units: Positive integer, dimensionality of the output space.
+      activation: Activation function to use
+          (see [activations](../activations.md)).
+          If you pass None, no activation is applied
+          (ie. "linear" activation: `a(x) = x`).
+      use_bias: Boolean, whether the layer uses a bias vector.
+      kernel_initializer: Initializer for the `kernel` weights matrix,
+          used for the linear transformation of the inputs.
+          (see [initializers](../initializers.md)).
+      recurrent_initializer: Initializer for the `recurrent_kernel`
+          weights matrix,
+          used for the linear transformation of the recurrent state.
+          (see [initializers](../initializers.md)).
+      bias_initializer: Initializer for the bias vector
+          (see [initializers](../initializers.md)).
+      kernel_regularizer: Regularizer function applied to
+          the `kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      recurrent_regularizer: Regularizer function applied to
+          the `recurrent_kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      bias_regularizer: Regularizer function applied to the bias vector
+          (see [regularizer](../regularizers.md)).
+      activity_regularizer: Regularizer function applied to
+          the output of the layer (its "activation").
+          (see [regularizer](../regularizers.md)).
+      kernel_constraint: Constraint function applied to
+          the `kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      recurrent_constraint: Constraint function applied to
+          the `recurrent_kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      bias_constraint: Constraint function applied to the bias vector
+          (see [constraints](../constraints.md)).
+      dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the inputs.
+      recurrent_dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the recurrent state.
+      return_sequences: Boolean. Whether to return the last output.
+          in the output sequence, or the full sequence.
+      return_state: Boolean. Whether to return the last state
+          in addition to the output.
+      go_backwards: Boolean (default False).
+          If True, process the input sequence backwards and return the
+          reversed sequence.
+      stateful: Boolean (default False). If True, the last state
+          for each sample at index i in a batch will be used as initial
+          state for the sample of index i in the following batch.
+      unroll: Boolean (default False).
+          If True, the network will be unrolled,
+          else a symbolic loop will be used.
+          Unrolling can speed-up a RNN,
+          although it tends to be more memory-intensive.
+          Unrolling is only suitable for short sequences.
+  """
+
+  def __init__(self,
+               units,
+               activation='tanh',
+               use_bias=True,
+               kernel_initializer='glorot_uniform',
+               recurrent_initializer='orthogonal',
+               bias_initializer='zeros',
+               kernel_regularizer=None,
+               recurrent_regularizer=None,
+               bias_regularizer=None,
+               activity_regularizer=None,
+               kernel_constraint=None,
+               recurrent_constraint=None,
+               bias_constraint=None,
+               dropout=0.,
+               recurrent_dropout=0.,
+               return_sequences=False,
+               return_state=False,
+               go_backwards=False,
+               stateful=False,
+               unroll=False,
+               **kwargs):
+    if 'implementation' in kwargs:
+      kwargs.pop('implementation')
+      logging.warning('The `implementation` argument '
+                      'in `SimpleRNN` has been deprecated. '
+                      'Please remove it from your layer call.')
+    cell = SimpleRNNCell(
+        units,
+        activation=activation,
+        use_bias=use_bias,
+        kernel_initializer=kernel_initializer,
+        recurrent_initializer=recurrent_initializer,
+        bias_initializer=bias_initializer,
+        kernel_regularizer=kernel_regularizer,
+        recurrent_regularizer=recurrent_regularizer,
+        bias_regularizer=bias_regularizer,
+        kernel_constraint=kernel_constraint,
+        recurrent_constraint=recurrent_constraint,
+        bias_constraint=bias_constraint,
+        dropout=dropout,
+        recurrent_dropout=recurrent_dropout)
+    super(SimpleRNN, self).__init__(
+        cell,
+        return_sequences=return_sequences,
+        return_state=return_state,
+        go_backwards=go_backwards,
+        stateful=stateful,
+        unroll=unroll,
+        activity_regularizer=regularizers.get(activity_regularizer),
+        **kwargs)
+    # self.activity_regularizer = regularizers.get(activity_regularizer)
+
+  def call(self, inputs, mask=None, training=None, initial_state=None):
+    self.cell._generate_dropout_mask(inputs, training=training)
+    self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+    return super(SimpleRNN, self).call(
+        inputs, mask=mask, training=training, initial_state=initial_state)
+
+  @property
+  def units(self):
+    return self.cell.units
+
+  @property
+  def activation(self):
+    return self.cell.activation
+
+  @property
+  def use_bias(self):
+    return self.cell.use_bias
+
+  @property
+  def kernel_initializer(self):
+    return self.cell.kernel_initializer
+
+  @property
+  def recurrent_initializer(self):
+    return self.cell.recurrent_initializer
+
+  @property
+  def bias_initializer(self):
+    return self.cell.bias_initializer
+
+  @property
+  def kernel_regularizer(self):
+    return self.cell.kernel_regularizer
+
+  @property
+  def recurrent_regularizer(self):
+    return self.cell.recurrent_regularizer
+
+  @property
+  def bias_regularizer(self):
+    return self.cell.bias_regularizer
+
+  @property
+  def kernel_constraint(self):
+    return self.cell.kernel_constraint
+
+  @property
+  def recurrent_constraint(self):
+    return self.cell.recurrent_constraint
+
+  @property
+  def bias_constraint(self):
+    return self.cell.bias_constraint
+
+  @property
+  def dropout(self):
+    return self.cell.dropout
+
+  @property
+  def recurrent_dropout(self):
+    return self.cell.recurrent_dropout
+
+  def get_config(self):
+    config = {
+        'units':
+            self.units,
+        'activation':
+            activations.serialize(self.activation),
+        'use_bias':
+            self.use_bias,
+        'kernel_initializer':
+            initializers.serialize(self.kernel_initializer),
+        'recurrent_initializer':
+            initializers.serialize(self.recurrent_initializer),
+        'bias_initializer':
+            initializers.serialize(self.bias_initializer),
+        'kernel_regularizer':
+            regularizers.serialize(self.kernel_regularizer),
+        'recurrent_regularizer':
+            regularizers.serialize(self.recurrent_regularizer),
+        'bias_regularizer':
+            regularizers.serialize(self.bias_regularizer),
+        'activity_regularizer':
+            regularizers.serialize(self.activity_regularizer),
+        'kernel_constraint':
+            constraints.serialize(self.kernel_constraint),
+        'recurrent_constraint':
+            constraints.serialize(self.recurrent_constraint),
+        'bias_constraint':
+            constraints.serialize(self.bias_constraint),
+        'dropout':
+            self.dropout,
+        'recurrent_dropout':
+            self.recurrent_dropout
+    }
+    base_config = super(SimpleRNN, self).get_config()
+    del base_config['cell']
+    return dict(list(base_config.items()) + list(config.items()))
+
+  @classmethod
+  def from_config(cls, config):
+    if 'implementation' in config:
+      config.pop('implementation')
+    return cls(**config)
+
+
+class GRUCell(Layer):
+  """Cell class for the GRU layer.
+
+  Arguments:
+      units: Positive integer, dimensionality of the output space.
+      activation: Activation function to use
+          (see [activations](../activations.md)).
+          If you pass None, no activation is applied
+          (ie. "linear" activation: `a(x) = x`).
+      recurrent_activation: Activation function to use
+          for the recurrent step
+          (see [activations](../activations.md)).
+      use_bias: Boolean, whether the layer uses a bias vector.
+      kernel_initializer: Initializer for the `kernel` weights matrix,
+          used for the linear transformation of the inputs.
+          (see [initializers](../initializers.md)).
+      recurrent_initializer: Initializer for the `recurrent_kernel`
+          weights matrix,
+          used for the linear transformation of the recurrent state.
+          (see [initializers](../initializers.md)).
+      bias_initializer: Initializer for the bias vector
+          (see [initializers](../initializers.md)).
+      kernel_regularizer: Regularizer function applied to
+          the `kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      recurrent_regularizer: Regularizer function applied to
+          the `recurrent_kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      bias_regularizer: Regularizer function applied to the bias vector
+          (see [regularizer](../regularizers.md)).
+      kernel_constraint: Constraint function applied to
+          the `kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      recurrent_constraint: Constraint function applied to
+          the `recurrent_kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      bias_constraint: Constraint function applied to the bias vector
+          (see [constraints](../constraints.md)).
+      dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the inputs.
+      recurrent_dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the recurrent state.
+      implementation: Implementation mode, either 1 or 2.
+          Mode 1 will structure its operations as a larger number of
+          smaller dot products and additions, whereas mode 2 will
+          batch them into fewer, larger operations. These modes will
+          have different performance profiles on different hardware and
+          for different applications.
+  """
+
+  def __init__(self,
+               units,
+               activation='tanh',
+               recurrent_activation='hard_sigmoid',
+               use_bias=True,
+               kernel_initializer='glorot_uniform',
+               recurrent_initializer='orthogonal',
+               bias_initializer='zeros',
+               kernel_regularizer=None,
+               recurrent_regularizer=None,
+               bias_regularizer=None,
+               kernel_constraint=None,
+               recurrent_constraint=None,
+               bias_constraint=None,
+               dropout=0.,
+               recurrent_dropout=0.,
+               implementation=1,
+               **kwargs):
+    super(GRUCell, self).__init__(**kwargs)
+    self.units = units
+    self.activation = activations.get(activation)
+    self.recurrent_activation = activations.get(recurrent_activation)
+    self.use_bias = use_bias
+
+    self.kernel_initializer = initializers.get(kernel_initializer)
+    self.recurrent_initializer = initializers.get(recurrent_initializer)
+    self.bias_initializer = initializers.get(bias_initializer)
+
+    self.kernel_regularizer = regularizers.get(kernel_regularizer)
+    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
+    self.bias_regularizer = regularizers.get(bias_regularizer)
+
+    self.kernel_constraint = constraints.get(kernel_constraint)
+    self.recurrent_constraint = constraints.get(recurrent_constraint)
+    self.bias_constraint = constraints.get(bias_constraint)
+
+    self.dropout = min(1., max(0., dropout))
+    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
+    self.implementation = implementation
+    self.state_size = self.units
+    self._dropout_mask = None
+    self._recurrent_dropout_mask = None
+
+  def build(self, input_shape):
+    input_dim = input_shape[-1]
+    self.kernel = self.add_weight(
+        shape=(input_dim, self.units * 3),
+        name='kernel',
+        initializer=self.kernel_initializer,
+        regularizer=self.kernel_regularizer,
+        constraint=self.kernel_constraint)
+    self.recurrent_kernel = self.add_weight(
+        shape=(self.units, self.units * 3),
+        name='recurrent_kernel',
+        initializer=self.recurrent_initializer,
+        regularizer=self.recurrent_regularizer,
+        constraint=self.recurrent_constraint)
+
+    if self.use_bias:
+      self.bias = self.add_weight(
+          shape=(self.units * 3,),
+          name='bias',
+          initializer=self.bias_initializer,
+          regularizer=self.bias_regularizer,
+          constraint=self.bias_constraint)
+    else:
+      self.bias = None
+
+    self.kernel_z = self.kernel[:, :self.units]
+    self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units]
+    self.kernel_r = self.kernel[:, self.units:self.units * 2]
+    self.recurrent_kernel_r = self.recurrent_kernel[:, self.units:
+                                                    self.units * 2]
+    self.kernel_h = self.kernel[:, self.units * 2:]
+    self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:]
+
+    if self.use_bias:
+      self.bias_z = self.bias[:self.units]
+      self.bias_r = self.bias[self.units:self.units * 2]
+      self.bias_h = self.bias[self.units * 2:]
+    else:
+      self.bias_z = None
+      self.bias_r = None
+      self.bias_h = None
+    self.built = True
+
+  def _generate_dropout_mask(self, inputs, training=None):
+    if 0 < self.dropout < 1:
+      ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
+
+      def dropped_inputs():
+        return K.dropout(ones, self.dropout)
+
+      self._dropout_mask = [
+          K.in_train_phase(dropped_inputs, ones, training=training)
+          for _ in range(3)
+      ]
+    else:
+      self._dropout_mask = None
+
+  def _generate_recurrent_dropout_mask(self, inputs, training=None):
+    if 0 < self.recurrent_dropout < 1:
+      ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
+      ones = K.tile(ones, (1, self.units))
+
+      def dropped_inputs():
+        return K.dropout(ones, self.dropout)
+
+      self._recurrent_dropout_mask = [
+          K.in_train_phase(dropped_inputs, ones, training=training)
+          for _ in range(3)
+      ]
+    else:
+      self._recurrent_dropout_mask = None
+
+  def call(self, inputs, states, training=None):
+    h_tm1 = states[0]  # previous memory
+
+    # dropout matrices for input units
+    dp_mask = self._dropout_mask
+    # dropout matrices for recurrent units
+    rec_dp_mask = self._recurrent_dropout_mask
+
+    if self.implementation == 1:
+      if 0. < self.dropout < 1.:
+        inputs_z = inputs * dp_mask[0]
+        inputs_r = inputs * dp_mask[1]
+        inputs_h = inputs * dp_mask[2]
+      else:
+        inputs_z = inputs
+        inputs_r = inputs
+        inputs_h = inputs
+      x_z = K.dot(inputs_z, self.kernel_z)
+      x_r = K.dot(inputs_r, self.kernel_r)
+      x_h = K.dot(inputs_h, self.kernel_h)
+      if self.use_bias:
+        x_z = K.bias_add(x_z, self.bias_z)
+        x_r = K.bias_add(x_r, self.bias_r)
+        x_h = K.bias_add(x_h, self.bias_h)
+
+      if 0. < self.recurrent_dropout < 1.:
+        h_tm1_z = h_tm1 * rec_dp_mask[0]
+        h_tm1_r = h_tm1 * rec_dp_mask[1]
+        h_tm1_h = h_tm1 * rec_dp_mask[2]
+      else:
+        h_tm1_z = h_tm1
+        h_tm1_r = h_tm1
+        h_tm1_h = h_tm1
+      z = self.recurrent_activation(
+          x_z + K.dot(h_tm1_z, self.recurrent_kernel_z))
+      r = self.recurrent_activation(
+          x_r + K.dot(h_tm1_r, self.recurrent_kernel_r))
+
+      hh = self.activation(x_h + K.dot(r * h_tm1_h, self.recurrent_kernel_h))
+    else:
+      if 0. < self.dropout < 1.:
+        inputs *= dp_mask[0]
+      matrix_x = K.dot(inputs, self.kernel)
+      if self.use_bias:
+        matrix_x = K.bias_add(matrix_x, self.bias)
+      if 0. < self.recurrent_dropout < 1.:
+        h_tm1 *= rec_dp_mask[0]
+      matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
+
+      x_z = matrix_x[:, :self.units]
+      x_r = matrix_x[:, self.units:2 * self.units]
+      recurrent_z = matrix_inner[:, :self.units]
+      recurrent_r = matrix_inner[:, self.units:2 * self.units]
+
+      z = self.recurrent_activation(x_z + recurrent_z)
+      r = self.recurrent_activation(x_r + recurrent_r)
+
+      x_h = matrix_x[:, 2 * self.units:]
+      recurrent_h = K.dot(r * h_tm1, self.recurrent_kernel[:, 2 * self.units:])
+      hh = self.activation(x_h + recurrent_h)
+    h = z * h_tm1 + (1 - z) * hh
+    if 0 < self.dropout + self.recurrent_dropout:
+      if training is None:
+        h._uses_learning_phase = True
+    return h, [h]
+
+
+class GRU(RNN):
+  # pylint: disable=line-too-long
+  """Gated Recurrent Unit - Cho et al.
+
+  2014.
+
+  Arguments:
+      units: Positive integer, dimensionality of the output space.
+      activation: Activation function to use
+          (see [activations](../activations.md)).
+          If you pass None, no activation is applied
+          (ie. "linear" activation: `a(x) = x`).
+      recurrent_activation: Activation function to use
+          for the recurrent step
+          (see [activations](../activations.md)).
+      use_bias: Boolean, whether the layer uses a bias vector.
+      kernel_initializer: Initializer for the `kernel` weights matrix,
+          used for the linear transformation of the inputs.
+          (see [initializers](../initializers.md)).
+      recurrent_initializer: Initializer for the `recurrent_kernel`
+          weights matrix,
+          used for the linear transformation of the recurrent state.
+          (see [initializers](../initializers.md)).
+      bias_initializer: Initializer for the bias vector
+          (see [initializers](../initializers.md)).
+      kernel_regularizer: Regularizer function applied to
+          the `kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      recurrent_regularizer: Regularizer function applied to
+          the `recurrent_kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      bias_regularizer: Regularizer function applied to the bias vector
+          (see [regularizer](../regularizers.md)).
+      activity_regularizer: Regularizer function applied to
+          the output of the layer (its "activation").
+          (see [regularizer](../regularizers.md)).
+      kernel_constraint: Constraint function applied to
+          the `kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      recurrent_constraint: Constraint function applied to
+          the `recurrent_kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      bias_constraint: Constraint function applied to the bias vector
+          (see [constraints](../constraints.md)).
+      dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the inputs.
+      recurrent_dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the recurrent state.
+      implementation: Implementation mode, either 1 or 2.
+          Mode 1 will structure its operations as a larger number of
+          smaller dot products and additions, whereas mode 2 will
+          batch them into fewer, larger operations. These modes will
+          have different performance profiles on different hardware and
+          for different applications.
+      return_sequences: Boolean. Whether to return the last output.
+          in the output sequence, or the full sequence.
+      return_state: Boolean. Whether to return the last state
+          in addition to the output.
+      go_backwards: Boolean (default False).
+          If True, process the input sequence backwards and return the
+          reversed sequence.
+      stateful: Boolean (default False). If True, the last state
+          for each sample at index i in a batch will be used as initial
+          state for the sample of index i in the following batch.
+      unroll: Boolean (default False).
+          If True, the network will be unrolled,
+          else a symbolic loop will be used.
+          Unrolling can speed-up a RNN,
+          although it tends to be more memory-intensive.
+          Unrolling is only suitable for short sequences.
+
+  References:
+      - [On the Properties of Neural Machine Translation: Encoder-Decoder Approaches](https://arxiv.org/abs/1409.1259)
+      - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](http://arxiv.org/abs/1412.3555v1)
+      - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
+  """
+  # pylint: enable=line-too-long
+
+  def __init__(self,
+               units,
+               activation='tanh',
+               recurrent_activation='hard_sigmoid',
+               use_bias=True,
+               kernel_initializer='glorot_uniform',
+               recurrent_initializer='orthogonal',
+               bias_initializer='zeros',
+               kernel_regularizer=None,
+               recurrent_regularizer=None,
+               bias_regularizer=None,
+               activity_regularizer=None,
+               kernel_constraint=None,
+               recurrent_constraint=None,
+               bias_constraint=None,
+               dropout=0.,
+               recurrent_dropout=0.,
+               implementation=1,
+               return_sequences=False,
+               return_state=False,
+               go_backwards=False,
+               stateful=False,
+               unroll=False,
+               **kwargs):
+    if implementation == 0:
+      logging.warning('`implementation=0` has been deprecated, '
+                      'and now defaults to `implementation=1`.'
+                      'Please update your layer call.')
+    cell = GRUCell(
+        units,
+        activation=activation,
+        recurrent_activation=recurrent_activation,
+        use_bias=use_bias,
+        kernel_initializer=kernel_initializer,
+        recurrent_initializer=recurrent_initializer,
+        bias_initializer=bias_initializer,
+        kernel_regularizer=kernel_regularizer,
+        recurrent_regularizer=recurrent_regularizer,
+        bias_regularizer=bias_regularizer,
+        kernel_constraint=kernel_constraint,
+        recurrent_constraint=recurrent_constraint,
+        bias_constraint=bias_constraint,
+        dropout=dropout,
+        recurrent_dropout=recurrent_dropout,
+        implementation=implementation)
+    super(GRU, self).__init__(
+        cell,
+        return_sequences=return_sequences,
+        return_state=return_state,
+        go_backwards=go_backwards,
+        stateful=stateful,
+        unroll=unroll,
+        **kwargs)
+    self.activity_regularizer = regularizers.get(activity_regularizer)
+
+  def call(self, inputs, mask=None, training=None, initial_state=None):
+    self.cell._generate_dropout_mask(inputs, training=training)
+    self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+    return super(GRU, self).call(
+        inputs, mask=mask, training=training, initial_state=initial_state)
+
+  @property
+  def units(self):
+    return self.cell.units
+
+  @property
+  def activation(self):
+    return self.cell.activation
+
+  @property
+  def recurrent_activation(self):
+    return self.cell.recurrent_activation
+
+  @property
+  def use_bias(self):
+    return self.cell.use_bias
+
+  @property
+  def kernel_initializer(self):
+    return self.cell.kernel_initializer
+
+  @property
+  def recurrent_initializer(self):
+    return self.cell.recurrent_initializer
+
+  @property
+  def bias_initializer(self):
+    return self.cell.bias_initializer
+
+  @property
+  def kernel_regularizer(self):
+    return self.cell.kernel_regularizer
+
+  @property
+  def recurrent_regularizer(self):
+    return self.cell.recurrent_regularizer
+
+  @property
+  def bias_regularizer(self):
+    return self.cell.bias_regularizer
+
+  @property
+  def kernel_constraint(self):
+    return self.cell.kernel_constraint
+
+  @property
+  def recurrent_constraint(self):
+    return self.cell.recurrent_constraint
+
+  @property
+  def bias_constraint(self):
+    return self.cell.bias_constraint
+
+  @property
+  def dropout(self):
+    return self.cell.dropout
+
+  @property
+  def recurrent_dropout(self):
+    return self.cell.recurrent_dropout
+
+  @property
+  def implementation(self):
+    return self.cell.implementation
+
+  def get_config(self):
+    config = {
+        'units':
+            self.units,
+        'activation':
+            activations.serialize(self.activation),
+        'recurrent_activation':
+            activations.serialize(self.recurrent_activation),
+        'use_bias':
+            self.use_bias,
+        'kernel_initializer':
+            initializers.serialize(self.kernel_initializer),
+        'recurrent_initializer':
+            initializers.serialize(self.recurrent_initializer),
+        'bias_initializer':
+            initializers.serialize(self.bias_initializer),
+        'kernel_regularizer':
+            regularizers.serialize(self.kernel_regularizer),
+        'recurrent_regularizer':
+            regularizers.serialize(self.recurrent_regularizer),
+        'bias_regularizer':
+            regularizers.serialize(self.bias_regularizer),
+        'activity_regularizer':
+            regularizers.serialize(self.activity_regularizer),
+        'kernel_constraint':
+            constraints.serialize(self.kernel_constraint),
+        'recurrent_constraint':
+            constraints.serialize(self.recurrent_constraint),
+        'bias_constraint':
+            constraints.serialize(self.bias_constraint),
+        'dropout':
+            self.dropout,
+        'recurrent_dropout':
+            self.recurrent_dropout,
+        'implementation':
+            self.implementation
+    }
+    base_config = super(GRU, self).get_config()
+    del base_config['cell']
+    return dict(list(base_config.items()) + list(config.items()))
+
+  @classmethod
+  def from_config(cls, config):
+    if 'implementation' in config and config['implementation'] == 0:
+      config['implementation'] = 1
+    return cls(**config)
+
+
+class LSTMCell(Layer):
+  """Cell class for the LSTM layer.
+
+  Arguments:
+      units: Positive integer, dimensionality of the output space.
+      activation: Activation function to use
+          (see [activations](../activations.md)).
+          If you pass None, no activation is applied
+          (ie. "linear" activation: `a(x) = x`).
+      recurrent_activation: Activation function to use
+          for the recurrent step
+          (see [activations](../activations.md)).
+      use_bias: Boolean, whether the layer uses a bias vector.
+      kernel_initializer: Initializer for the `kernel` weights matrix,
+          used for the linear transformation of the inputs.
+          (see [initializers](../initializers.md)).
+      recurrent_initializer: Initializer for the `recurrent_kernel`
+          weights matrix,
+          used for the linear transformation of the recurrent state.
+          (see [initializers](../initializers.md)).
+      bias_initializer: Initializer for the bias vector
+          (see [initializers](../initializers.md)).
+      unit_forget_bias: Boolean.
+          If True, add 1 to the bias of the forget gate at initialization.
+          Setting it to true will also force `bias_initializer="zeros"`.
+          This is recommended in [Jozefowicz et
+            al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
+      kernel_regularizer: Regularizer function applied to
+          the `kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      recurrent_regularizer: Regularizer function applied to
+          the `recurrent_kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      bias_regularizer: Regularizer function applied to the bias vector
+          (see [regularizer](../regularizers.md)).
+      kernel_constraint: Constraint function applied to
+          the `kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      recurrent_constraint: Constraint function applied to
+          the `recurrent_kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      bias_constraint: Constraint function applied to the bias vector
+          (see [constraints](../constraints.md)).
+      dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the inputs.
+      recurrent_dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the recurrent state.
+      implementation: Implementation mode, either 1 or 2.
+          Mode 1 will structure its operations as a larger number of
+          smaller dot products and additions, whereas mode 2 will
+          batch them into fewer, larger operations. These modes will
+          have different performance profiles on different hardware and
+          for different applications.
+  """
+
+  def __init__(self,
+               units,
+               activation='tanh',
+               recurrent_activation='hard_sigmoid',
+               use_bias=True,
+               kernel_initializer='glorot_uniform',
+               recurrent_initializer='orthogonal',
+               bias_initializer='zeros',
+               unit_forget_bias=True,
+               kernel_regularizer=None,
+               recurrent_regularizer=None,
+               bias_regularizer=None,
+               kernel_constraint=None,
+               recurrent_constraint=None,
+               bias_constraint=None,
+               dropout=0.,
+               recurrent_dropout=0.,
+               implementation=1,
+               **kwargs):
+    super(LSTMCell, self).__init__(**kwargs)
+    self.units = units
+    self.activation = activations.get(activation)
+    self.recurrent_activation = activations.get(recurrent_activation)
+    self.use_bias = use_bias
+
+    self.kernel_initializer = initializers.get(kernel_initializer)
+    self.recurrent_initializer = initializers.get(recurrent_initializer)
+    self.bias_initializer = initializers.get(bias_initializer)
+    self.unit_forget_bias = unit_forget_bias
+
+    self.kernel_regularizer = regularizers.get(kernel_regularizer)
+    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
+    self.bias_regularizer = regularizers.get(bias_regularizer)
+
+    self.kernel_constraint = constraints.get(kernel_constraint)
+    self.recurrent_constraint = constraints.get(recurrent_constraint)
+    self.bias_constraint = constraints.get(bias_constraint)
+
+    self.dropout = min(1., max(0., dropout))
+    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
+    self.implementation = implementation
+    self.state_size = (self.units, self.units)
+    self._dropout_mask = None
+    self._recurrent_dropout_mask = None
+
+  def build(self, input_shape):
+    input_dim = input_shape[-1]
+    self.kernel = self.add_weight(
+        shape=(input_dim, self.units * 4),
+        name='kernel',
+        initializer=self.kernel_initializer,
+        regularizer=self.kernel_regularizer,
+        constraint=self.kernel_constraint)
+    self.recurrent_kernel = self.add_weight(
+        shape=(self.units, self.units * 4),
+        name='recurrent_kernel',
+        initializer=self.recurrent_initializer,
+        regularizer=self.recurrent_regularizer,
+        constraint=self.recurrent_constraint)
+
+    if self.use_bias:
+      if self.unit_forget_bias:
+
+        def bias_initializer(_, *args, **kwargs):
+          return K.concatenate([
+              self.bias_initializer((self.units,), *args, **kwargs),
+              initializers.Ones()((self.units,), *args, **kwargs),
+              self.bias_initializer((self.units * 2,), *args, **kwargs),
+          ])
+      else:
+        bias_initializer = self.bias_initializer
+      self.bias = self.add_weight(
+          shape=(self.units * 4,),
+          name='bias',
+          initializer=bias_initializer,
+          regularizer=self.bias_regularizer,
+          constraint=self.bias_constraint)
+    else:
+      self.bias = None
+
+    self.kernel_i = self.kernel[:, :self.units]
+    self.kernel_f = self.kernel[:, self.units:self.units * 2]
+    self.kernel_c = self.kernel[:, self.units * 2:self.units * 3]
+    self.kernel_o = self.kernel[:, self.units * 3:]
+
+    self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
+    self.recurrent_kernel_f = self.recurrent_kernel[:, self.units:
+                                                    self.units * 2]
+    self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2:
+                                                    self.units * 3]
+    self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]
+
+    if self.use_bias:
+      self.bias_i = self.bias[:self.units]
+      self.bias_f = self.bias[self.units:self.units * 2]
+      self.bias_c = self.bias[self.units * 2:self.units * 3]
+      self.bias_o = self.bias[self.units * 3:]
+    else:
+      self.bias_i = None
+      self.bias_f = None
+      self.bias_c = None
+      self.bias_o = None
+    self.built = True
+
+  def _generate_dropout_mask(self, inputs, training=None):
+    if 0 < self.dropout < 1:
+      ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
+
+      def dropped_inputs():
+        return K.dropout(ones, self.dropout)
+
+      self._dropout_mask = [
+          K.in_train_phase(dropped_inputs, ones, training=training)
+          for _ in range(4)
+      ]
+    else:
+      self._dropout_mask = None
+
+  def _generate_recurrent_dropout_mask(self, inputs, training=None):
+    if 0 < self.recurrent_dropout < 1:
+      ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
+      ones = K.tile(ones, (1, self.units))
+
+      def dropped_inputs():
+        return K.dropout(ones, self.dropout)
+
+      self._recurrent_dropout_mask = [
+          K.in_train_phase(dropped_inputs, ones, training=training)
+          for _ in range(4)
+      ]
+    else:
+      self._recurrent_dropout_mask = None
+
+  def call(self, inputs, states, training=None):
+    # dropout matrices for input units
+    dp_mask = self._dropout_mask
+    # dropout matrices for recurrent units
+    rec_dp_mask = self._recurrent_dropout_mask
+
+    h_tm1 = states[0]  # previous memory state
+    c_tm1 = states[1]  # previous carry state
+
+    if self.implementation == 1:
+      if 0 < self.dropout < 1.:
+        inputs_i = inputs * dp_mask[0]
+        inputs_f = inputs * dp_mask[1]
+        inputs_c = inputs * dp_mask[2]
+        inputs_o = inputs * dp_mask[3]
+      else:
+        inputs_i = inputs
+        inputs_f = inputs
+        inputs_c = inputs
+        inputs_o = inputs
+      x_i = K.dot(inputs_i, self.kernel_i)
+      x_f = K.dot(inputs_f, self.kernel_f)
+      x_c = K.dot(inputs_c, self.kernel_c)
+      x_o = K.dot(inputs_o, self.kernel_o)
+      if self.use_bias:
+        x_i = K.bias_add(x_i, self.bias_i)
+        x_f = K.bias_add(x_f, self.bias_f)
+        x_c = K.bias_add(x_c, self.bias_c)
+        x_o = K.bias_add(x_o, self.bias_o)
+
+      if 0 < self.recurrent_dropout < 1.:
+        h_tm1_i = h_tm1 * rec_dp_mask[0]
+        h_tm1_f = h_tm1 * rec_dp_mask[1]
+        h_tm1_c = h_tm1 * rec_dp_mask[2]
+        h_tm1_o = h_tm1 * rec_dp_mask[3]
+      else:
+        h_tm1_i = h_tm1
+        h_tm1_f = h_tm1
+        h_tm1_c = h_tm1
+        h_tm1_o = h_tm1
+      i = self.recurrent_activation(
+          x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
+      f = self.recurrent_activation(
+          x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
+      c = f * c_tm1 + i * self.activation(
+          x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
+      o = self.recurrent_activation(
+          x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
+    else:
+      if 0. < self.dropout < 1.:
+        inputs *= dp_mask[0]
+      z = K.dot(inputs, self.kernel)
+      if 0. < self.recurrent_dropout < 1.:
+        h_tm1 *= rec_dp_mask[0]
+      z += K.dot(h_tm1, self.recurrent_kernel)
+      if self.use_bias:
+        z = K.bias_add(z, self.bias)
+
+      z0 = z[:, :self.units]
+      z1 = z[:, self.units:2 * self.units]
+      z2 = z[:, 2 * self.units:3 * self.units]
+      z3 = z[:, 3 * self.units:]
+
+      i = self.recurrent_activation(z0)
+      f = self.recurrent_activation(z1)
+      c = f * c_tm1 + i * self.activation(z2)
+      o = self.recurrent_activation(z3)
+
+    h = o * self.activation(c)
+    if 0 < self.dropout + self.recurrent_dropout:
+      if training is None:
+        h._uses_learning_phase = True
+    return h, [h, c]
+
+
+class LSTM(RNN):
+  # pylint: disable=line-too-long
+  """Long-Short Term Memory layer - Hochreiter 1997.
+
+  Arguments:
+      units: Positive integer, dimensionality of the output space.
+      activation: Activation function to use
+          (see [activations](../activations.md)).
+          If you pass None, no activation is applied
+          (ie. "linear" activation: `a(x) = x`).
+      recurrent_activation: Activation function to use
+          for the recurrent step
+          (see [activations](../activations.md)).
+      use_bias: Boolean, whether the layer uses a bias vector.
+      kernel_initializer: Initializer for the `kernel` weights matrix,
+          used for the linear transformation of the inputs.
+          (see [initializers](../initializers.md)).
+      recurrent_initializer: Initializer for the `recurrent_kernel`
+          weights matrix,
+          used for the linear transformation of the recurrent state.
+          (see [initializers](../initializers.md)).
+      bias_initializer: Initializer for the bias vector
+          (see [initializers](../initializers.md)).
+      unit_forget_bias: Boolean.
+          If True, add 1 to the bias of the forget gate at initialization.
+          Setting it to true will also force `bias_initializer="zeros"`.
+          This is recommended in [Jozefowicz et
+            al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
+      kernel_regularizer: Regularizer function applied to
+          the `kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      recurrent_regularizer: Regularizer function applied to
+          the `recurrent_kernel` weights matrix
+          (see [regularizer](../regularizers.md)).
+      bias_regularizer: Regularizer function applied to the bias vector
+          (see [regularizer](../regularizers.md)).
+      activity_regularizer: Regularizer function applied to
+          the output of the layer (its "activation").
+          (see [regularizer](../regularizers.md)).
+      kernel_constraint: Constraint function applied to
+          the `kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      recurrent_constraint: Constraint function applied to
+          the `recurrent_kernel` weights matrix
+          (see [constraints](../constraints.md)).
+      bias_constraint: Constraint function applied to the bias vector
+          (see [constraints](../constraints.md)).
+      dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the inputs.
+      recurrent_dropout: Float between 0 and 1.
+          Fraction of the units to drop for
+          the linear transformation of the recurrent state.
+      implementation: Implementation mode, either 1 or 2.
+          Mode 1 will structure its operations as a larger number of
+          smaller dot products and additions, whereas mode 2 will
+          batch them into fewer, larger operations. These modes will
+          have different performance profiles on different hardware and
+          for different applications.
+      return_sequences: Boolean. Whether to return the last output.
+          in the output sequence, or the full sequence.
+      return_state: Boolean. Whether to return the last state
+          in addition to the output.
+      go_backwards: Boolean (default False).
+          If True, process the input sequence backwards and return the
+          reversed sequence.
+      stateful: Boolean (default False). If True, the last state
+          for each sample at index i in a batch will be used as initial
+          state for the sample of index i in the following batch.
+      unroll: Boolean (default False).
+          If True, the network will be unrolled,
+          else a symbolic loop will be used.
+          Unrolling can speed-up a RNN,
+          although it tends to be more memory-intensive.
+          Unrolling is only suitable for short sequences.
+
+  References:
+      - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf)
+      - [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
+      - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
+      - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
+  """
+  # pylint: enable=line-too-long
+
+  def __init__(self,
+               units,
+               activation='tanh',
+               recurrent_activation='hard_sigmoid',
+               use_bias=True,
+               kernel_initializer='glorot_uniform',
+               recurrent_initializer='orthogonal',
+               bias_initializer='zeros',
+               unit_forget_bias=True,
+               kernel_regularizer=None,
+               recurrent_regularizer=None,
+               bias_regularizer=None,
+               activity_regularizer=None,
+               kernel_constraint=None,
+               recurrent_constraint=None,
+               bias_constraint=None,
+               dropout=0.,
+               recurrent_dropout=0.,
+               implementation=1,
+               return_sequences=False,
+               return_state=False,
+               go_backwards=False,
+               stateful=False,
+               unroll=False,
+               **kwargs):
+    if implementation == 0:
+      logging.warning('`implementation=0` has been deprecated, '
+                      'and now defaults to `implementation=1`.'
+                      'Please update your layer call.')
+    cell = LSTMCell(
+        units,
+        activation=activation,
+        recurrent_activation=recurrent_activation,
+        use_bias=use_bias,
+        kernel_initializer=kernel_initializer,
+        recurrent_initializer=recurrent_initializer,
+        unit_forget_bias=unit_forget_bias,
+        bias_initializer=bias_initializer,
+        kernel_regularizer=kernel_regularizer,
+        recurrent_regularizer=recurrent_regularizer,
+        bias_regularizer=bias_regularizer,
+        kernel_constraint=kernel_constraint,
+        recurrent_constraint=recurrent_constraint,
+        bias_constraint=bias_constraint,
+        dropout=dropout,
+        recurrent_dropout=recurrent_dropout,
+        implementation=implementation)
+    super(LSTM, self).__init__(
+        cell,
+        return_sequences=return_sequences,
+        return_state=return_state,
+        go_backwards=go_backwards,
+        stateful=stateful,
+        unroll=unroll,
+        **kwargs)
+    self.activity_regularizer = regularizers.get(activity_regularizer)
+
+  def call(self, inputs, mask=None, training=None, initial_state=None):
+    self.cell._generate_dropout_mask(inputs, training=training)
+    self.cell._generate_recurrent_dropout_mask(inputs, training=training)
+    return super(LSTM, self).call(
+        inputs, mask=mask, training=training, initial_state=initial_state)
+
+  @property
+  def units(self):
+    return self.cell.units
+
+  @property
+  def activation(self):
+    return self.cell.activation
+
+  @property
+  def recurrent_activation(self):
+    return self.cell.recurrent_activation
+
+  @property
+  def use_bias(self):
+    return self.cell.use_bias
+
+  @property
+  def kernel_initializer(self):
+    return self.cell.kernel_initializer
+
+  @property
+  def recurrent_initializer(self):
+    return self.cell.recurrent_initializer
+
+  @property
+  def bias_initializer(self):
+    return self.cell.bias_initializer
+
+  @property
+  def unit_forget_bias(self):
+    return self.cell.unit_forget_bias
+
+  @property
+  def kernel_regularizer(self):
+    return self.cell.kernel_regularizer
+
+  @property
+  def recurrent_regularizer(self):
+    return self.cell.recurrent_regularizer
+
+  @property
+  def bias_regularizer(self):
+    return self.cell.bias_regularizer
+
+  @property
+  def kernel_constraint(self):
+    return self.cell.kernel_constraint
+
+  @property
+  def recurrent_constraint(self):
+    return self.cell.recurrent_constraint
+
+  @property
+  def bias_constraint(self):
+    return self.cell.bias_constraint
+
+  @property
+  def dropout(self):
+    return self.cell.dropout
+
+  @property
+  def recurrent_dropout(self):
+    return self.cell.recurrent_dropout
+
+  @property
+  def implementation(self):
+    return self.cell.implementation
+
+  def get_config(self):
+    config = {
+        'units':
+            self.units,
+        'activation':
+            activations.serialize(self.activation),
+        'recurrent_activation':
+            activations.serialize(self.recurrent_activation),
+        'use_bias':
+            self.use_bias,
+        'kernel_initializer':
+            initializers.serialize(self.kernel_initializer),
+        'recurrent_initializer':
+            initializers.serialize(self.recurrent_initializer),
+        'bias_initializer':
+            initializers.serialize(self.bias_initializer),
+        'unit_forget_bias':
+            self.unit_forget_bias,
+        'kernel_regularizer':
+            regularizers.serialize(self.kernel_regularizer),
+        'recurrent_regularizer':
+            regularizers.serialize(self.recurrent_regularizer),
+        'bias_regularizer':
+            regularizers.serialize(self.bias_regularizer),
+        'activity_regularizer':
+            regularizers.serialize(self.activity_regularizer),
+        'kernel_constraint':
+            constraints.serialize(self.kernel_constraint),
+        'recurrent_constraint':
+            constraints.serialize(self.recurrent_constraint),
+        'bias_constraint':
+            constraints.serialize(self.bias_constraint),
+        'dropout':
+            self.dropout,
+        'recurrent_dropout':
+            self.recurrent_dropout,
+        'implementation':
+            self.implementation
+    }
+    base_config = super(LSTM, self).get_config()
+    del base_config['cell']
+    return dict(list(base_config.items()) + list(config.items()))
+
+  @classmethod
+  def from_config(cls, config):
+    if 'implementation' in config and config['implementation'] == 0:
+      config['implementation'] = 1
+    return cls(**config)
 
 
 class Recurrent(Layer):
-  """Abstract base class for recurrent layers.
+  """Deprecated abstract base class for recurrent layers.
 
-  Do not use in a model -- it's not a valid layer!
-  Use its children classes `LSTM`, `GRU` and `SimpleRNN` instead.
-
-  All recurrent layers (`LSTM`, `GRU`, `SimpleRNN`) also
-  follow the specifications of this class and accept
-  the keyword arguments listed below.
-
-  Example:
-
-  ```python
-      # as the first layer in a Sequential model
-      model = Sequential()
-      model.add(LSTM(32, input_shape=(10, 64)))
-      # now model.output_shape == (None, 32)
-      # note: `None` is the batch dimension.
-
-      # for subsequent layers, no need to specify the input size:
-      model.add(LSTM(16))
-
-      # to stack recurrent layers, you must use return_sequences=True
-      # on any recurrent layer that feeds into another recurrent layer.
-      # note that you only need to specify the input size on the first layer.
-      model = Sequential()
-      model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True))
-      model.add(LSTM(32, return_sequences=True))
-      model.add(LSTM(10))
-  ```
+  It still exists because it is leveraged by the convolutional-recurrent layers.
+  It will be removed entirely in the future.
+  It was never part of the public API.
+  Do not use.
 
   Arguments:
       weights: list of Numpy arrays to set as initial weights.
@@ -163,7 +2227,7 @@ class Recurrent(Layer):
           at the level of the first layer
           (e.g. via the `input_shape` argument)
 
-  Input shape:s
+  Input shape:
       3D tensor with shape `(batch_size, timesteps, input_dim)`,
       (Optional) 2D tensors with shape `(batch_size, output_dim)`.
 
@@ -439,832 +2503,3 @@ class Recurrent(Layer):
     }
     base_config = super(Recurrent, self).get_config()
     return dict(list(base_config.items()) + list(config.items()))
-
-
-class SimpleRNN(Recurrent):
-  """Fully-connected RNN where the output is to be fed back to input.
-
-  Arguments:
-      units: Positive integer, dimensionality of the output space.
-      activation: Activation function to use.
-          If you don't specify anything, no activation is applied
-          If you pass None, no activation is applied
-          (ie. "linear" activation: `a(x) = x`).
-      use_bias: Boolean, whether the layer uses a bias vector.
-      kernel_initializer: Initializer for the `kernel` weights matrix,
-          used for the linear transformation of the inputs..
-      recurrent_initializer: Initializer for the `recurrent_kernel`
-          weights matrix,
-          used for the linear transformation of the recurrent state..
-      bias_initializer: Initializer for the bias vector.
-      kernel_regularizer: Regularizer function applied to
-          the `kernel` weights matrix.
-      recurrent_regularizer: Regularizer function applied to
-          the `recurrent_kernel` weights matrix.
-      bias_regularizer: Regularizer function applied to the bias vector.
-      activity_regularizer: Regularizer function applied to
-          the output of the layer (its "activation")..
-      kernel_constraint: Constraint function applied to
-          the `kernel` weights matrix.
-      recurrent_constraint: Constraint function applied to
-          the `recurrent_kernel` weights matrix.
-      bias_constraint: Constraint function applied to the bias vector.
-      dropout: Float between 0 and 1.
-          Fraction of the units to drop for
-          the linear transformation of the inputs.
-      recurrent_dropout: Float between 0 and 1.
-          Fraction of the units to drop for
-          the linear transformation of the recurrent state.
-
-  References:
-      - [A Theoretically Grounded Application of Dropout in Recurrent Neural
-        Networks](http://arxiv.org/abs/1512.05287)
-  """
-
-  def __init__(self,
-               units,
-               activation='tanh',
-               use_bias=True,
-               kernel_initializer='glorot_uniform',
-               recurrent_initializer='orthogonal',
-               bias_initializer='zeros',
-               kernel_regularizer=None,
-               recurrent_regularizer=None,
-               bias_regularizer=None,
-               activity_regularizer=None,
-               kernel_constraint=None,
-               recurrent_constraint=None,
-               bias_constraint=None,
-               dropout=0.,
-               recurrent_dropout=0.,
-               **kwargs):
-    super(SimpleRNN, self).__init__(
-        activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
-    self.units = units
-    self.activation = activations.get(activation)
-    self.use_bias = use_bias
-
-    self.kernel_initializer = initializers.get(kernel_initializer)
-    self.recurrent_initializer = initializers.get(recurrent_initializer)
-    self.bias_initializer = initializers.get(bias_initializer)
-
-    self.kernel_regularizer = regularizers.get(kernel_regularizer)
-    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
-    self.bias_regularizer = regularizers.get(bias_regularizer)
-
-    self.kernel_constraint = constraints.get(kernel_constraint)
-    self.recurrent_constraint = constraints.get(recurrent_constraint)
-    self.bias_constraint = constraints.get(bias_constraint)
-
-    self.dropout = min(1., max(0., dropout))
-    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
-    self.state_spec = InputSpec(shape=(None, self.units))
-
-  def build(self, input_shape):
-    if isinstance(input_shape, list):
-      input_shape = input_shape[0]
-    input_shape = tensor_shape.TensorShape(input_shape).as_list()
-
-    batch_size = input_shape[0] if self.stateful else None
-    self.input_dim = input_shape[2]
-    self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
-
-    self.states = [None]
-    if self.stateful:
-      self.reset_states()
-
-    self.kernel = self.add_weight(
-        shape=(self.input_dim, self.units),
-        name='kernel',
-        initializer=self.kernel_initializer,
-        regularizer=self.kernel_regularizer,
-        constraint=self.kernel_constraint)
-    self.recurrent_kernel = self.add_weight(
-        shape=(self.units, self.units),
-        name='recurrent_kernel',
-        initializer=self.recurrent_initializer,
-        regularizer=self.recurrent_regularizer,
-        constraint=self.recurrent_constraint)
-    if self.use_bias:
-      self.bias = self.add_weight(
-          shape=(self.units,),
-          name='bias',
-          initializer=self.bias_initializer,
-          regularizer=self.bias_regularizer,
-          constraint=self.bias_constraint)
-    else:
-      self.bias = None
-    self.built = True
-
-  def preprocess_input(self, inputs, training=None):
-    if self.implementation > 0:
-      return inputs
-    else:
-      input_shape = inputs.get_shape().as_list()
-      input_dim = input_shape[2]
-      timesteps = input_shape[1]
-      return _time_distributed_dense(
-          inputs,
-          self.kernel,
-          self.bias,
-          self.dropout,
-          input_dim,
-          self.units,
-          timesteps,
-          training=training)
-
-  def step(self, inputs, states):
-    if self.implementation == 0:
-      h = inputs
-    else:
-      if 0 < self.dropout < 1:
-        h = K.dot(inputs * states[1], self.kernel)
-      else:
-        h = K.dot(inputs, self.kernel)
-      if self.bias is not None:
-        h = K.bias_add(h, self.bias)
-
-    prev_output = states[0]
-    if 0 < self.recurrent_dropout < 1:
-      prev_output *= states[2]
-    output = h + K.dot(prev_output, self.recurrent_kernel)
-    if self.activation is not None:
-      output = self.activation(output)
-
-    # Properly set learning phase on output tensor.
-    if 0 < self.dropout + self.recurrent_dropout:
-      output._uses_learning_phase = True
-    return output, [output]
-
-  def get_constants(self, inputs, training=None):
-    constants = []
-    if self.implementation != 0 and 0 < self.dropout < 1:
-      input_shape = K.int_shape(inputs)
-      input_dim = input_shape[-1]
-      ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
-      ones = K.tile(ones, (1, int(input_dim)))
-
-      def dropped_inputs():
-        return K.dropout(ones, self.dropout)
-
-      dp_mask = K.in_train_phase(dropped_inputs, ones, training=training)
-      constants.append(dp_mask)
-    else:
-      constants.append(K.cast_to_floatx(1.))
-
-    if 0 < self.recurrent_dropout < 1:
-      ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
-      ones = K.tile(ones, (1, self.units))
-
-      def dropped_inputs():  # pylint: disable=function-redefined
-        return K.dropout(ones, self.recurrent_dropout)
-
-      rec_dp_mask = K.in_train_phase(dropped_inputs, ones, training=training)
-      constants.append(rec_dp_mask)
-    else:
-      constants.append(K.cast_to_floatx(1.))
-    return constants
-
-  def get_config(self):
-    config = {
-        'units': self.units,
-        'activation': activations.serialize(self.activation),
-        'use_bias': self.use_bias,
-        'kernel_initializer': initializers.serialize(self.kernel_initializer),
-        'recurrent_initializer':
-            initializers.serialize(self.recurrent_initializer),
-        'bias_initializer': initializers.serialize(self.bias_initializer),
-        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
-        'recurrent_regularizer':
-            regularizers.serialize(self.recurrent_regularizer),
-        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
-        'activity_regularizer':
-            regularizers.serialize(self.activity_regularizer),
-        'kernel_constraint': constraints.serialize(self.kernel_constraint),
-        'recurrent_constraint':
-            constraints.serialize(self.recurrent_constraint),
-        'bias_constraint': constraints.serialize(self.bias_constraint),
-        'dropout': self.dropout,
-        'recurrent_dropout': self.recurrent_dropout
-    }
-    base_config = super(SimpleRNN, self).get_config()
-    return dict(list(base_config.items()) + list(config.items()))
-
-
-class GRU(Recurrent):
-  """Gated Recurrent Unit - Cho et al.
-
-  2014.
-
-  Arguments:
-      units: Positive integer, dimensionality of the output space.
-      activation: Activation function to use.
-          If you pass None, no activation is applied
-          (ie. "linear" activation: `a(x) = x`).
-      recurrent_activation: Activation function to use
-          for the recurrent step.
-      use_bias: Boolean, whether the layer uses a bias vector.
-      kernel_initializer: Initializer for the `kernel` weights matrix,
-          used for the linear transformation of the inputs..
-      recurrent_initializer: Initializer for the `recurrent_kernel`
-          weights matrix,
-          used for the linear transformation of the recurrent state..
-      bias_initializer: Initializer for the bias vector.
-      kernel_regularizer: Regularizer function applied to
-          the `kernel` weights matrix.
-      recurrent_regularizer: Regularizer function applied to
-          the `recurrent_kernel` weights matrix.
-      bias_regularizer: Regularizer function applied to the bias vector.
-      activity_regularizer: Regularizer function applied to
-          the output of the layer (its "activation")..
-      kernel_constraint: Constraint function applied to
-          the `kernel` weights matrix.
-      recurrent_constraint: Constraint function applied to
-          the `recurrent_kernel` weights matrix.
-      bias_constraint: Constraint function applied to the bias vector.
-      dropout: Float between 0 and 1.
-          Fraction of the units to drop for
-          the linear transformation of the inputs.
-      recurrent_dropout: Float between 0 and 1.
-          Fraction of the units to drop for
-          the linear transformation of the recurrent state.
-
-  References:
-      - [On the Properties of Neural Machine Translation: Encoder-Decoder
-        Approaches](https://arxiv.org/abs/1409.1259)
-      - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
-        Modeling](http://arxiv.org/abs/1412.3555v1)
-      - [A Theoretically Grounded Application of Dropout in Recurrent Neural
-        Networks](http://arxiv.org/abs/1512.05287)
-  """
-
-  def __init__(self,
-               units,
-               activation='tanh',
-               recurrent_activation='hard_sigmoid',
-               use_bias=True,
-               kernel_initializer='glorot_uniform',
-               recurrent_initializer='orthogonal',
-               bias_initializer='zeros',
-               kernel_regularizer=None,
-               recurrent_regularizer=None,
-               bias_regularizer=None,
-               activity_regularizer=None,
-               kernel_constraint=None,
-               recurrent_constraint=None,
-               bias_constraint=None,
-               dropout=0.,
-               recurrent_dropout=0.,
-               **kwargs):
-    super(GRU, self).__init__(
-        activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
-    self.units = units
-    self.activation = activations.get(activation)
-    self.recurrent_activation = activations.get(recurrent_activation)
-    self.use_bias = use_bias
-
-    self.kernel_initializer = initializers.get(kernel_initializer)
-    self.recurrent_initializer = initializers.get(recurrent_initializer)
-    self.bias_initializer = initializers.get(bias_initializer)
-
-    self.kernel_regularizer = regularizers.get(kernel_regularizer)
-    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
-    self.bias_regularizer = regularizers.get(bias_regularizer)
-
-    self.kernel_constraint = constraints.get(kernel_constraint)
-    self.recurrent_constraint = constraints.get(recurrent_constraint)
-    self.bias_constraint = constraints.get(bias_constraint)
-
-    self.dropout = min(1., max(0., dropout))
-    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
-    self.state_spec = InputSpec(shape=(None, self.units))
-
-  def build(self, input_shape):
-    if isinstance(input_shape, list):
-      input_shape = input_shape[0]
-    input_shape = tensor_shape.TensorShape(input_shape).as_list()
-    batch_size = input_shape[0] if self.stateful else None
-    self.input_dim = input_shape[2]
-    self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
-
-    self.states = [None]
-    if self.stateful:
-      self.reset_states()
-
-    self.kernel = self.add_weight(
-        shape=(self.input_dim, self.units * 3),
-        name='kernel',
-        initializer=self.kernel_initializer,
-        regularizer=self.kernel_regularizer,
-        constraint=self.kernel_constraint)
-    self.recurrent_kernel = self.add_weight(
-        shape=(self.units, self.units * 3),
-        name='recurrent_kernel',
-        initializer=self.recurrent_initializer,
-        regularizer=self.recurrent_regularizer,
-        constraint=self.recurrent_constraint)
-
-    if self.use_bias:
-      self.bias = self.add_weight(
-          shape=(self.units * 3,),
-          name='bias',
-          initializer=self.bias_initializer,
-          regularizer=self.bias_regularizer,
-          constraint=self.bias_constraint)
-    else:
-      self.bias = None
-
-    self.kernel_z = self.kernel[:, :self.units]
-    self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units]
-    self.kernel_r = self.kernel[:, self.units:self.units * 2]
-    self.recurrent_kernel_r = self.recurrent_kernel[:, self.units:
-                                                    self.units * 2]
-    self.kernel_h = self.kernel[:, self.units * 2:]
-    self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:]
-
-    if self.use_bias:
-      self.bias_z = self.bias[:self.units]
-      self.bias_r = self.bias[self.units:self.units * 2]
-      self.bias_h = self.bias[self.units * 2:]
-    else:
-      self.bias_z = None
-      self.bias_r = None
-      self.bias_h = None
-    self.built = True
-
-  def preprocess_input(self, inputs, training=None):
-    if self.implementation == 0:
-      input_shape = inputs.get_shape().as_list()
-      input_dim = input_shape[2]
-      timesteps = input_shape[1]
-
-      x_z = _time_distributed_dense(
-          inputs,
-          self.kernel_z,
-          self.bias_z,
-          self.dropout,
-          input_dim,
-          self.units,
-          timesteps,
-          training=training)
-      x_r = _time_distributed_dense(
-          inputs,
-          self.kernel_r,
-          self.bias_r,
-          self.dropout,
-          input_dim,
-          self.units,
-          timesteps,
-          training=training)
-      x_h = _time_distributed_dense(
-          inputs,
-          self.kernel_h,
-          self.bias_h,
-          self.dropout,
-          input_dim,
-          self.units,
-          timesteps,
-          training=training)
-      return K.concatenate([x_z, x_r, x_h], axis=2)
-    else:
-      return inputs
-
-  def get_constants(self, inputs, training=None):
-    constants = []
-    if self.implementation != 0 and 0 < self.dropout < 1:
-      input_shape = K.int_shape(inputs)
-      input_dim = input_shape[-1]
-      ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
-      ones = K.tile(ones, (1, int(input_dim)))
-
-      def dropped_inputs():
-        return K.dropout(ones, self.dropout)
-
-      dp_mask = [
-          K.in_train_phase(dropped_inputs, ones, training=training)
-          for _ in range(3)
-      ]
-      constants.append(dp_mask)
-    else:
-      constants.append([K.cast_to_floatx(1.) for _ in range(3)])
-
-    if 0 < self.recurrent_dropout < 1:
-      ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
-      ones = K.tile(ones, (1, self.units))
-
-      def dropped_inputs():  # pylint: disable=function-redefined
-        return K.dropout(ones, self.recurrent_dropout)
-
-      rec_dp_mask = [
-          K.in_train_phase(dropped_inputs, ones, training=training)
-          for _ in range(3)
-      ]
-      constants.append(rec_dp_mask)
-    else:
-      constants.append([K.cast_to_floatx(1.) for _ in range(3)])
-    return constants
-
-  def step(self, inputs, states):
-    h_tm1 = states[0]  # previous memory
-    dp_mask = states[1]  # dropout matrices for recurrent units
-    rec_dp_mask = states[2]
-
-    if self.implementation == 2:
-      matrix_x = K.dot(inputs * dp_mask[0], self.kernel)
-      if self.use_bias:
-        matrix_x = K.bias_add(matrix_x, self.bias)
-      matrix_inner = K.dot(h_tm1 * rec_dp_mask[0],
-                           self.recurrent_kernel[:, :2 * self.units])
-
-      x_z = matrix_x[:, :self.units]
-      x_r = matrix_x[:, self.units:2 * self.units]
-      recurrent_z = matrix_inner[:, :self.units]
-      recurrent_r = matrix_inner[:, self.units:2 * self.units]
-
-      z = self.recurrent_activation(x_z + recurrent_z)
-      r = self.recurrent_activation(x_r + recurrent_r)
-
-      x_h = matrix_x[:, 2 * self.units:]
-      recurrent_h = K.dot(r * h_tm1 * rec_dp_mask[0],
-                          self.recurrent_kernel[:, 2 * self.units:])
-      hh = self.activation(x_h + recurrent_h)
-    else:
-      if self.implementation == 0:
-        x_z = inputs[:, :self.units]
-        x_r = inputs[:, self.units:2 * self.units]
-        x_h = inputs[:, 2 * self.units:]
-      elif self.implementation == 1:
-        x_z = K.dot(inputs * dp_mask[0], self.kernel_z)
-        x_r = K.dot(inputs * dp_mask[1], self.kernel_r)
-        x_h = K.dot(inputs * dp_mask[2], self.kernel_h)
-        if self.use_bias:
-          x_z = K.bias_add(x_z, self.bias_z)
-          x_r = K.bias_add(x_r, self.bias_r)
-          x_h = K.bias_add(x_h, self.bias_h)
-      else:
-        raise ValueError('Unknown `implementation` mode.')
-      z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0],
-                                                self.recurrent_kernel_z))
-      r = self.recurrent_activation(x_r + K.dot(h_tm1 * rec_dp_mask[1],
-                                                self.recurrent_kernel_r))
-
-      hh = self.activation(x_h + K.dot(r * h_tm1 * rec_dp_mask[2],
-                                       self.recurrent_kernel_h))
-    h = z * h_tm1 + (1 - z) * hh
-    if 0 < self.dropout + self.recurrent_dropout:
-      h._uses_learning_phase = True
-    return h, [h]
-
-  def get_config(self):
-    config = {
-        'units': self.units,
-        'activation': activations.serialize(self.activation),
-        'recurrent_activation':
-            activations.serialize(self.recurrent_activation),
-        'use_bias': self.use_bias,
-        'kernel_initializer': initializers.serialize(self.kernel_initializer),
-        'recurrent_initializer':
-            initializers.serialize(self.recurrent_initializer),
-        'bias_initializer': initializers.serialize(self.bias_initializer),
-        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
-        'recurrent_regularizer':
-            regularizers.serialize(self.recurrent_regularizer),
-        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
-        'activity_regularizer':
-            regularizers.serialize(self.activity_regularizer),
-        'kernel_constraint': constraints.serialize(self.kernel_constraint),
-        'recurrent_constraint':
-            constraints.serialize(self.recurrent_constraint),
-        'bias_constraint': constraints.serialize(self.bias_constraint),
-        'dropout': self.dropout,
-        'recurrent_dropout': self.recurrent_dropout
-    }
-    base_config = super(GRU, self).get_config()
-    return dict(list(base_config.items()) + list(config.items()))
-
-
-class LSTM(Recurrent):
-  """Long-Short Term Memory unit - Hochreiter 1997.
-
-  For a step-by-step description of the algorithm, see
-  [this tutorial](http://deeplearning.net/tutorial/lstm.html).
-
-  Arguments:
-      units: Positive integer, dimensionality of the output space.
-      activation: Activation function to use.
-          If you pass None, no activation is applied
-          (ie. "linear" activation: `a(x) = x`).
-      recurrent_activation: Activation function to use
-          for the recurrent step.
-      use_bias: Boolean, whether the layer uses a bias vector.
-      kernel_initializer: Initializer for the `kernel` weights matrix,
-          used for the linear transformation of the inputs..
-      recurrent_initializer: Initializer for the `recurrent_kernel`
-          weights matrix,
-          used for the linear transformation of the recurrent state..
-      bias_initializer: Initializer for the bias vector.
-      unit_forget_bias: Boolean.
-          If True, add 1 to the bias of the forget gate at initialization.
-          Setting it to true will also force `bias_initializer="zeros"`.
-          This is recommended in [Jozefowicz et
-            al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
-      kernel_regularizer: Regularizer function applied to
-          the `kernel` weights matrix.
-      recurrent_regularizer: Regularizer function applied to
-          the `recurrent_kernel` weights matrix.
-      bias_regularizer: Regularizer function applied to the bias vector.
-      activity_regularizer: Regularizer function applied to
-          the output of the layer (its "activation")..
-      kernel_constraint: Constraint function applied to
-          the `kernel` weights matrix.
-      recurrent_constraint: Constraint function applied to
-          the `recurrent_kernel` weights matrix.
-      bias_constraint: Constraint function applied to the bias vector.
-      dropout: Float between 0 and 1.
-          Fraction of the units to drop for
-          the linear transformation of the inputs.
-      recurrent_dropout: Float between 0 and 1.
-          Fraction of the units to drop for
-          the linear transformation of the recurrent state.
-
-  References:
-      - [Long short-term
-        memory]((http://www.bioinf.jku.at/publications/older/2604.pdf)
-        (original 1997 paper)
-      - [Supervised sequence labeling with recurrent neural
-        networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
-      - [A Theoretically Grounded Application of Dropout in Recurrent Neural
-        Networks](http://arxiv.org/abs/1512.05287)
-  """
-
-  def __init__(self,
-               units,
-               activation='tanh',
-               recurrent_activation='hard_sigmoid',
-               use_bias=True,
-               kernel_initializer='glorot_uniform',
-               recurrent_initializer='orthogonal',
-               bias_initializer='zeros',
-               unit_forget_bias=True,
-               kernel_regularizer=None,
-               recurrent_regularizer=None,
-               bias_regularizer=None,
-               activity_regularizer=None,
-               kernel_constraint=None,
-               recurrent_constraint=None,
-               bias_constraint=None,
-               dropout=0.,
-               recurrent_dropout=0.,
-               **kwargs):
-    super(LSTM, self).__init__(
-        activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
-    self.units = units
-    self.activation = activations.get(activation)
-    self.recurrent_activation = activations.get(recurrent_activation)
-    self.use_bias = use_bias
-
-    self.kernel_initializer = initializers.get(kernel_initializer)
-    self.recurrent_initializer = initializers.get(recurrent_initializer)
-    self.bias_initializer = initializers.get(bias_initializer)
-    self.unit_forget_bias = unit_forget_bias
-
-    self.kernel_regularizer = regularizers.get(kernel_regularizer)
-    self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
-    self.bias_regularizer = regularizers.get(bias_regularizer)
-
-    self.kernel_constraint = constraints.get(kernel_constraint)
-    self.recurrent_constraint = constraints.get(recurrent_constraint)
-    self.bias_constraint = constraints.get(bias_constraint)
-
-    self.dropout = min(1., max(0., dropout))
-    self.recurrent_dropout = min(1., max(0., recurrent_dropout))
-    self.state_spec = [
-        InputSpec(shape=(None, self.units)),
-        InputSpec(shape=(None, self.units))
-    ]
-
-  def build(self, input_shape):
-    if isinstance(input_shape, list):
-      input_shape = input_shape[0]
-    input_shape = tensor_shape.TensorShape(input_shape).as_list()
-    batch_size = input_shape[0] if self.stateful else None
-    self.input_dim = input_shape[2]
-    self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
-
-    self.states = [None, None]
-    if self.stateful:
-      self.reset_states()
-
-    self.kernel = self.add_weight(
-        shape=(self.input_dim, self.units * 4),
-        name='kernel',
-        initializer=self.kernel_initializer,
-        regularizer=self.kernel_regularizer,
-        constraint=self.kernel_constraint)
-    self.recurrent_kernel = self.add_weight(
-        shape=(self.units, self.units * 4),
-        name='recurrent_kernel',
-        initializer=self.recurrent_initializer,
-        regularizer=self.recurrent_regularizer,
-        constraint=self.recurrent_constraint)
-
-    if self.use_bias:
-      if self.unit_forget_bias:
-
-        def bias_initializer(_, *args, **kwargs):
-          return K.concatenate([
-              self.bias_initializer((self.units,), *args, **kwargs),
-              initializers.Ones()((self.units,), *args, **kwargs),
-              self.bias_initializer((self.units * 2,), *args, **kwargs),
-          ])
-      else:
-        bias_initializer = self.bias_initializer
-      self.bias = self.add_weight(
-          shape=(self.units * 4,),
-          name='bias',
-          initializer=bias_initializer,
-          regularizer=self.bias_regularizer,
-          constraint=self.bias_constraint)
-    else:
-      self.bias = None
-
-    self.kernel_i = self.kernel[:, :self.units]
-    self.kernel_f = self.kernel[:, self.units:self.units * 2]
-    self.kernel_c = self.kernel[:, self.units * 2:self.units * 3]
-    self.kernel_o = self.kernel[:, self.units * 3:]
-
-    self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units]
-    self.recurrent_kernel_f = self.recurrent_kernel[:, self.units:
-                                                    self.units * 2]
-    self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2:
-                                                    self.units * 3]
-    self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:]
-
-    if self.use_bias:
-      self.bias_i = self.bias[:self.units]
-      self.bias_f = self.bias[self.units:self.units * 2]
-      self.bias_c = self.bias[self.units * 2:self.units * 3]
-      self.bias_o = self.bias[self.units * 3:]
-    else:
-      self.bias_i = None
-      self.bias_f = None
-      self.bias_c = None
-      self.bias_o = None
-    self.built = True
-
-  def preprocess_input(self, inputs, training=None):
-    if self.implementation == 0:
-      input_shape = inputs.get_shape().as_list()
-      input_dim = input_shape[2]
-      timesteps = input_shape[1]
-
-      x_i = _time_distributed_dense(
-          inputs,
-          self.kernel_i,
-          self.bias_i,
-          self.dropout,
-          input_dim,
-          self.units,
-          timesteps,
-          training=training)
-      x_f = _time_distributed_dense(
-          inputs,
-          self.kernel_f,
-          self.bias_f,
-          self.dropout,
-          input_dim,
-          self.units,
-          timesteps,
-          training=training)
-      x_c = _time_distributed_dense(
-          inputs,
-          self.kernel_c,
-          self.bias_c,
-          self.dropout,
-          input_dim,
-          self.units,
-          timesteps,
-          training=training)
-      x_o = _time_distributed_dense(
-          inputs,
-          self.kernel_o,
-          self.bias_o,
-          self.dropout,
-          input_dim,
-          self.units,
-          timesteps,
-          training=training)
-      return K.concatenate([x_i, x_f, x_c, x_o], axis=2)
-    else:
-      return inputs
-
-  def get_constants(self, inputs, training=None):
-    constants = []
-    if self.implementation != 0 and 0 < self.dropout < 1:
-      input_shape = K.int_shape(inputs)
-      input_dim = input_shape[-1]
-      ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
-      ones = K.tile(ones, (1, int(input_dim)))
-
-      def dropped_inputs():
-        return K.dropout(ones, self.dropout)
-
-      dp_mask = [
-          K.in_train_phase(dropped_inputs, ones, training=training)
-          for _ in range(4)
-      ]
-      constants.append(dp_mask)
-    else:
-      constants.append([K.cast_to_floatx(1.) for _ in range(4)])
-
-    if 0 < self.recurrent_dropout < 1:
-      ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
-      ones = K.tile(ones, (1, self.units))
-
-      def dropped_inputs():  # pylint: disable=function-redefined
-        return K.dropout(ones, self.recurrent_dropout)
-
-      rec_dp_mask = [
-          K.in_train_phase(dropped_inputs, ones, training=training)
-          for _ in range(4)
-      ]
-      constants.append(rec_dp_mask)
-    else:
-      constants.append([K.cast_to_floatx(1.) for _ in range(4)])
-    return constants
-
-  def step(self, inputs, states):
-    h_tm1 = states[0]
-    c_tm1 = states[1]
-    dp_mask = states[2]
-    rec_dp_mask = states[3]
-
-    if self.implementation == 2:
-      z = K.dot(inputs * dp_mask[0], self.kernel)
-      z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel)
-      if self.use_bias:
-        z = K.bias_add(z, self.bias)
-
-      z0 = z[:, :self.units]
-      z1 = z[:, self.units:2 * self.units]
-      z2 = z[:, 2 * self.units:3 * self.units]
-      z3 = z[:, 3 * self.units:]
-
-      i = self.recurrent_activation(z0)
-      f = self.recurrent_activation(z1)
-      c = f * c_tm1 + i * self.activation(z2)
-      o = self.recurrent_activation(z3)
-    else:
-      if self.implementation == 0:
-        x_i = inputs[:, :self.units]
-        x_f = inputs[:, self.units:2 * self.units]
-        x_c = inputs[:, 2 * self.units:3 * self.units]
-        x_o = inputs[:, 3 * self.units:]
-      elif self.implementation == 1:
-        x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i
-        x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f
-        x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c
-        x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o
-      else:
-        raise ValueError('Unknown `implementation` mode.')
-
-      i = self.recurrent_activation(x_i + K.dot(h_tm1 * rec_dp_mask[0],
-                                                self.recurrent_kernel_i))
-      f = self.recurrent_activation(x_f + K.dot(h_tm1 * rec_dp_mask[1],
-                                                self.recurrent_kernel_f))
-      c = f * c_tm1 + i * self.activation(
-          x_c + K.dot(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c))
-      o = self.recurrent_activation(x_o + K.dot(h_tm1 * rec_dp_mask[3],
-                                                self.recurrent_kernel_o))
-    h = o * self.activation(c)
-    if 0 < self.dropout + self.recurrent_dropout:
-      h._uses_learning_phase = True
-    return h, [h, c]
-
-  def get_config(self):
-    config = {
-        'units': self.units,
-        'activation': activations.serialize(self.activation),
-        'recurrent_activation':
-            activations.serialize(self.recurrent_activation),
-        'use_bias': self.use_bias,
-        'kernel_initializer': initializers.serialize(self.kernel_initializer),
-        'recurrent_initializer':
-            initializers.serialize(self.recurrent_initializer),
-        'bias_initializer': initializers.serialize(self.bias_initializer),
-        'unit_forget_bias': self.unit_forget_bias,
-        'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
-        'recurrent_regularizer':
-            regularizers.serialize(self.recurrent_regularizer),
-        'bias_regularizer': regularizers.serialize(self.bias_regularizer),
-        'activity_regularizer':
-            regularizers.serialize(self.activity_regularizer),
-        'kernel_constraint': constraints.serialize(self.kernel_constraint),
-        'recurrent_constraint':
-            constraints.serialize(self.recurrent_constraint),
-        'bias_constraint': constraints.serialize(self.bias_constraint),
-        'dropout': self.dropout,
-        'recurrent_dropout': self.recurrent_dropout
-    }
-    base_config = super(LSTM, self).get_config()
-    return dict(list(base_config.items()) + list(config.items()))
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
new file mode 100644
index 00000000000..b1f89a30bb3
--- /dev/null
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py
@@ -0,0 +1,378 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for recurrent layers functionality other than GRU, LSTM, SimpleRNN.
+
+See also: lstm_test.py, gru_test.py, simplernn_test.py.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.keras._impl import keras
+from tensorflow.python.platform import test
+
+
+class RNNTest(test.TestCase):
+
+  def test_minimal_rnn_cell_non_layer(self):
+
+    class MinimalRNNCell(object):
+
+      def __init__(self, units, input_dim):
+        self.units = units
+        self.state_size = units
+        self.kernel = keras.backend.variable(
+            np.random.random((input_dim, units)))
+
+      def call(self, inputs, states):
+        prev_output = states[0]
+        output = keras.backend.dot(inputs, self.kernel) + prev_output
+        return output, [output]
+
+    with self.test_session():
+      # Basic test case.
+      cell = MinimalRNNCell(32, 5)
+      x = keras.Input((None, 5))
+      layer = keras.layers.RNN(cell)
+      y = layer(x)
+      model = keras.models.Model(x, y)
+      model.compile(optimizer='rmsprop', loss='mse')
+      model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+      # Test stacking.
+      cells = [MinimalRNNCell(8, 5),
+               MinimalRNNCell(32, 8),
+               MinimalRNNCell(32, 32)]
+      layer = keras.layers.RNN(cells)
+      y = layer(x)
+      model = keras.models.Model(x, y)
+      model.compile(optimizer='rmsprop', loss='mse')
+      model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+  def test_minimal_rnn_cell_non_layer_multiple_states(self):
+
+    class MinimalRNNCell(object):
+
+      def __init__(self, units, input_dim):
+        self.units = units
+        self.state_size = (units, units)
+        self.kernel = keras.backend.variable(
+            np.random.random((input_dim, units)))
+
+      def call(self, inputs, states):
+        prev_output_1 = states[0]
+        prev_output_2 = states[1]
+        output = keras.backend.dot(inputs, self.kernel)
+        output += prev_output_1
+        output -= prev_output_2
+        return output, [output * 2, output * 3]
+
+    with self.test_session():
+      # Basic test case.
+      cell = MinimalRNNCell(32, 5)
+      x = keras.Input((None, 5))
+      layer = keras.layers.RNN(cell)
+      y = layer(x)
+      model = keras.models.Model(x, y)
+      model.compile(optimizer='rmsprop', loss='mse')
+      model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+      # Test stacking.
+      cells = [MinimalRNNCell(8, 5),
+               MinimalRNNCell(16, 8),
+               MinimalRNNCell(32, 16)]
+      layer = keras.layers.RNN(cells)
+      assert layer.cell.state_size == (32, 32, 16, 16, 8, 8)
+      y = layer(x)
+      model = keras.models.Model(x, y)
+      model.compile(optimizer='rmsprop', loss='mse')
+      model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+  def test_minimal_rnn_cell_layer(self):
+
+    class MinimalRNNCell(keras.layers.Layer):
+
+      def __init__(self, units, **kwargs):
+        self.units = units
+        self.state_size = units
+        super(MinimalRNNCell, self).__init__(**kwargs)
+
+      def build(self, input_shape):
+        self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
+                                      initializer='uniform',
+                                      name='kernel')
+        self.recurrent_kernel = self.add_weight(
+            shape=(self.units, self.units),
+            initializer='uniform',
+            name='recurrent_kernel')
+        self.built = True
+
+      def call(self, inputs, states):
+        prev_output = states[0]
+        h = keras.backend.dot(inputs, self.kernel)
+        output = h + keras.backend.dot(prev_output, self.recurrent_kernel)
+        return output, [output]
+
+      def get_config(self):
+        config = {'units': self.units}
+        base_config = super(MinimalRNNCell, self).get_config()
+        return dict(list(base_config.items()) + list(config.items()))
+
+    with self.test_session():
+      # Test basic case.
+      x = keras.Input((None, 5))
+      cell = MinimalRNNCell(32)
+      layer = keras.layers.RNN(cell)
+      y = layer(x)
+      model = keras.models.Model(x, y)
+      model.compile(optimizer='rmsprop', loss='mse')
+      model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+      # Test basic case serialization.
+      x_np = np.random.random((6, 5, 5))
+      y_np = model.predict(x_np)
+      weights = model.get_weights()
+      config = layer.get_config()
+      with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}):
+        layer = keras.layers.RNN.from_config(config)
+      y = layer(x)
+      model = keras.models.Model(x, y)
+      model.set_weights(weights)
+      y_np_2 = model.predict(x_np)
+      self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+      # Test stacking.
+      cells = [MinimalRNNCell(8),
+               MinimalRNNCell(12),
+               MinimalRNNCell(32)]
+      layer = keras.layers.RNN(cells)
+      y = layer(x)
+      model = keras.models.Model(x, y)
+      model.compile(optimizer='rmsprop', loss='mse')
+      model.train_on_batch(np.zeros((6, 5, 5)), np.zeros((6, 32)))
+
+      # Test stacked RNN serialization.
+      x_np = np.random.random((6, 5, 5))
+      y_np = model.predict(x_np)
+      weights = model.get_weights()
+      config = layer.get_config()
+      with keras.utils.CustomObjectScope({'MinimalRNNCell': MinimalRNNCell}):
+        layer = keras.layers.RNN.from_config(config)
+      y = layer(x)
+      model = keras.models.Model(x, y)
+      model.set_weights(weights)
+      y_np_2 = model.predict(x_np)
+      self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+  def test_rnn_cell_with_constants_layer(self):
+
+    class RNNCellWithConstants(keras.layers.Layer):
+
+      def __init__(self, units, **kwargs):
+        self.units = units
+        self.state_size = units
+        super(RNNCellWithConstants, self).__init__(**kwargs)
+
+      def build(self, input_shape):
+        if not isinstance(input_shape, list):
+          raise TypeError('expects constants shape')
+        [input_shape, constant_shape] = input_shape
+        # will (and should) raise if more than one constant passed
+
+        self.input_kernel = self.add_weight(
+            shape=(input_shape[-1], self.units),
+            initializer='uniform',
+            name='kernel')
+        self.recurrent_kernel = self.add_weight(
+            shape=(self.units, self.units),
+            initializer='uniform',
+            name='recurrent_kernel')
+        self.constant_kernel = self.add_weight(
+            shape=(constant_shape[-1], self.units),
+            initializer='uniform',
+            name='constant_kernel')
+        self.built = True
+
+      def call(self, inputs, states, constants):
+        [prev_output] = states
+        [constant] = constants
+        h_input = keras.backend.dot(inputs, self.input_kernel)
+        h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
+        h_const = keras.backend.dot(constant, self.constant_kernel)
+        output = h_input + h_state + h_const
+        return output, [output]
+
+      def get_config(self):
+        config = {'units': self.units}
+        base_config = super(RNNCellWithConstants, self).get_config()
+        return dict(list(base_config.items()) + list(config.items()))
+
+    with self.test_session():
+      # Test basic case.
+      x = keras.Input((None, 5))
+      c = keras.Input((3,))
+      cell = RNNCellWithConstants(32)
+      layer = keras.layers.RNN(cell)
+      y = layer(x, constants=c)
+      model = keras.models.Model([x, c], y)
+      model.compile(optimizer='rmsprop', loss='mse')
+      model.train_on_batch(
+          [np.zeros((6, 5, 5)), np.zeros((6, 3))],
+          np.zeros((6, 32))
+      )
+
+    with self.test_session():
+      # Test basic case serialization.
+      x_np = np.random.random((6, 5, 5))
+      c_np = np.random.random((6, 3))
+      y_np = model.predict([x_np, c_np])
+      weights = model.get_weights()
+      config = layer.get_config()
+      custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
+      with keras.utils.CustomObjectScope(custom_objects):
+        layer = keras.layers.RNN.from_config(config.copy())
+      y = layer(x, constants=c)
+      model = keras.models.Model([x, c], y)
+      model.set_weights(weights)
+      y_np_2 = model.predict([x_np, c_np])
+      self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+    with self.test_session():
+      # test flat list inputs
+      with keras.utils.CustomObjectScope(custom_objects):
+        layer = keras.layers.RNN.from_config(config.copy())
+      y = layer([x, c])
+      model = keras.models.Model([x, c], y)
+      model.set_weights(weights)
+      y_np_3 = model.predict([x_np, c_np])
+      self.assertAllClose(y_np, y_np_3, atol=1e-4)
+
+  def test_rnn_cell_with_constants_layer_passing_initial_state(self):
+
+    class RNNCellWithConstants(keras.layers.Layer):
+
+      def __init__(self, units, **kwargs):
+        self.units = units
+        self.state_size = units
+        super(RNNCellWithConstants, self).__init__(**kwargs)
+
+      def build(self, input_shape):
+        if not isinstance(input_shape, list):
+          raise TypeError('expects constants shape')
+        [input_shape, constant_shape] = input_shape
+        # will (and should) raise if more than one constant passed
+
+        self.input_kernel = self.add_weight(
+            shape=(input_shape[-1], self.units),
+            initializer='uniform',
+            name='kernel')
+        self.recurrent_kernel = self.add_weight(
+            shape=(self.units, self.units),
+            initializer='uniform',
+            name='recurrent_kernel')
+        self.constant_kernel = self.add_weight(
+            shape=(constant_shape[-1], self.units),
+            initializer='uniform',
+            name='constant_kernel')
+        self.built = True
+
+      def call(self, inputs, states, constants):
+        [prev_output] = states
+        [constant] = constants
+        h_input = keras.backend.dot(inputs, self.input_kernel)
+        h_state = keras.backend.dot(prev_output, self.recurrent_kernel)
+        h_const = keras.backend.dot(constant, self.constant_kernel)
+        output = h_input + h_state + h_const
+        return output, [output]
+
+      def get_config(self):
+        config = {'units': self.units}
+        base_config = super(RNNCellWithConstants, self).get_config()
+        return dict(list(base_config.items()) + list(config.items()))
+
+    with self.test_session():
+      # Test basic case.
+      x = keras.Input((None, 5))
+      c = keras.Input((3,))
+      s = keras.Input((32,))
+      cell = RNNCellWithConstants(32)
+      layer = keras.layers.RNN(cell)
+      y = layer(x, initial_state=s, constants=c)
+      model = keras.models.Model([x, s, c], y)
+      model.compile(optimizer='rmsprop', loss='mse')
+      model.train_on_batch(
+          [np.zeros((6, 5, 5)), np.zeros((6, 32)), np.zeros((6, 3))],
+          np.zeros((6, 32))
+      )
+
+    with self.test_session():
+      # Test basic case serialization.
+      x_np = np.random.random((6, 5, 5))
+      s_np = np.random.random((6, 32))
+      c_np = np.random.random((6, 3))
+      y_np = model.predict([x_np, s_np, c_np])
+      weights = model.get_weights()
+      config = layer.get_config()
+      custom_objects = {'RNNCellWithConstants': RNNCellWithConstants}
+      with keras.utils.CustomObjectScope(custom_objects):
+        layer = keras.layers.RNN.from_config(config.copy())
+      y = layer(x, initial_state=s, constants=c)
+      model = keras.models.Model([x, s, c], y)
+      model.set_weights(weights)
+      y_np_2 = model.predict([x_np, s_np, c_np])
+      self.assertAllClose(y_np, y_np_2, atol=1e-4)
+
+      # verify that state is used
+      y_np_2_different_s = model.predict([x_np, s_np + 10., c_np])
+      with self.assertRaises(AssertionError):
+        self.assertAllClose(y_np, y_np_2_different_s, atol=1e-4)
+
+    with self.test_session():
+      # test flat list inputs
+      with keras.utils.CustomObjectScope(custom_objects):
+        layer = keras.layers.RNN.from_config(config.copy())
+      y = layer([x, s, c])
+      model = keras.models.Model([x, s, c], y)
+      model.set_weights(weights)
+      y_np_3 = model.predict([x_np, s_np, c_np])
+      self.assertAllClose(y_np, y_np_3, atol=1e-4)
+
+  def test_stacked_rnn_attributes(self):
+    cells = [keras.layers.LSTMCell(3),
+             keras.layers.LSTMCell(3, kernel_regularizer='l2')]
+    layer = keras.layers.RNN(cells)
+    layer.build((None, None, 5))
+
+    # Test regularization losses
+    assert len(layer.losses) == 1
+
+    # Test weights
+    assert len(layer.trainable_weights) == 6
+    cells[0].trainable = False
+    assert len(layer.trainable_weights) == 3
+    assert len(layer.non_trainable_weights) == 3
+
+    # Test `get_losses_for`
+    x = keras.Input((None, 5))
+    y = keras.backend.sum(x)
+    cells[0].add_loss(y, inputs=x)
+    assert layer.get_losses_for(x) == [y]
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py
index 9833485236b..7edebdacd07 100644
--- a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py
+++ b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py
@@ -156,8 +156,10 @@ class SimpleRNNLayerTest(test.TestCase):
           activity_regularizer='l1')
       layer.build((None, None, 2))
       self.assertEqual(len(layer.losses), 3)
-      layer(keras.backend.variable(np.ones((2, 3, 2))))
-      self.assertEqual(len(layer.losses), 4)
+
+      x = keras.backend.variable(np.ones((2, 3, 2)))
+      layer(x)
+      self.assertEqual(len(layer.get_losses_for(x)), 1)
 
   def test_constraints_SimpleRNN(self):
     embedding_dim = 4
@@ -175,9 +177,9 @@ class SimpleRNNLayerTest(test.TestCase):
           recurrent_constraint=r_constraint,
           bias_constraint=b_constraint)
       layer.build((None, None, embedding_dim))
-      self.assertEqual(layer.kernel.constraint, k_constraint)
-      self.assertEqual(layer.recurrent_kernel.constraint, r_constraint)
-      self.assertEqual(layer.bias.constraint, b_constraint)
+      self.assertEqual(layer.cell.kernel.constraint, k_constraint)
+      self.assertEqual(layer.cell.recurrent_kernel.constraint, r_constraint)
+      self.assertEqual(layer.cell.bias.constraint, b_constraint)
 
   def test_with_masking_layer_SimpleRNN(self):
     layer_class = keras.layers.SimpleRNN
diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py
index acf0a5e1799..b94bf8f0f67 100644
--- a/tensorflow/python/keras/layers/__init__.py
+++ b/tensorflow/python/keras/layers/__init__.py
@@ -134,6 +134,11 @@ from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool2D
 from tensorflow.python.keras._impl.keras.layers.pooling import GlobalMaxPool3D
 
 # Recurrent layers.
+from tensorflow.python.keras._impl.keras.layers.recurrent import RNN
+from tensorflow.python.keras._impl.keras.layers.recurrent import StackedRNNCells
+from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNNCell
+from tensorflow.python.keras._impl.keras.layers.recurrent import GRUCell
+from tensorflow.python.keras._impl.keras.layers.recurrent import LSTMCell
 from tensorflow.python.keras._impl.keras.layers.recurrent import SimpleRNN
 from tensorflow.python.keras._impl.keras.layers.recurrent import GRU
 from tensorflow.python.keras._impl.keras.layers.recurrent import LSTM
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 8c8d774b754..c71e8382e91 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -642,7 +642,7 @@ class Layer(object):
             for output in output_list:
               with ops.name_scope('ActivityRegularizer'):
                 activity_regularization = self._activity_regularizer(output)
-              self.add_loss(activity_regularization)
+              self.add_loss(activity_regularization, inputs=inputs)
 
         if not in_deferred_mode:
           # TODO(fchollet): consider how masking will work with deferred mode.
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 71eff2f9657..7ddfe37827d 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -574,6 +574,13 @@ class BaseLayerTest(test.TestCase):
       self.assertEqual(3, result['label'].numpy())
       self.assertEqual(4.0, result['logits'].numpy())
 
+  def testActivityRegularizer(self):
+    regularizer = math_ops.reduce_sum
+    layer = base_layers.Layer(activity_regularizer=regularizer)
+    x = array_ops.placeholder('int32')
+    layer.apply(x)
+    self.assertEqual(len(layer.get_losses_for(x)), 1)
+
 
 class NetworkTest(test.TestCase):
 
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
new file mode 100644
index 00000000000..763184899ca
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -0,0 +1,179 @@
+path: "tensorflow.keras.layers.GRUCell"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.GRUCell\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+  is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "activity_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "inbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "losses"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "outbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "scope_name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "weights"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
index 92373992548..889f2cbc234 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -1,14 +1,34 @@
 path: "tensorflow.keras.layers.GRU"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.GRU\'>"
-  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.Recurrent\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
   is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
   is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
   is_instance: "<type \'object\'>"
+  member {
+    name: "activation"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "activity_regularizer"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "bias_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "bias_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "bias_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dropout"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "dtype"
     mtype: "<type \'property\'>"
@@ -17,6 +37,10 @@ tf_class {
     name: "graph"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "implementation"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "inbound_nodes"
     mtype: "<type \'property\'>"
@@ -33,6 +57,18 @@ tf_class {
     name: "input_shape"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "kernel_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "kernel_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "kernel_regularizer"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "losses"
     mtype: "<type \'property\'>"
@@ -65,10 +101,34 @@ tf_class {
     name: "output_shape"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "recurrent_activation"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_dropout"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_regularizer"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "scope_name"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "states"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "trainable_variables"
     mtype: "<type \'property\'>"
@@ -77,10 +137,18 @@ tf_class {
     name: "trainable_weights"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "units"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "updates"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "use_bias"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "variables"
     mtype: "<type \'property\'>"
@@ -91,7 +159,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+    argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
   }
   member_method {
     name: "add_loss"
@@ -137,10 +205,6 @@ tf_class {
     name: "get_config"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "get_constants"
-    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
   member_method {
     name: "get_initial_state"
     argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
@@ -159,7 +223,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "get_output_at"
@@ -181,10 +245,6 @@ tf_class {
     name: "get_weights"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "preprocess_input"
-    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
   member_method {
     name: "reset_states"
     argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
@@ -193,8 +253,4 @@ tf_class {
     name: "set_weights"
     argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "step"
-    argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
-  }
 }
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
new file mode 100644
index 00000000000..4ce7c34f6c7
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -0,0 +1,179 @@
+path: "tensorflow.keras.layers.LSTMCell"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.LSTMCell\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+  is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "activity_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "inbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "losses"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "outbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "scope_name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "weights"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 20935e2f99a..e1a1d0d58ec 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -1,14 +1,34 @@
 path: "tensorflow.keras.layers.LSTM"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.LSTM\'>"
-  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.Recurrent\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
   is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
   is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
   is_instance: "<type \'object\'>"
+  member {
+    name: "activation"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "activity_regularizer"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "bias_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "bias_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "bias_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dropout"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "dtype"
     mtype: "<type \'property\'>"
@@ -17,6 +37,10 @@ tf_class {
     name: "graph"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "implementation"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "inbound_nodes"
     mtype: "<type \'property\'>"
@@ -33,6 +57,18 @@ tf_class {
     name: "input_shape"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "kernel_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "kernel_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "kernel_regularizer"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "losses"
     mtype: "<type \'property\'>"
@@ -65,10 +101,34 @@ tf_class {
     name: "output_shape"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "recurrent_activation"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_dropout"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_regularizer"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "scope_name"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "states"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "trainable_variables"
     mtype: "<type \'property\'>"
@@ -77,10 +137,22 @@ tf_class {
     name: "trainable_weights"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "unit_forget_bias"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "units"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "updates"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "use_bias"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "variables"
     mtype: "<type \'property\'>"
@@ -91,7 +163,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+    argspec: "args=[\'self\', \'units\', \'activation\', \'recurrent_activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'unit_forget_bias\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'implementation\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'hard_sigmoid\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'1\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
   }
   member_method {
     name: "add_loss"
@@ -137,10 +209,6 @@ tf_class {
     name: "get_config"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "get_constants"
-    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
   member_method {
     name: "get_initial_state"
     argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
@@ -159,7 +227,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "get_output_at"
@@ -181,10 +249,6 @@ tf_class {
     name: "get_weights"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "preprocess_input"
-    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
   member_method {
     name: "reset_states"
     argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
@@ -193,8 +257,4 @@ tf_class {
     name: "set_weights"
     argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "step"
-    argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
-  }
 }
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
new file mode 100644
index 00000000000..c7c9b10f22d
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -0,0 +1,191 @@
+path: "tensorflow.keras.layers.RNN"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+  is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "activity_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "inbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "losses"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "outbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "scope_name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "states"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "weights"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'activity_regularizer\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'initial_state\', \'constants\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_initial_state"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reset_states"
+    argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
new file mode 100644
index 00000000000..10c7f8867cb
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -0,0 +1,179 @@
+path: "tensorflow.keras.layers.SimpleRNNCell"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.SimpleRNNCell\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+  is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "activity_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "inbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "losses"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "outbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "scope_name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "weights"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'states\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index f4148fcc230..588df21088f 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -1,14 +1,34 @@
 path: "tensorflow.keras.layers.SimpleRNN"
 tf_class {
   is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.SimpleRNN\'>"
-  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.Recurrent\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.RNN\'>"
   is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
   is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
   is_instance: "<type \'object\'>"
+  member {
+    name: "activation"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "activity_regularizer"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "bias_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "bias_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "bias_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dropout"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "dtype"
     mtype: "<type \'property\'>"
@@ -33,6 +53,18 @@ tf_class {
     name: "input_shape"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "kernel_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "kernel_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "kernel_regularizer"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "losses"
     mtype: "<type \'property\'>"
@@ -65,10 +97,30 @@ tf_class {
     name: "output_shape"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "recurrent_constraint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_dropout"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_initializer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "recurrent_regularizer"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "scope_name"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "states"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "trainable_variables"
     mtype: "<type \'property\'>"
@@ -77,10 +129,18 @@ tf_class {
     name: "trainable_weights"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "units"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "updates"
     mtype: "<type \'property\'>"
   }
+  member {
+    name: "use_bias"
+    mtype: "<type \'property\'>"
+  }
   member {
     name: "variables"
     mtype: "<type \'property\'>"
@@ -91,7 +151,7 @@ tf_class {
   }
   member_method {
     name: "__init__"
-    argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\'], "
+    argspec: "args=[\'self\', \'units\', \'activation\', \'use_bias\', \'kernel_initializer\', \'recurrent_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'recurrent_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'recurrent_constraint\', \'bias_constraint\', \'dropout\', \'recurrent_dropout\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'tanh\', \'True\', \'glorot_uniform\', \'orthogonal\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'0.0\', \'0.0\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
   }
   member_method {
     name: "add_loss"
@@ -137,10 +197,6 @@ tf_class {
     name: "get_config"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "get_constants"
-    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
   member_method {
     name: "get_initial_state"
     argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
@@ -159,7 +215,7 @@ tf_class {
   }
   member_method {
     name: "get_losses_for"
-    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
   }
   member_method {
     name: "get_output_at"
@@ -181,10 +237,6 @@ tf_class {
     name: "get_weights"
     argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "preprocess_input"
-    argspec: "args=[\'self\', \'inputs\', \'training\'], varargs=None, keywords=None, defaults=[\'None\'], "
-  }
   member_method {
     name: "reset_states"
     argspec: "args=[\'self\', \'states\'], varargs=None, keywords=None, defaults=[\'None\'], "
@@ -193,8 +245,4 @@ tf_class {
     name: "set_weights"
     argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
   }
-  member_method {
-    name: "step"
-    argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=None, defaults=None"
-  }
 }
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
new file mode 100644
index 00000000000..5779e413422
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -0,0 +1,183 @@
+path: "tensorflow.keras.layers.StackedRNNCells"
+tf_class {
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.layers.recurrent.StackedRNNCells\'>"
+  is_instance: "<class \'tensorflow.python.keras._impl.keras.engine.topology.Layer\'>"
+  is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "activity_regularizer"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "dtype"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "inbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "input_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "losses"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "non_trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "outbound_nodes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_mask"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "output_shape"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "scope_name"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "state_size"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "trainable_weights"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "updates"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "variables"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "weights"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'cells\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "add_loss"
+    argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_update"
+    argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "add_variable"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "add_weight"
+    argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
+  }
+  member_method {
+    name: "apply"
+    argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "build"
+    argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call"
+    argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "compute_mask"
+    argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "count_params"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "from_config"
+    argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_config"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_input_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_losses_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "get_output_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_mask_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_output_shape_at"
+    argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_updates_for"
+    argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_weights"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "set_weights"
+    argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
index 8466c3e0390..fe336c4be5a 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.pbtxt
@@ -140,6 +140,10 @@ tf_module {
     name: "GRU"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "GRUCell"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "GaussianDropout"
     mtype: "<type \'type\'>"
@@ -208,6 +212,10 @@ tf_module {
     name: "LSTM"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "LSTMCell"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "Lambda"
     mtype: "<type \'type\'>"
@@ -272,6 +280,10 @@ tf_module {
     name: "Permute"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "RNN"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "RepeatVector"
     mtype: "<type \'type\'>"
@@ -292,6 +304,10 @@ tf_module {
     name: "SimpleRNN"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "SimpleRNNCell"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "SpatialDropout1D"
     mtype: "<type \'type\'>"
@@ -304,6 +320,10 @@ tf_module {
     name: "SpatialDropout3D"
     mtype: "<type \'type\'>"
   }
+  member {
+    name: "StackedRNNCells"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "ThresholdedReLU"
     mtype: "<type \'type\'>"
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index f1c207f9b68..8d4e4c23dc3 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -98,7 +98,8 @@ do_pylint() {
 "^tensorflow/contrib/eager/python/evaluator\.py.*\[E0202.*method-hidden "\
 "^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\
 "^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
-"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable"
+"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
+"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition"
 
   echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
 

From 8bb665ae1c8f2aedd479b5bfe2403ac54e37319e Mon Sep 17 00:00:00 2001
From: Jianwei Xie <xiejw@google.com>
Date: Wed, 8 Nov 2017 15:19:12 -0800
Subject: [PATCH 052/115] Improve usability of TPUEstimator.

1) Log how many batches to enqueue. The old message is very confusing.
2) If input_pipeline has queue runner, generate a logging (legacy mode) or error out (new mode)
3) If input pipeline has summaries, generate a logging (legacy mode) or error out (new mode)

PiperOrigin-RevId: 175073856
---
 .../contrib/tpu/python/tpu/tpu_estimator.py   | 35 +++++++++++++++++--
 1 file changed, 32 insertions(+), 3 deletions(-)

diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 5a3b8314291..16d712af9e2 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -535,13 +535,15 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
           session, self._dequeue_ops)
 
   def before_run(self, run_context):
-    logging.info('Enqueue next batch of data to infeed.')
-
     iterations = run_context.session.run(self._iterations_per_loop_var)
+
+    logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
+
     self._infeed_thd_controller.send_next_batch_signal(iterations)
     if self._dequeue_ops is not None:
       # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop.
-      logging.info('Dequeue next batch of data from outfeed.')
+      logging.info(
+          'Dequeue next (%d) batch(es) of data from outfeed.', iterations)
       self._outfeed_thd_controller.send_next_batch_signal(iterations)
 
   def end(self, session):
@@ -842,6 +844,8 @@ class _InputPipeline(object):
     # structure is recorded.
     enqueue_ops = self._invoke_input_fn_and_record_structure()
 
+    self._validate_input_pipeline()
+
     def dequeue_fn():
       """dequeue_fn is used by TPU to retrieve the tensors."""
       values = self._infeed_queue.generate_dequeue_op()
@@ -920,6 +924,31 @@ class _InputPipeline(object):
       else:
         return enqueue_fn()
 
+  def _validate_input_pipeline(self):
+    # Perform some sanity checks to log user friendly information. We should
+    # error out to give users better error message. But, if
+    # _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
+    # user code, so, log a warning.
+    if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
+      err_msg = ('Input pipeline contains one or more QueueRunners. '
+                 'These are not supported via TPUEstimator. You must convert '
+                 'your input pipeline to use `tf.data` instead (see '
+                 'https://www.tensorflow.org/programmers_guide/datasets for '
+                 'instructions.')
+      if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+        raise RuntimeError(err_msg)
+      else:
+        logging.warn(err_msg)
+    elif ops.get_default_graph().get_collection(ops.GraphKeys.SUMMARIES):
+      # Queue Runner has summary Ops by default. So here we use elif to do
+      # necessary checks for Dataset input pipeline only.
+      err_msg = ('Input pipeline contains `tf.summary` operations. '
+                 'These are not currently supported.')
+      if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+        raise RuntimeError(err_msg)
+      else:
+        logging.warn(err_msg)
+
 
 class _ModelFnWrapper(object):
   """A `model_fn` wrapper.

From 12d6b450b2be345b3848efd8d623b1507a2c630f Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 15:24:01 -0800
Subject: [PATCH 053/115] Hlo parser: support window and convolution.

Also, to make the text format easier to write and unambiguous:
- Print "window={}" around the window attribute; rename the "window" sub attribute to "size";
- Print the dim_lables in logical order, instead of physical order.

PiperOrigin-RevId: 175074526
---
 .../compiler/xla/service/hlo_instruction.cc   |  10 +-
 .../compiler/xla/tools/parser/README.md       |  16 +-
 .../compiler/xla/tools/parser/hlo_lexer.cc    |  65 +-
 .../compiler/xla/tools/parser/hlo_lexer.h     |   6 +-
 .../compiler/xla/tools/parser/hlo_parser.cc   | 589 ++++++++++++++----
 .../xla/tools/parser/hlo_parser_test.cc       | 120 ++++
 .../compiler/xla/tools/parser/hlo_token.h     |   3 +
 tensorflow/compiler/xla/window_util.cc        |  26 +-
 8 files changed, 690 insertions(+), 145 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 5107ac782d7..ee98c3fabc5 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1850,7 +1850,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
     extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
   }
   if (window_ != nullptr) {
-    extra.push_back(window_util::ToString(*window_));
+    extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
   }
   if (padding_config_ != nullptr) {
     extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
@@ -2856,13 +2856,7 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
   const auto append_dims = [&](const std::vector<string>& dims,
                                const Shape& shape) {
     CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
-    for (int64 logical = 0; logical < dims.size(); ++logical) {
-      int64 physical = logical;
-      if (!shape.layout().minor_to_major().empty()) {
-        physical = LayoutUtil::Major(shape.layout(), logical);
-      }
-      result += dims[physical];
-    }
+    StrAppend(&result, Join(dims, ""));
   };
 
   // lhs_dims[i] is the symbol of the logical dimension i for the lhs
diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md
index 2c864d77a20..986041caf61 100644
--- a/tensorflow/compiler/xla/tools/parser/README.md
+++ b/tensorflow/compiler/xla/tools/parser/README.md
@@ -43,14 +43,22 @@ operand
   : shape name
   ;
 
-extra_attributes
+attributes
   : /*empty*/
-  | ',' extra_attribute
-  | ',' extra_attribute extra_attributes
+  | ',' attribute
+  | ',' attribute attributes
   ;
-extra_attribute
+attribute
   : attribute_name attribute_value
   ;
+attribute_value
+  : kInt
+  | kName
+  | [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,} /*dim_labels_pattern*/
+  | [0-9]+(x[0-9]+)+                     /*dxd_pattern*/
+  | [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*       /*window_pad_pattern*/
+  | '{' sub_attributes '}'
+  ;
 
 param_list
   : '(' param_list1 ')'
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
index d104ff34601..f70386411cf 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
@@ -122,7 +122,7 @@ TokKind HloLexer::LexToken() {
           current_ptr_++;
           return TokKind::kArrow;
         }
-        return LexDigitOrNegative();
+        return LexNumberOrPattern();
       case '=':
         return TokKind::kEqual;
       case ',':
@@ -149,12 +149,15 @@ TokKind HloLexer::LexToken() {
   }
 }
 
-// Lex a shape, name, keyword, or opcode.
+// Lex a shape, name, keyword, opcode, attribute name, or the dim labels
+// pattern.
+//
 // shape    ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
 // name     ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
 // keyword  ::= HloModule, ENTRY, ...
 // opcode   ::= add, greater-than, ...
 // attribute_name ::= condition, body, dimensions, ...
+// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}
 TokKind HloLexer::LexIdentifier() {
   {
     auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
@@ -220,6 +223,16 @@ TokKind HloLexer::LexIdentifier() {
     return TokKind::kOpcode;
   }
 
+  {
+    auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
+    static LazyRE2 dim_labels_pattern = {
+        R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"};
+    if (RE2::Consume(&consumable, *dim_labels_pattern)) {
+      current_ptr_ = consumable.begin();
+      str_val_.assign(token_start_, current_ptr_);
+      return TokKind::kDimLabels;
+    }
+  }
   current_ptr_ = token_start_ + 1;
   return TokKind::kError;
 }
@@ -240,15 +253,20 @@ TokKind HloLexer::LexPercent() {
   return TokKind::kError;
 }
 
-// Lex integer and floating-point values, and -inf.
-// int             [-]?[0-9]+
-// fp with exp     [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
-// fp without exp  [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
-// negative inf    -inf
-TokKind HloLexer::LexDigitOrNegative() {
+// Lex integer and floating-point values, -inf, and patterns for dim labels,
+// dxd (e.g. 1x2x3), and window pad.
+//
+// fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
+// fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
+// dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}
+// dxd_pattern ::= [0-9]+(x[0-9]+)+
+// window_pad_pattern ::= [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*
+// int ::=  [-]?[0-9]+
+// negative inf ::= '-inf'
+TokKind HloLexer::LexNumberOrPattern() {
   auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
   static LazyRE2 float_pattern = {
-      R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"};
+      R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
   if (RE2::Consume(&consumable, *float_pattern)) {
     current_ptr_ = consumable.begin();
     tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
@@ -256,6 +274,29 @@ TokKind HloLexer::LexDigitOrNegative() {
     return TokKind::kDecimal;
   }
 
+  static LazyRE2 dim_labels_pattern = {
+      R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"};
+  static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
+  static LazyRE2 pad_pattern = {R"([0-9]+_[0-9]+(x[0-9]+_[0-9]+)*)"};
+
+  if (RE2::Consume(&consumable, *dim_labels_pattern)) {
+    current_ptr_ = consumable.begin();
+    str_val_.assign(token_start_, current_ptr_);
+    return TokKind::kDimLabels;
+  }
+
+  if (RE2::Consume(&consumable, *dxd_pattern)) {
+    current_ptr_ = consumable.begin();
+    str_val_.assign(token_start_, current_ptr_);
+    return TokKind::kDxD;
+  }
+
+  if (RE2::Consume(&consumable, *pad_pattern)) {
+    current_ptr_ = consumable.begin();
+    str_val_.assign(token_start_, current_ptr_);
+    return TokKind::kWindowPad;
+  }
+
   static LazyRE2 int_pattern = {R"([-]?\d+)"};
   if (RE2::Consume(&consumable, *int_pattern)) {
     current_ptr_ = consumable.begin();
@@ -350,6 +391,12 @@ string TokKindToString(TokKind kind) {
       return "kName";
     case TokKind::kAttributeName:
       return "kAttributeName";
+    case TokKind::kDimLabels:
+      return "kDimLabels";
+    case TokKind::kDxD:
+      return "kDxD";
+    case TokKind::kWindowPad:
+      return "kWindowPad";
     case TokKind::kShape:
       return "kShape";
     case TokKind::kOpcode:
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
index 3b9efcb92d0..74e6829180a 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
@@ -37,11 +37,15 @@ class HloLexer {
   }
 
   TokKind Lex() { return current_kind_ = LexToken(); }
+
   TokKind GetKind() const { return current_kind_; }
   string GetStrVal() const {
     switch (GetKind()) {
       case TokKind::kName:
       case TokKind::kAttributeName:
+      case TokKind::kDimLabels:
+      case TokKind::kDxD:
+      case TokKind::kWindowPad:
         return str_val_;
       default:
         LOG(FATAL) << "This token does not have string value";
@@ -92,7 +96,7 @@ class HloLexer {
   TokKind LexPercent();
   TokKind LexShape();
   TokKind LexConstant();
-  TokKind LexDigitOrNegative();
+  TokKind LexNumberOrPattern();
   TokKind LexComment();
 
   const tensorflow::StringPiece buf_;
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 6c2e37e3b5c..f1e987cb15c 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -28,6 +28,9 @@ namespace tools {
 namespace {
 
 using tensorflow::StringPiece;
+using tensorflow::gtl::optional;
+using tensorflow::str_util::Split;
+using tensorflow::str_util::SplitAndParseAsInts;
 using tensorflow::strings::Printf;
 using tensorflow::strings::StrAppend;
 using tensorflow::strings::StrCat;
@@ -57,8 +60,6 @@ class HloParser {
   bool ParseInstructionList(HloComputation::Builder* builder,
                             string* root_name);
   bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
-  bool ParseSharding(HloInstruction* instruction);
-  bool ParseControlPredecessors(HloInstruction* instruction);
   bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
   bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
   bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
@@ -78,10 +79,55 @@ class HloParser {
   bool ParseOperands(std::vector<HloInstruction*>* operands,
                      const int expected_size);
 
-  template <typename T>
-  bool ParseExtraAttribute(T* value, const string& expected_attribute);
-  template <typename T>
-  bool ParseAttributeValue(T* value);
+  // Types of attributes.
+  enum class AttrTy {
+    kInt64,
+    kHloComputation,
+    kWindow,
+    kConvolutionDimensionNumbers,
+    kSharding,
+    kInstructionList,
+  };
+
+  struct AttrConfig {
+    bool required;     // whether it's required or optional
+    AttrTy attr_type;  // what type it is
+    void* result;      // where to store the parsed result.
+  };
+
+  // Parses attributes given names and configs of the attributes. Each parsed
+  // result is passed back through the result pointer in corresponding
+  // AttrConfig. Note that the result pointer must point to a optional<T> typed
+  // variable which outlives this function. Returns false on error. You should
+  // not use the any of the results if this function failed.
+  //
+  // Example usage:
+  //
+  //  std::unordered_map<string, AttrConfig> attrs;
+  //  optional<int64> foo;
+  //  attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
+  //  optional<Window> bar;
+  //  attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
+  //  if (!ParseAttribute(attrs)) {
+  //    return false; // Do not use 'foo' 'bar' if failed.
+  //  }
+  //  // Do something with 'bar'.
+  //  if (foo) { // If attr foo is seen, do something with 'foo'. }
+  //
+  bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
+
+  // Parses a name and finds the corresponding hlo computation.
+  bool ParseComputationName(HloComputation** value);
+  // Parses a list of names and finds the corresponding hlo instructions.
+  bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
+  bool ParseWindow(Window* window);
+  bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
+  bool ParseSharding(OpSharding* sharding);
+
+  // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
+  bool ParseDxD(const string& name, std::vector<int64>* result);
+  // Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
+  bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
 
   bool ParseParamList();
   bool ParseName(string* result);
@@ -214,7 +260,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
                     "expects '}' at the end of instruction list.");
 }
 
-// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)*
+// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
 bool HloParser::ParseInstruction(HloComputation::Builder* builder,
                                  string* root_name) {
   string name;
@@ -230,6 +276,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
   if (is_root) {
     *root_name = name;
   }
+
+  // Add optional attributes.
+  std::unordered_map<string, AttrConfig> attrs;
+  optional<OpSharding> sharding;
+  attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
+  optional<std::vector<HloInstruction*>> predecessors;
+  attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
+                                   &predecessors};
+
   HloInstruction* instruction;
   switch (opcode) {
     case HloOpcode::kParameter: {
@@ -237,7 +292,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
       if (!ParseToken(TokKind::kLparen,
                       "expects '(' before parameter number") ||
           !ParseInt64(&parameter_number) ||
-          !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) {
+          !ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
@@ -249,7 +305,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
       if (!ParseToken(TokKind::kLparen,
                       "expects '(' before constant literal") ||
           !ParseLiteral(&literal, shape) ||
-          !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) {
+          !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
@@ -275,7 +332,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
     case HloOpcode::kSin:
     case HloOpcode::kSort:
     case HloOpcode::kTanh: {
-      if (!ParseOperands(&operands, /*expected_size=*/1)) {
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
@@ -305,7 +363,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
     case HloOpcode::kShiftLeft:
     case HloOpcode::kShiftRightArithmetic:
     case HloOpcode::kShiftRightLogical: {
-      if (!ParseOperands(&operands, /*expected_size=*/2)) {
+      if (!ParseOperands(&operands, /*expected_size=*/2) ||
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(HloInstruction::CreateBinary(
@@ -315,7 +374,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
     // Ternary ops.
     case HloOpcode::kClamp:
     case HloOpcode::kSelect: {
-      if (!ParseOperands(&operands, /*expected_size=*/3)) {
+      if (!ParseOperands(&operands, /*expected_size=*/3) ||
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(HloInstruction::CreateTernary(
@@ -324,7 +384,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
     }
     // Other supported ops.
     case HloOpcode::kConvert: {
-      if (!ParseOperands(&operands, /*expected_size=*/1)) {
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
@@ -332,7 +393,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
       break;
     }
     case HloOpcode::kCrossReplicaSum: {
-      if (!ParseOperands(&operands, /*expected_size=*/1)) {
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
@@ -340,7 +402,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
       break;
     }
     case HloOpcode::kReshape: {
-      if (!ParseOperands(&operands, /*expected_size=*/1)) {
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
@@ -348,7 +411,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
       break;
     }
     case HloOpcode::kTuple: {
-      if (!ParseOperands(&operands)) {
+      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
         return false;
       }
       instruction =
@@ -356,70 +419,99 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
       break;
     }
     case HloOpcode::kWhile: {
-      HloComputation* condition;
-      HloComputation* body;
+      optional<HloComputation*> condition;
+      optional<HloComputation*> body;
+      attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
+                            &condition};
+      attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
       if (!ParseOperands(&operands, /*expected_size=*/1) ||
-          !ParseExtraAttribute(&condition,
-                               /*expected_attribute=*/"condition") ||
-          !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) {
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(HloInstruction::CreateWhile(
-          shape, condition, body, /*init=*/operands[0]));
+          shape, *condition, *body, /*init=*/operands[0]));
       break;
     }
     case HloOpcode::kRecv: {
-      int64 channel_id;
+      optional<int64> channel_id;
+      attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
       if (!ParseOperands(&operands, /*expected_size=*/0) ||
-          !ParseExtraAttribute(&channel_id,
-                               /*expected_attribute=*/"channel_id")) {
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
-          HloInstruction::CreateRecv(shape, channel_id));
+          HloInstruction::CreateRecv(shape, *channel_id));
       break;
     }
     case HloOpcode::kSend: {
-      int64 channel_id;
+      optional<int64> channel_id;
+      attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
       if (!ParseOperands(&operands, /*expected_size=*/1) ||
-          !ParseExtraAttribute(&channel_id,
-                               /*expected_attribute=*/"channel_id")) {
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
-          HloInstruction::CreateSend(operands[0], channel_id));
+          HloInstruction::CreateSend(operands[0], *channel_id));
       break;
     }
     case HloOpcode::kGetTupleElement: {
-      int64 index;
+      optional<int64> index;
+      attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
       if (!ParseOperands(&operands, /*expected_size=*/1) ||
-          !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) {
+          !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
-          HloInstruction::CreateGetTupleElement(shape, operands[0], index));
+          HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
       break;
     }
     case HloOpcode::kCall: {
-      HloComputation* to_apply;
-      if (!ParseOperands(&operands) ||
-          !ParseExtraAttribute(&to_apply,
-                               /*expected_attribute=*/"to_apply")) {
+      optional<HloComputation*> to_apply;
+      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+                           &to_apply};
+      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
         return false;
       }
       instruction = builder->AddInstruction(
-          HloInstruction::CreateCall(shape, operands, to_apply));
+          HloInstruction::CreateCall(shape, operands, *to_apply));
+      break;
+    }
+    case HloOpcode::kReduceWindow: {
+      optional<HloComputation*> reduce_computation;
+      optional<Window> window;
+      attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
+      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+                           &reduce_computation};
+      if (!ParseOperands(&operands, /*expected_size=*/2) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
+          shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
+          *reduce_computation));
+      break;
+    }
+    case HloOpcode::kConvolution: {
+      optional<Window> window;
+      optional<ConvolutionDimensionNumbers> dnums;
+      attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
+      attrs["dim_labels"] = {/*required=*/true,
+                             AttrTy::kConvolutionDimensionNumbers, &dnums};
+      if (!ParseOperands(&operands, /*expected_size=*/2) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
+          shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
       break;
     }
     case HloOpcode::kBroadcast:
     case HloOpcode::kCustomCall:
     case HloOpcode::kConcatenate:
     case HloOpcode::kReducePrecision:
-    case HloOpcode::kConvolution:
     case HloOpcode::kMap:
     case HloOpcode::kPad:
     case HloOpcode::kReduce:
-    case HloOpcode::kReduceWindow:
     case HloOpcode::kSelectAndScatter:
     case HloOpcode::kReverse:
     case HloOpcode::kRng:
@@ -438,43 +530,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
                                HloOpcodeString(opcode)));
   }
 
-  bool has_sharding = false;
-  bool has_control = false;
-  while (EatIfPresent(TokKind::kComma)) {
-    string attribute_name;
-    if (!ParseAttributeName(&attribute_name)) {
-      return TokenError("expects ', sharding=' or ', control-predecessors='");
-    }
-
-    if (attribute_name == "sharding") {
-      // Parse "sharding=".
-      if (has_sharding) {
-        return TokenError("expects at most 1 'sharding='");
+  // Add common attrs (sharding, control predecessors) to the instruction, if
+  // they were seen.
+  if (sharding) {
+    instruction->set_sharding(
+        HloSharding::FromProto(sharding.value()).ValueOrDie());
+  }
+  if (predecessors) {
+    for (auto* pre : *predecessors) {
+      Status status = pre->AddControlDependencyTo(instruction);
+      if (!status.ok()) {
+        return TokenError(StrCat("error adding control dependency for: ", name,
+                                 " status: ", status.ToString()));
       }
-      has_sharding = true;
-      if (!ParseSharding(instruction)) {
-        return false;
-      }
-    } else if (attribute_name == "control-predecessors") {
-      // Parse "control-predecessors"
-      if (has_control) {
-        return TokenError("expects at most 1 'control-predecessors='");
-      }
-      has_control = true;
-      if (!ParseControlPredecessors(instruction)) {
-        return false;
-      }
-    } else {
-      return TokenError(StrCat("unexpected attribute: ", attribute_name));
     }
   }
-
   return AddInstruction(name, instruction);
 }
 
 // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
 // dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
-bool HloParser::ParseSharding(HloInstruction* instruction) {
+bool HloParser::ParseSharding(OpSharding* sharding) {
   if (!ParseToken(TokKind::kLbrace,
                   "expected '{' to start sharding attribute")) {
     return false;
@@ -545,7 +621,6 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
     }
   }
 
-  OpSharding sharding;
   if (replicated) {
     if (!devices.empty()) {
       return TokenError(
@@ -555,7 +630,7 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
       return TokenError(
           "replicated shardings should not have any tile shape set");
     }
-    sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
+    sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
   } else if (maximal) {
     if (devices.size() != 1) {
       return TokenError(
@@ -564,8 +639,8 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
     if (!ShapeUtil::Equal(tile_shape, Shape())) {
       return TokenError("maximal shardings should not have any tile shape set");
     }
-    sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
-    sharding.add_tile_assignment_devices(devices[0]);
+    sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
+    sharding->add_tile_assignment_devices(devices[0]);
   } else {
     if (devices.size() <= 1) {
       return TokenError(
@@ -579,47 +654,43 @@ bool HloParser::ParseSharding(HloInstruction* instruction) {
           "non-maximal shardings must have a tile assignment list including "
           "dimensions");
     }
-    sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER);
-    *sharding.mutable_tile_shape() = tile_shape;
+    sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
+    *sharding->mutable_tile_shape() = tile_shape;
     for (int64 dim : tile_assignment_dimensions) {
-      sharding.add_tile_assignment_dimensions(dim);
+      sharding->add_tile_assignment_dimensions(dim);
     }
     for (int64 device : devices) {
-      sharding.add_tile_assignment_devices(device);
+      sharding->add_tile_assignment_devices(device);
     }
   }
 
-  instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie());
   lexer_.Lex();
   return true;
 }
 
 // '{' name+ '}'
-bool HloParser::ParseControlPredecessors(HloInstruction* instruction) {
+bool HloParser::ParseInstructionNames(
+    std::vector<HloInstruction*>* instructions) {
   if (!ParseToken(TokKind::kLbrace,
-                  "expects '{' at the beginning of control predecessors")) {
+                  "expects '{' at the beginning of instruction name list")) {
     return false;
   }
   do {
     string name;
     if (!ParseName(&name)) {
-      return TokenError("expects a control predecessor");
+      return TokenError("expects a instruction name");
     }
-    HloInstruction* pre =
+    HloInstruction* instr =
         tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
-    if (!pre) {
+    if (!instr) {
       return TokenError(
-          StrCat("control predecessor ", name, " is not defined: "));
-    }
-    Status status = pre->AddControlDependencyTo(instruction);
-    if (!status.ok()) {
-      return TokenError(StrCat("error adding control dependency for: ", name,
-                               " status: ", status.ToString()));
+          Printf("instruction '%s' is not defined", name.c_str()));
     }
+    instructions->push_back(instr);
   } while (EatIfPresent(TokKind::kComma));
 
   return ParseToken(TokKind::kRbrace,
-                    "expects '}' at the end of control predecessors");
+                    "expects '}' at the end of control instructions");
 }
 
 bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
@@ -957,28 +1028,95 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
   return true;
 }
 
-// extra_attribute ::= ',' attribute_name value
-template <typename T>
-bool HloParser::ParseExtraAttribute(T* value,
-                                    const string& expected_attribute) {
-  if (!ParseToken(TokKind::kComma,
-                  "expects ',' in front of an extra attribute")) {
-    return false;
+bool HloParser::ParseAttributes(
+    const std::unordered_map<string, AttrConfig>& attrs) {
+  std::unordered_set<string> seen_attrs;
+  while (EatIfPresent(TokKind::kComma)) {
+    string name;
+    if (!ParseAttributeName(&name)) {
+      return TokenError("error parsing attributes");
+    }
+    VLOG(1) << "Parsing attribute " << name;
+    if (!seen_attrs.insert(name).second) {
+      return TokenError(Printf("attribute %s already exists", name.c_str()));
+    }
+    auto attr_it = attrs.find(name);
+    if (attr_it == attrs.end()) {
+      return TokenError(Printf("unexpected attribute %s", name.c_str()));
+    }
+    AttrTy attr_type = attr_it->second.attr_type;
+    void* attr_out_ptr = attr_it->second.result;
+    bool success = [&] {
+      switch (attr_type) {
+        case AttrTy::kInt64: {
+          int64 result;
+          if (!ParseInt64(&result)) {
+            return false;
+          }
+          static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
+          return true;
+        }
+        case AttrTy::kHloComputation: {
+          HloComputation* result;
+          if (!ParseComputationName(&result)) {
+            return false;
+          }
+          static_cast<optional<HloComputation*>*>(attr_out_ptr)
+              ->emplace(result);
+          return true;
+        }
+        case AttrTy::kWindow: {
+          Window result;
+          if (!ParseWindow(&result)) {
+            return false;
+          }
+          static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
+          return true;
+        }
+        case AttrTy::kConvolutionDimensionNumbers: {
+          ConvolutionDimensionNumbers result;
+          if (!ParseConvolutionDimensionNumbers(&result)) {
+            return false;
+          }
+          static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
+              ->emplace(result);
+          return true;
+        }
+        case AttrTy::kSharding: {
+          OpSharding sharding;
+          if (!ParseSharding(&sharding)) {
+            return false;
+          }
+          static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
+          return true;
+        }
+        case AttrTy::kInstructionList: {
+          std::vector<HloInstruction*> result;
+          if (!ParseInstructionNames(&result)) {
+            return false;
+          }
+          static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
+              ->emplace(result);
+          return true;
+        }
+      }
+    }();
+    if (!success) {
+      return TokenError(Printf("error parsing attribute %s", name.c_str()));
+    }
   }
-  string attribute_name;
-  if (!ParseAttributeName(&attribute_name) &&
-      attribute_name != expected_attribute) {
-    return TokenError(StrCat("expects attribute name: ", expected_attribute));
-  }
-  if (!ParseAttributeValue(value)) {
-    return TokenError(
-        StrCat("expects value for attribute: ", expected_attribute));
+  // Check that all required attrs were seen.
+  for (const auto& attr_it : attrs) {
+    if (attr_it.second.required &&
+        seen_attrs.find(attr_it.first) == seen_attrs.end()) {
+      return TokenError(Printf("attribute %s is expected but not seen",
+                               attr_it.first.c_str()));
+    }
   }
   return true;
 }
 
-template <>
-bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) {
+bool HloParser::ParseComputationName(HloComputation** value) {
   string name;
   if (!ParseName(&name)) {
     return TokenError("expects computation name");
@@ -990,9 +1128,191 @@ bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) {
   return true;
 }
 
-template <>
-bool HloParser::ParseAttributeValue<int64>(int64* value) {
-  return ParseInt64(value);
+// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
+// The subattributes can appear in any order. 'size=' is required, others are
+// optional.
+bool HloParser::ParseWindow(Window* window) {
+  if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
+    return false;
+  }
+
+  std::vector<int64> size;
+  std::vector<int64> stride;
+  std::vector<std::vector<int64>> pad;
+  std::vector<int64> lhs_dilate;
+  std::vector<int64> rhs_dilate;
+  while (lexer_.GetKind() != TokKind::kRbrace) {
+    string field_name;
+    if (!ParseAttributeName(&field_name)) {
+      return TokenError("expects sub-attributes in window");
+    }
+    bool ok = [&] {
+      if (field_name == "size") {
+        return ParseDxD("size", &size);
+      }
+      if (field_name == "stride") {
+        return ParseDxD("stride", &stride);
+      }
+      if (field_name == "lhs_dilate") {
+        return ParseDxD("lhs_dilate", &lhs_dilate);
+      }
+      if (field_name == "rhs_dilate") {
+        return ParseDxD("rls_dilate", &rhs_dilate);
+      }
+      if (field_name == "pad") {
+        return ParseWindowPad(&pad);
+      }
+      return TokenError(StrCat("unexpected attribute name: ", field_name));
+    }();
+    if (!ok) {
+      return false;
+    }
+  }
+
+  if (size.empty()) {
+    return TokenError(
+        "sub-attribute 'size=' is required in the window attribute");
+  }
+  if (!stride.empty() && stride.size() != size.size()) {
+    return TokenError("expects 'stride=' has the same size as 'size='");
+  }
+  if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
+    return TokenError("expects 'lhs_dilate=' has the same size as 'size='");
+  }
+  if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
+    return TokenError("expects 'rhs_dilate=' has the same size as 'size='");
+  }
+  if (!pad.empty() && pad.size() != size.size()) {
+    return TokenError("expects 'pad=' has the same size as 'size='");
+  }
+
+  for (int i = 0; i < size.size(); i++) {
+    window->add_dimensions()->set_size(size[i]);
+    if (!pad.empty()) {
+      window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
+      window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
+    }
+    // If some field is not present, it has the default value.
+    window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
+    window->mutable_dimensions(i)->set_base_dilation(
+        lhs_dilate.empty() ? 1 : lhs_dilate[i]);
+    window->mutable_dimensions(i)->set_window_dilation(
+        rhs_dilate.empty() ? 1 : rhs_dilate[i]);
+  }
+  return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
+}
+
+// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
+// The string looks like "dim_labels=0bf_0io->0bf".
+bool HloParser::ParseConvolutionDimensionNumbers(
+    ConvolutionDimensionNumbers* dnums) {
+  if (lexer_.GetKind() != TokKind::kDimLabels) {
+    return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
+  }
+  string str = lexer_.GetStrVal();
+
+  // The str is expected to have 3 items, lhs, rhs, out, and it must looks like
+  // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
+  // So we replace the "->" with "_" and then split on "_".
+  str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
+                                            /*newsub=*/"_",
+                                            /*replace_all=*/false);
+  std::vector<string> lhs_rhs_out = Split(str, "_");
+  if (lhs_rhs_out.size() != 3) {
+    LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
+               << str;
+  }
+
+  const int64 rank = lhs_rhs_out[0].length();
+  if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
+    return TokenError(
+        "convolution lhs, rhs, and output must have the same rank");
+  }
+  if (rank < 3) {
+    return TokenError("convolution rank must >=3");
+  }
+
+  auto is_unique = [](string str) -> bool {
+    std::sort(str.begin(), str.end());
+    return std::unique(str.begin(), str.end()) == str.end();
+  };
+
+  // lhs
+  {
+    const string& lhs = lhs_rhs_out[0];
+    if (!is_unique(lhs)) {
+      return TokenError(
+          StrCat("expects unique lhs dimension numbers, but sees ", lhs));
+    }
+    for (int i = 0; i < rank - 2; i++) {
+      dnums->add_spatial_dimensions(-1);
+    }
+    for (int i = 0; i < rank; i++) {
+      char c = lhs[i];
+      if (c == 'b') {
+        dnums->set_input_batch_dimension(i);
+      } else if (c == 'f') {
+        dnums->set_input_feature_dimension(i);
+      } else if (c < '0' + rank && c >= '0') {
+        dnums->set_spatial_dimensions(c - '0', i);
+      } else {
+        return TokenError(
+            Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
+      }
+    }
+  }
+  // rhs
+  {
+    const string& rhs = lhs_rhs_out[1];
+    if (!is_unique(rhs)) {
+      return TokenError(
+          StrCat("expects unique rhs dimension numbers, but sees ", rhs));
+    }
+    for (int i = 0; i < rank - 2; i++) {
+      dnums->add_kernel_spatial_dimensions(-1);
+    }
+    for (int i = 0; i < rank; i++) {
+      char c = rhs[i];
+      if (c == 'i') {
+        dnums->set_kernel_input_feature_dimension(i);
+      } else if (c == 'o') {
+        dnums->set_kernel_output_feature_dimension(i);
+      } else if (c < '0' + rank && c >= '0') {
+        dnums->set_kernel_spatial_dimensions(c - '0', i);
+      } else {
+        return TokenError(
+            Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1));
+      }
+    }
+  }
+  // output
+  {
+    const string& out = lhs_rhs_out[2];
+    if (!is_unique(out)) {
+      return TokenError(
+          StrCat("expects unique output dimension numbers, but sees ", out));
+    }
+    for (int i = 0; i < rank; i++) {
+      char c = out[i];
+      if (c == 'b') {
+        dnums->set_output_batch_dimension(i);
+      } else if (c == 'f') {
+        dnums->set_output_feature_dimension(i);
+      } else if (c < '0' + rank && c >= '0') {
+        if (dnums->spatial_dimensions(c - '0') != i) {
+          return TokenError(
+              "output spatial dimensions should be the same as input spatial "
+              "dimensions");
+        }
+      } else {
+        return TokenError(
+            Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
+      }
+    }
+  }
+
+  lexer_.Lex();
+  return true;
 }
 
 // param_list ::= '(' param_list1 ')'
@@ -1070,6 +1390,55 @@ bool HloParser::ParseAttributeName(string* result) {
   return true;
 }
 
+bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
+  if (!result->empty()) {
+    return TokenError(
+        Printf("sub-attribute '%s=' already exists", name.c_str()));
+  }
+  // 1D
+  if (lexer_.GetKind() == TokKind::kInt) {
+    int64 number;
+    if (!ParseInt64(&number)) {
+      return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str()));
+    }
+    result->push_back(number);
+    return true;
+  }
+  // 2D or higher.
+  if (lexer_.GetKind() == TokKind::kDxD) {
+    string str = lexer_.GetStrVal();
+    if (!SplitAndParseAsInts(str, 'x', result)) {
+      return TokenError(
+          Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
+    }
+    lexer_.Lex();
+    return true;
+  }
+  return TokenError("expects token type kInt or kDxD");
+}
+
+bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
+  if (!pad->empty()) {
+    return TokenError("sub-attribute 'pad=' already exists");
+  }
+  if (lexer_.GetKind() != TokKind::kWindowPad) {
+    return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
+  }
+  string str = lexer_.GetStrVal();
+  std::vector<string> padding_str = Split(str, 'x');
+  for (int i = 0; i < padding_str.size(); i++) {
+    std::vector<int64> low_high;
+    if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
+        low_high.size() != 2) {
+      return TokenError(
+          "expects padding_low and padding_high separated by '_'");
+    }
+    pad->push_back(low_high);
+  }
+  lexer_.Lex();
+  return true;
+}
+
 bool HloParser::ParseOpcode(HloOpcode* result) {
   VLOG(1) << "ParseOpcode";
   if (lexer_.GetKind() != TokKind::kOpcode) {
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 359256f0646..62b4385e76f 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -25,6 +25,7 @@ namespace tools {
 namespace {
 
 using tensorflow::StringPiece;
+using tensorflow::strings::StrCat;
 
 struct TestData {
   string test_name;
@@ -247,6 +248,39 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
   ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
 }
 
+)"
+},
+// reduce window
+{
+"ReduceWindow",
+R"(HloModule R4UnitWindow_module:
+
+%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+}
+
+ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
+  %operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
+  %constant = f32[] constant(0)
+  ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
+}
+
+)"
+},
+// convolution
+{
+"Convolution",
+R"(HloModule Convolve1D1Window_0_module:
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
+  %input = f32[1,2,1]{2,1,0} parameter(0)
+  %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
+  %filter = f32[1,1,1]{2,1,0} parameter(1)
+  ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f
+}
+
 )"
 }
   });
@@ -427,6 +461,92 @@ ENTRY %ConstantWithExp.v4 () -> f32[] {
   // printed as "300".
 }
 
+TEST_F(HloParserTest, AttibutesAnyOrder) {
+  const string original = R"(HloModule any_order_module:
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
+  %input = f32[1,2,1]{2,1,0} parameter(0)
+  %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
+  %filter = f32[1,1,1]{2,1,0} parameter(1)
+  ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
+}
+
+)";
+  TF_EXPECT_OK(Parse(original).status());
+}
+
+TEST_F(HloParserTest, InvalidDimLabels) {
+  string prefix = R"(HloModule invalid_dim_labels_module:
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
+  %input = f32[1,2,1]{2,1,0} parameter(0)
+  %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
+  %filter = f32[1,1,1]{2,1,0} parameter(1)
+  ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
+  string suffix = R"(
+}
+
+)";
+
+  ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix))
+                      .status()
+                      .error_message(),
+                  "expects dim labels pattern");
+
+  ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix))
+                      .status()
+                      .error_message(),
+                  "must have the same rank");
+
+  ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=0bf_io0->b0f", suffix))
+                      .status()
+                      .error_message(),
+                  "output spatial dimensions should be the same as input "
+                  "spatial dimensions");
+}
+
+TEST_F(HloParserTest, UnexpectedAttribute) {
+  const string original = R"(HloModule unexpected_attr_module:
+
+ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
+  %recv = f32[] recv(), channel_id=15
+  ROOT %constant = f32[] constant(2.1)
+  %send = () send(f32[] %constant), channel_id=16, calls=%recv
+}
+
+)";
+  ExpectHasSubstr(Parse(original).status().error_message(),
+                  "unexpected attribute calls");
+}
+
+TEST_F(HloParserTest, MissingAttribute) {
+  const string original = R"(HloModule missing_attr_module:
+
+ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
+  %recv = f32[] recv(), channel_id=15
+  ROOT %constant = f32[] constant(-2.1)
+  %send = () send(f32[] %constant)
+}
+
+)";
+  ExpectHasSubstr(Parse(original).status().error_message(),
+                  "attribute channel_id is expected but not seen");
+}
+
+TEST_F(HloParserTest, PredecessorUndefined) {
+  const string original = R"(HloModule pre_not_found_module:
+
+ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
+  %recv = f32[] recv(), channel_id=15
+  ROOT %constant = f32[] constant(2.1)
+  %send = () send(f32[] %constant), channel_id=16, control-predecessors={%done}
+}
+
+)";
+  ExpectHasSubstr(Parse(original).status().error_message(),
+                  "'done' is not defined");
+}
+
 }  // namespace
 }  // namespace tools
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h
index 9c2069e7568..15ab8b1cccf 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_token.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h
@@ -57,6 +57,9 @@ enum class TokKind {
   // Typed tokens.
   kName,           // %foo
   kAttributeName,  // dimensions=
+  kDimLabels,      // [0-9bf]+_[0-9io]+->[0-9bf]+
+  kDxD,            // [0-9]+(x[0-9]+)+
+  kWindowPad,      // [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*
   kShape,          // f32[2,3]{1,0}
   kOpcode,         // add
   kInt,            // 42
diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc
index 23161873a0b..6f7f1479b90 100644
--- a/tensorflow/compiler/xla/window_util.cc
+++ b/tensorflow/compiler/xla/window_util.cc
@@ -26,8 +26,8 @@ namespace xla {
 namespace window_util {
 
 /* static */ string ToString(const WindowDimension& dim) {
-  using tensorflow::strings::StrCat;
   using tensorflow::strings::StrAppend;
+  using tensorflow::strings::StrCat;
   string str = StrCat("(size=", dim.size());
   if (dim.stride() != 1) {
     StrAppend(&str, ",stride=", dim.stride());
@@ -49,22 +49,22 @@ namespace window_util {
 }
 
 string ToString(const Window& window) {
-  using tensorflow::strings::StrCat;
   using tensorflow::strings::StrAppend;
+  using tensorflow::strings::StrCat;
 
   string str;
-  const auto add_field = [&](
-      const char* heading,
-      std::function<string(const WindowDimension&)> format) {
-    StrAppend(&str, heading, "=");
-    const char* prefix = "";
-    for (const auto& window_dimension : window.dimensions()) {
-      StrAppend(&str, prefix, format(window_dimension));
-      prefix = "x";
-    }
-  };
+  const auto add_field =
+      [&](const char* heading,
+          std::function<string(const WindowDimension&)> format) {
+        StrAppend(&str, heading, "=");
+        const char* prefix = "";
+        for (const auto& window_dimension : window.dimensions()) {
+          StrAppend(&str, prefix, format(window_dimension));
+          prefix = "x";
+        }
+      };
 
-  add_field("window",
+  add_field("size",
             [](const WindowDimension& dim) { return StrCat(dim.size()); });
   if (HasStride(window)) {
     add_field(" stride",

From 35febc0cc9c27d57e574dc6a3bd634f9611feb60 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 15:24:05 -0800
Subject: [PATCH 054/115] Add a --all_tensor_names option, which is useful if I
 only want to know all tensor names. It is especially useful in cases whether
 some of the tensors has huge size. Also update the usage description.

PiperOrigin-RevId: 175074541
---
 tensorflow/python/tools/inspect_checkpoint.py | 23 +++++++++++++++----
 1 file changed, 18 insertions(+), 5 deletions(-)

diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py
index 47a74e5abfb..8716058e619 100644
--- a/tensorflow/python/tools/inspect_checkpoint.py
+++ b/tensorflow/python/tools/inspect_checkpoint.py
@@ -29,7 +29,8 @@ from tensorflow.python.platform import flags
 FLAGS = None
 
 
-def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
+def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
+                                     all_tensor_names):
   """Prints tensors in a checkpoint file.
 
   If no `tensor_name` is provided, prints the tensor names and shapes
@@ -41,14 +42,16 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
     file_name: Name of the checkpoint file.
     tensor_name: Name of the tensor in the checkpoint file to print.
     all_tensors: Boolean indicating whether to print all tensors.
+    all_tensor_names: Boolean indicating whether to print all tensor names.
   """
   try:
     reader = pywrap_tensorflow.NewCheckpointReader(file_name)
-    if all_tensors:
+    if all_tensors or all_tensor_names:
       var_to_shape_map = reader.get_variable_to_shape_map()
       for key in sorted(var_to_shape_map):
         print("tensor_name: ", key)
-        print(reader.get_tensor(key))
+        if all_tensors:
+          print(reader.get_tensor(key))
     elif not tensor_name:
       print(reader.debug_string().decode("utf-8"))
     else:
@@ -104,11 +107,14 @@ def parse_numpy_printoption(kv_str):
 def main(unused_argv):
   if not FLAGS.file_name:
     print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
-          "[--tensor_name=tensor_to_print]")
+          "[--tensor_name=tensor_to_print] "
+          "[--all_tensors] "
+          "[--all_tensor_names] "
+          "[--printoptions]")
     sys.exit(1)
   else:
     print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
-                                     FLAGS.all_tensors)
+                                     FLAGS.all_tensors, FLAGS.all_tensor_names)
 
 
 if __name__ == "__main__":
@@ -130,6 +136,13 @@ if __name__ == "__main__":
       type="bool",
       default=False,
       help="If True, print the values of all the tensors.")
+  parser.add_argument(
+      "--all_tensor_names",
+      nargs="?",
+      const=True,
+      type="bool",
+      default=False,
+      help="If True, print the names of all the tensors.")
   parser.add_argument(
       "--printoptions",
       nargs="*",

From a6babd6a4f6462e805be946bf6b352b2e4248794 Mon Sep 17 00:00:00 2001
From: Mark Heffernan <meheff@google.com>
Date: Wed, 8 Nov 2017 15:35:27 -0800
Subject: [PATCH 055/115] Move MakeFakeLiteral from client/lib/testing.h to
 tests/test_utils.h. Also remove superfluous literal creation methods in that
 file, and replace them with the existing ones in the Literal class.

Also, optionally print layout in Literal::ToString.

PiperOrigin-RevId: 175076277
---
 tensorflow/compiler/xla/client/lib/BUILD      |   1 +
 tensorflow/compiler/xla/client/lib/testing.cc |  57 +--------
 tensorflow/compiler/xla/client/lib/testing.h  |   4 -
 tensorflow/compiler/xla/literal_util.cc       |  22 +++-
 tensorflow/compiler/xla/literal_util.h        |   2 +-
 tensorflow/compiler/xla/service/BUILD         |   2 -
 .../compiler/xla/service/hlo_cse_test.cc      |  24 ++--
 .../xla/service/layout_assignment_test.cc     |  32 ++---
 tensorflow/compiler/xla/tests/BUILD           |   3 +-
 .../xla/tests/client_library_test_base.h      |   6 +-
 tensorflow/compiler/xla/tests/client_test.cc  |   4 +-
 .../xla/tests/compilation_cache_test.cc       |   8 +-
 .../xla/tests/compute_constant_test.cc        |   4 +-
 .../compiler/xla/tests/dot_operation_test.cc  |  25 ++--
 .../xla/tests/local_client_execute_test.cc    |  10 +-
 tensorflow/compiler/xla/tests/map_test.cc     |   8 +-
 tensorflow/compiler/xla/tests/test_utils.cc   | 120 ++++++++++++++++++
 tensorflow/compiler/xla/tests/test_utils.h    |  64 ++--------
 tensorflow/compiler/xla/tools/BUILD           |   1 +
 .../compiler/xla/tools/replay_computation.cc  |   1 +
 20 files changed, 209 insertions(+), 189 deletions(-)
 create mode 100644 tensorflow/compiler/xla/tests/test_utils.cc

diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index ee346820879..fca2bf2688c 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -44,6 +44,7 @@ cc_library(
         "//tensorflow/compiler/xla/client:computation",
         "//tensorflow/compiler/xla/client:computation_builder",
         "//tensorflow/compiler/xla/client:global_data",
+        "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/core:lib",
     ],
 )
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
index e6645e4941b..d936bd870b8 100644
--- a/tensorflow/compiler/xla/client/lib/testing.cc
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -21,6 +21,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -48,62 +49,6 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
 
 }  // namespace
 
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
-  if (ShapeUtil::IsTuple(shape)) {
-    std::vector<std::unique_ptr<Literal>> elements;
-    for (const Shape& element_shape : shape.tuple_shapes()) {
-      TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
-                          MakeFakeLiteral(element_shape));
-      elements.push_back(std::move(element));
-    }
-    return Literal::MakeTupleOwned(std::move(elements));
-  }
-  std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
-  std::minstd_rand0 engine;
-  switch (shape.element_type()) {
-    case F32: {
-      std::uniform_real_distribution<float> generator(0.0f, 1.0f);
-      TF_CHECK_OK(literal->Populate<float>(
-          [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
-            return generator(engine);
-          }));
-      break;
-    }
-    case S32: {
-      std::uniform_int_distribution<int32> generator(
-          std::numeric_limits<int32>::lowest(),
-          std::numeric_limits<int32>::max());
-      TF_CHECK_OK(literal->Populate<int32>(
-          [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
-            return generator(engine);
-          }));
-      break;
-    }
-    case S64: {
-      std::uniform_int_distribution<int64> generator(
-          std::numeric_limits<int64>::lowest(),
-          std::numeric_limits<int64>::max());
-      TF_CHECK_OK(literal->Populate<int64>(
-          [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
-            return generator(engine);
-          }));
-      break;
-    }
-    case PRED: {
-      std::uniform_int_distribution<int> generator(0, 1);
-      TF_CHECK_OK(literal->Populate<bool>(
-          [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
-            return generator(engine);
-          }));
-      break;
-    }
-    default:
-      return Unimplemented("Unsupported type for fake literal generation: %s",
-                           ShapeUtil::HumanString(shape).c_str());
-  }
-  return std::move(literal);
-}
-
 std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
                                               Client* client) {
   if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) {
diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h
index b5c4393dcc3..7e640d1307e 100644
--- a/tensorflow/compiler/xla/client/lib/testing.h
+++ b/tensorflow/compiler/xla/client/lib/testing.h
@@ -26,10 +26,6 @@ limitations under the License.
 
 namespace xla {
 
-// Generates fake data in a literal of the given shape, or returns an error
-// status if the element type is currently unhandled for fake data generation.
-StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
-
 // Generates fake data of the given shape on the device or dies. The fake data
 // is created by performing a computation on the device rather than transferring
 // data from the host to the device.
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index fda791401d5..0cb2223ae5a 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -569,9 +569,17 @@ int64 Literal::LinearIndex(
   return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index);
 }
 
-string Literal::ToString() const {
+string Literal::ToString(bool print_layout) const {
   std::vector<string> pieces;
 
+  auto shape_to_string = [print_layout](const Shape& shape) {
+    if (print_layout) {
+      return ShapeUtil::HumanStringWithLayout(shape);
+    } else {
+      return ShapeUtil::HumanString(shape);
+    }
+  };
+
   auto element_to_string =
       [this](tensorflow::gtl::ArraySlice<int64> indices) -> string {
     PrimitiveType element_type = shape().element_type();
@@ -585,7 +593,7 @@ string Literal::ToString() const {
 
   // TODO(b/32894291): refactor this code to reduce code duplication.
   if (ShapeUtil::IsTuple(shape())) {
-    pieces.push_back(ShapeUtil::HumanString(shape()));
+    pieces.push_back(shape_to_string(shape()));
     pieces.push_back(" (\n");
     pieces.push_back(tensorflow::str_util::Join(
         tuple_literals(), ",\n", [](string* out, const Literal& element) {
@@ -601,7 +609,7 @@ string Literal::ToString() const {
     }
     pieces.push_back("}");
   } else if (ShapeUtil::Rank(shape()) == 2) {
-    pieces.push_back(ShapeUtil::HumanString(shape()));
+    pieces.push_back(shape_to_string(shape()));
     pieces.push_back(" {\n");
     for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
       pieces.push_back("  { ");
@@ -613,7 +621,7 @@ string Literal::ToString() const {
     }
     pieces.push_back("}");
   } else if (ShapeUtil::Rank(shape()) == 3) {
-    pieces.push_back(ShapeUtil::HumanString(shape()));
+    pieces.push_back(shape_to_string(shape()));
     pieces.push_back(" {\n");
     for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
       pieces.push_back(i0 > 0 ? ",\n{" : "{");
@@ -628,7 +636,7 @@ string Literal::ToString() const {
     }
     pieces.push_back("\n}");
   } else if (ShapeUtil::Rank(shape()) == 4) {
-    pieces.push_back(ShapeUtil::HumanString(shape()));
+    pieces.push_back(shape_to_string(shape()));
     pieces.push_back(" {\n");
     for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
       pieces.push_back(tensorflow::strings::Printf("  {  /*i0=%lld*/\n", i0));
@@ -649,7 +657,7 @@ string Literal::ToString() const {
     }
     pieces.push_back("}");
   } else if (ShapeUtil::Rank(shape()) == 5) {
-    pieces.push_back(ShapeUtil::HumanString(shape()));
+    pieces.push_back(shape_to_string(shape()));
     pieces.push_back(" {\n");
     for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) {
       pieces.push_back(tensorflow::strings::Printf("  {  /*i0=%lld*/\n", i0));
@@ -676,7 +684,7 @@ string Literal::ToString() const {
     }
     pieces.push_back("}");
   } else {
-    pieces.push_back(ShapeUtil::HumanString(shape()));
+    pieces.push_back(shape_to_string(shape()));
     pieces.push_back(" {...}");
   }
 
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index a1e288829f2..667f926c464 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -450,7 +450,7 @@ class Literal {
   tensorflow::Status ValidateLiteral() const;
 
   // Returns a string representation of the literal value.
-  string ToString() const;
+  string ToString(bool print_layout = false) const;
 
   // Invokes the "per cell" callback for each element in the provided
   // literal with the element's indices and a string representation of
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index c6f6c6c38bc..7cf24641b5b 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1780,7 +1780,6 @@ tf_cc_test(
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:test_utils",
-        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
     ],
 )
@@ -1851,7 +1850,6 @@ tf_cc_test(
         "//tensorflow/compiler/xla/tests:hlo_test_base",
         "//tensorflow/compiler/xla/tests:literal_test_util",
         "//tensorflow/compiler/xla/tests:test_utils",
-        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
         "//tensorflow/core:lib",
     ],
 )
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
index 7c4626e78a3..3601a790c44 100644
--- a/tensorflow/compiler/xla/service/hlo_cse_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
   // Test that two identical constants with different layouts are commoned if
   // the pass is not layout sensitive.
   auto builder = HloComputation::Builder(TestName());
-  auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
-      test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                                   /*minor_to_major=*/{0, 1})));
-  auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
-      test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                                   /*minor_to_major=*/{1, 0})));
+  auto constant1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+  auto constant2 = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
 
@@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
   // Test that two identical constants with different layouts are *not* commoned
   // if the pass is layout sensitive.
   auto builder = HloComputation::Builder(TestName());
-  auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
-      test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                                   /*minor_to_major=*/{0, 1})));
-  auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
-      test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                                   /*minor_to_major=*/{1, 0})));
+  auto constant1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+  auto constant2 = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
       constant1->shape(), HloOpcode::kAdd, constant1, constant2));
 
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index c39ff522300..d51c0d1dfb7 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) {
   std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
   for (auto& minor_to_major : minor_to_majors) {
     auto builder = HloComputation::Builder(TestName());
-    auto constant_literal1 = test_utils::CreateR2LiteralWithLayout<float>(
-        {{1.0, 2.0}, {3.0, 4.0}}, minor_to_major);
-    auto constant_literal2 = test_utils::CreateR2LiteralWithLayout<float>(
-        {{5.0, 6.0}, {7.0, 8.0}}, minor_to_major);
+    auto constant_literal1 = Literal::CreateR2WithLayout<float>(
+        {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major));
+    auto constant_literal2 = Literal::CreateR2WithLayout<float>(
+        {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major));
     Shape ashape = constant_literal1->shape();
 
     auto constant1 = builder.AddInstruction(
@@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
   // Verify the layouts of a tuple are assigned properly (the element layouts
   // match their source).
   auto builder = HloComputation::Builder(TestName());
-  auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                                   {0, 1})));
-  auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
-      test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                                   {1, 0})));
+  auto constant0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+  auto constant1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
   auto tuple = builder.AddInstruction(
       HloInstruction::CreateTuple({constant0, constant1}));
 
@@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) {
 TEST_F(LayoutAssignmentTest, TupleSelect) {
   // Verify layouts of a select with tuple operands is assigned properly.
   auto builder = HloComputation::Builder(TestName());
-  auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
-      test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                                   {0, 1})));
-  auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
-      test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
-                                                   {1, 0})));
+  auto constant0 = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
+  auto constant1 = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR2WithLayout<float>(
+          {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0}))));
   auto tuple0 = builder.AddInstruction(
       HloInstruction::CreateTuple({constant0, constant1}));
   auto tuple1 = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 4e1be24b61c..2333a30ad58 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -61,13 +61,14 @@ generate_backend_test_macros()
 
 cc_library(
     name = "test_utils",
-    testonly = True,
+    srcs = ["test_utils.cc"],
     hdrs = ["test_utils.h"],
     deps = [
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto",
+        "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/core:lib",
     ],
 )
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 7cfc276ec19..2c37466ff20 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -469,8 +469,7 @@ template <typename NativeT>
 std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
     const int width, NativeT min_value, NativeT max_value, uint32 seed) {
   std::vector<NativeT> result(width);
-  test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
-                                                       seed);
+  PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
   for (int i = 0; i < width; ++i) {
     result[i] = generator.get();
   }
@@ -482,8 +481,7 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
     const int rows, const int cols, NativeT min_value, NativeT max_value,
     uint32 seed) {
   auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
-  test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
-                                                       seed);
+  PseudorandomGenerator<NativeT> generator(min_value, max_value, seed);
   for (int y = 0; y < rows; ++y) {
     for (int x = 0; x < cols; ++x) {
       (*result)(y, x) = generator.get();
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index 0853feeebd6..183bcf1dd33 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -54,8 +54,8 @@ TEST_F(ClientTest, ExecuteWithLayout) {
               .ConsumeValueOrDie();
 
       std::unique_ptr<Literal> expected_literal =
-          test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
-                                                       transfer_layout);
+          Literal::CreateR2WithLayout<int32>(
+              {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout));
 
       auto computed = client_->Transfer(*data, &expected_literal->shape());
 
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
index 707e439245c..0f780fa87ef 100644
--- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
   // layouts. Use these arrays as parameters to a simple computation. If the
   // layout of the array changes then computation should be recompiled (cache
   // miss).
-  auto rowmaj_array = test_utils::CreateR2LiteralWithLayout(
-      {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0});
+  auto rowmaj_array = Literal::CreateR2WithLayout(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0}));
   auto rowmaj_handle =
       client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
 
-  auto colmaj_array = test_utils::CreateR2LiteralWithLayout(
-      {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1});
+  auto colmaj_array = Literal::CreateR2WithLayout(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}));
   auto colmaj_handle =
       client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
 
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
index d423c78476d..5226a783868 100644
--- a/tensorflow/compiler/xla/tests/compute_constant_test.cc
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -264,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
       ASSERT_TRUE(computed.ok()) << computed.status();
 
       std::unique_ptr<Literal> expected_literal =
-          test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
-                                                       layout);
+          Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
+                                             LayoutUtil::MakeLayout(layout));
       LiteralTestUtil::AssertEqualShapesAndLayouts(
           expected_literal->shape(), computed.ValueOrDie()->shape());
       LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index c4e422b506b..b72dd2707c2 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
                                            bool rhs_row_major) {
   auto lhs_handle =
       client_
-          ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+          ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
               {{1.0, 2.0}, {3.0, -4.0}},
-              MinorToMajorForIsRowMajor(lhs_row_major)))
+              LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
           .ConsumeValueOrDie();
   auto rhs_handle =
       client_
-          ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+          ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
               {{1.0, 6.0}, {7.0, -4.0}},
-              MinorToMajorForIsRowMajor(rhs_row_major)))
+              LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
           .ConsumeValueOrDie();
 
   ComputationBuilder builder(client_, TestName());
@@ -362,15 +362,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
                                               bool rhs_row_major) {
   auto lhs_handle =
       client_
-          ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+          ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
               {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
-              MinorToMajorForIsRowMajor(lhs_row_major)))
+              LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))))
           .ConsumeValueOrDie();
   auto rhs_handle =
       client_
-          ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+          ->TransferToServer(*Literal::CreateR2WithLayout<Element>(
               {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
-              MinorToMajorForIsRowMajor(rhs_row_major)))
+              LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))))
           .ConsumeValueOrDie();
 
   ComputationBuilder builder(client_, TestName());
@@ -420,13 +420,14 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) {
 XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
   auto lhs_handle =
       client_
-          ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
-              {{1.0, 2.0, 3.0, -4.0}}, {1, 0}))
+          ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+              {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
           .ConsumeValueOrDie();
   auto rhs_handle =
       client_
-          ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<complex64>(
-              {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, {1, 0}))
+          ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+              {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
+              LayoutUtil::MakeLayout({1, 0})))
           .ConsumeValueOrDie();
 
   ComputationBuilder builder(client_, TestName());
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 329b53012f5..a196e250d11 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
   auto computation = builder.Build().ConsumeValueOrDie();
 
   // Create x as a col-major array.
-  auto x_array = LiteralToShapedBuffer(
-      *test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}},
-                                             /*minor_to_major=*/{0, 1}));
+  auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+      {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})));
   EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(),
                                 LayoutUtil::MakeLayout({0, 1})));
 
   // Create y as a row-major array.
-  auto y_array = LiteralToShapedBuffer(
-      *test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}},
-                                             /*minor_to_major=*/{1, 0}));
+  auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout(
+      {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0})));
   EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(),
                                 LayoutUtil::MakeLayout({1, 0})));
 
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 2ef392508d1..2b0f7e6e80c 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) {
 // for Map that used to fail in shape inference (b/28989438).
 XLA_TEST_F(MapTest, AddWithMixedLayouts) {
   ComputationBuilder builder(client_, TestName());
-  std::unique_ptr<Literal> param0_literal =
-      test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0});
+  std::unique_ptr<Literal> param0_literal = Literal::CreateR2WithLayout(
+      {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0}));
   std::unique_ptr<GlobalData> param0_data =
       client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
 
-  std::unique_ptr<Literal> param1_literal =
-      test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1});
+  std::unique_ptr<Literal> param1_literal = Literal::CreateR2WithLayout(
+      {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1}));
   std::unique_ptr<GlobalData> param1_data =
       client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
 
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
new file mode 100644
index 00000000000..cdd3d66bbba
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+
+#include "tensorflow/compiler/xla/primitive_util.h"
+
+namespace xla {
+
+namespace {
+
+template <typename FloatT>
+void PopulateWithRandomFloatingPointData(Literal* literal) {
+  CHECK_EQ(literal->shape().element_type(),
+           primitive_util::NativeToPrimitiveType<FloatT>());
+  std::minstd_rand0 engine;
+  std::uniform_real_distribution<FloatT> generator(0.0f, 1.0f);
+  TF_CHECK_OK(literal->Populate<FloatT>(
+      [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+        return generator(engine);
+      }));
+}
+
+template <typename IntT>
+void PopulateWithRandomIntegralData(Literal* literal) {
+  CHECK_EQ(literal->shape().element_type(),
+           primitive_util::NativeToPrimitiveType<IntT>());
+  std::minstd_rand0 engine;
+  std::uniform_int_distribution<IntT> generator(
+      std::numeric_limits<IntT>::lowest(), std::numeric_limits<IntT>::max());
+  TF_CHECK_OK(literal->Populate<IntT>(
+      [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+        return generator(engine);
+      }));
+}
+
+}  // namespace
+
+StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
+  if (ShapeUtil::IsTuple(shape)) {
+    std::vector<std::unique_ptr<Literal>> elements;
+    for (const Shape& element_shape : shape.tuple_shapes()) {
+      TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> element,
+                          MakeFakeLiteral(element_shape));
+      elements.push_back(std::move(element));
+    }
+    return Literal::MakeTupleOwned(std::move(elements));
+  }
+  std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
+  switch (shape.element_type()) {
+    case F32:
+      PopulateWithRandomFloatingPointData<float>(literal.get());
+      break;
+    case F64:
+      PopulateWithRandomFloatingPointData<double>(literal.get());
+      break;
+    case S8:
+      PopulateWithRandomIntegralData<int8>(literal.get());
+      break;
+    case U8:
+      PopulateWithRandomIntegralData<uint8>(literal.get());
+      break;
+    case S16:
+      PopulateWithRandomIntegralData<int16>(literal.get());
+      break;
+    case U16:
+      PopulateWithRandomIntegralData<uint16>(literal.get());
+      break;
+    case S32:
+      PopulateWithRandomIntegralData<int32>(literal.get());
+      break;
+    case U32:
+      PopulateWithRandomIntegralData<uint32>(literal.get());
+      break;
+    case S64:
+      PopulateWithRandomIntegralData<int64>(literal.get());
+      break;
+    case U64:
+      PopulateWithRandomIntegralData<uint64>(literal.get());
+      break;
+    case PRED: {
+      std::uniform_int_distribution<int> generator(0, 1);
+      std::minstd_rand0 engine;
+      TF_CHECK_OK(literal->Populate<bool>(
+          [&](tensorflow::gtl::ArraySlice<int64> /*indices*/) {
+            return generator(engine);
+          }));
+      break;
+    }
+    default:
+      return Unimplemented("Unsupported type for fake literal generation: %s",
+                           ShapeUtil::HumanString(shape).c_str());
+  }
+  return std::move(literal);
+}
+
+StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
+    const HloModule& module) {
+  std::vector<std::unique_ptr<Literal>> arguments;
+  for (const ShapeLayout& shape_layout :
+       module.config().entry_computation_layout().parameter_layouts()) {
+    TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape()));
+    arguments.push_back(std::move(literal));
+  }
+  return std::move(arguments);
+}
+
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
index f3a522b05eb..12d5255fce5 100644
--- a/tensorflow/compiler/xla/tests/test_utils.h
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -23,12 +23,12 @@ limitations under the License.
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/platform/types.h"
 
 namespace xla {
-namespace test_utils {
 
 // A class which generates pseudorandom numbers of a given type within a given
 // range. Not cryptographically secure and likely not perfectly evenly
@@ -53,63 +53,15 @@ class PseudorandomGenerator {
   std::mt19937 generator_;
 };
 
-// Convenience function for creating a rank-2 array with arbitrary layout.
-template <typename NativeT>
-std::unique_ptr<Literal> CreateR2LiteralWithLayout(
-    std::initializer_list<std::initializer_list<NativeT>> values,
-    tensorflow::gtl::ArraySlice<int64> minor_to_major) {
-  auto literal = MakeUnique<Literal>();
-  const int64 d0 = values.size();
-  const int64 d1 = values.begin()->size();
-  literal.get()->PopulateWithValue<NativeT>(0, {d0, d1});
-  *literal->mutable_shape()->mutable_layout() =
-      LayoutUtil::MakeLayout(minor_to_major);
-  TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
+// Generates fake data in a literal of the given shape, or returns an error
+// status if the element type is currently unhandled for fake data generation.
+StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape);
 
-  int64 dim0 = 0;
-  for (auto inner_list : values) {
-    int64 dim1 = 0;
-    for (auto value : inner_list) {
-      literal.get()->Set({dim0, dim1}, value);
-      ++dim1;
-    }
-    ++dim0;
-  }
-  return literal;
-}
+// Generates a vector of arguments containing fake data. The number, shape and
+// layout of the arguments is appropriate for given HLO module.
+StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
+    const HloModule& module);
 
-// Convenience function for creating a rank-3 array with arbitrary layout.
-template <typename NativeT>
-std::unique_ptr<Literal> CreateR3LiteralWithLayout(
-    std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
-        values,
-    tensorflow::gtl::ArraySlice<int64> minor_to_major) {
-  auto literal = MakeUnique<Literal>();
-  const int64 d0 = values.size();
-  const int64 d1 = values.begin()->size();
-  const int64 d2 = values.begin()->begin()->size();
-  literal.get()->PopulateWithValue<NativeT>(0, {d0, d1, d2});
-  *literal->mutable_shape()->mutable_layout() =
-      LayoutUtil::MakeLayout(minor_to_major);
-  TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
-
-  int64 dim0 = 0;
-  for (auto inner_list : values) {
-    int64 dim1 = 0;
-    for (auto inner_inner_list : inner_list) {
-      int64 dim2 = 0;
-      for (auto value : inner_inner_list) {
-        literal.get()->Set({dim0, dim1, dim2}, value);
-        ++dim2;
-      }
-      ++dim1;
-    }
-    ++dim0;
-  }
-  return literal;
-}
-
-}  // namespace test_utils
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index 759921dce5a..091fa0c3ec8 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -88,6 +88,7 @@ cc_library(
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/client/lib:testing",
         "//tensorflow/compiler/xla/service:session_proto",
+        "//tensorflow/compiler/xla/tests:test_utils",
         "//tensorflow/core:framework_internal",
         "//tensorflow/core:lib",
     ],
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index 89b26b8916b..503e7d456e1 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -45,6 +45,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/shape_util.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/core/threadpool.h"

From 481739daad1bc92225da29bb7a65ced6a9a52303 Mon Sep 17 00:00:00 2001
From: Olivia Nordquist <nolivia@google.com>
Date: Wed, 8 Nov 2017 15:57:27 -0800
Subject: [PATCH 056/115] allows tf.Print to print empty data list and changes
 a noop test in function_test.py to verify that it doesn't raise a ValueError
 as an empty list would have previously

PiperOrigin-RevId: 175079527
---
 tensorflow/core/ops/logging_ops.cc           | 2 +-
 tensorflow/python/framework/function_test.py | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index 11cb9861a39..e6995821df7 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -43,7 +43,7 @@ REGISTER_OP("Print")
     .Output("output: T")
     .SetIsStateful()
     .Attr("T: type")
-    .Attr("U: list(type)")
+    .Attr("U: list(type) >= 0")
     .Attr("message: string = ''")
     .Attr("first_n: int = -1")
     .Attr("summarize: int = 3")
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 36b0737cfca..ba43e9199b4 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -370,7 +370,7 @@ class FunctionTest(test.TestCase):
 
     @function.Defun(dtypes.float32)
     def Foo(x):
-      y = logging_ops.Print(x, [x], "Hello")
+      y = logging_ops.Print(x, [], "Hello")
       with ops.control_dependencies([y]):
         z = control_flow_ops.no_op()
       with ops.control_dependencies([z]):

From 1ff2d1377753c1ae74eca7b0705fce2775195cbe Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 15:59:24 -0800
Subject: [PATCH 057/115] Add Baseline Estimators to core TensorFlow

PiperOrigin-RevId: 175079784
---
 tensorflow/python/estimator/BUILD             |   65 +
 .../python/estimator/canned/baseline.py       |  349 ++++
 .../python/estimator/canned/baseline_test.py  | 1545 +++++++++++++++++
 tensorflow/python/estimator/estimator_lib.py  |    4 +
 ...rflow.estimator.-baseline-classifier.pbtxt |   54 +
 ...orflow.estimator.-baseline-regressor.pbtxt |   54 +
 .../api/golden/tensorflow.estimator.pbtxt     |    8 +
 7 files changed, 2079 insertions(+)
 create mode 100644 tensorflow/python/estimator/canned/baseline.py
 create mode 100644 tensorflow/python/estimator/canned/baseline_test.py
 create mode 100644 tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
 create mode 100644 tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt

diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 26f1fd888a0..dba77617008 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -25,6 +25,7 @@ py_library(
     srcs = ["estimator_lib.py"],
     srcs_version = "PY2AND3",
     deps = [
+        ":baseline",
         ":dnn",
         ":dnn_linear_combined",
         ":estimator",
@@ -186,6 +187,70 @@ py_test(
     ],
 )
 
+py_library(
+    name = "baseline",
+    srcs = ["canned/baseline.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":estimator",
+        ":head",
+        ":model_fn",
+        ":optimizers",
+        "//tensorflow/python:init_ops",
+        "//tensorflow/python:layers",
+        "//tensorflow/python:nn",
+        "//tensorflow/python:partitioned_variables",
+        "//tensorflow/python:summary",
+        "//tensorflow/python:training",
+        "//tensorflow/python:variable_scope",
+        "//tensorflow/python/feature_column",
+        "@six_archive//:six",
+    ],
+)
+
+py_test(
+    name = "baseline_test",
+    size = "medium",
+    srcs = ["canned/baseline_test.py"],
+    srcs_version = "PY2AND3",
+    tags = [
+        "no_pip",
+        "notsan",  # b/67510291
+    ],
+    deps = [
+        ":baseline",
+        ":estimator",
+        ":export_export",
+        ":metric_keys",
+        ":numpy_io",
+        ":pandas_io",
+        ":run_config",
+        "//tensorflow/core:protos_all_py",
+        "//tensorflow/python:array_ops",
+        "//tensorflow/python:check_ops",
+        "//tensorflow/python:client",
+        "//tensorflow/python:client_testlib",
+        "//tensorflow/python:constant_op",
+        "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:data_flow_ops",
+        "//tensorflow/python:dtypes",
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:parsing_ops",
+        "//tensorflow/python:platform",
+        "//tensorflow/python:sparse_tensor",
+        "//tensorflow/python:state_ops",
+        "//tensorflow/python:summary",
+        "//tensorflow/python:training",
+        "//tensorflow/python:variable_scope",
+        "//tensorflow/python:variables",
+        "//tensorflow/python/feature_column",
+        "//third_party/py/numpy",
+        "//third_party/py/pandas",
+        "@six_archive//:six",
+    ],
+)
+
 py_library(
     name = "dnn",
     srcs = ["canned/dnn.py"],
diff --git a/tensorflow/python/estimator/canned/baseline.py b/tensorflow/python/estimator/canned/baseline.py
new file mode 100644
index 00000000000..96e4ecd29fb
--- /dev/null
+++ b/tensorflow/python/estimator/canned/baseline.py
@@ -0,0 +1,349 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Baseline estimators.
+
+Baseline estimators are bias-only estimators that can be used for debugging
+and as simple baselines.
+
+Example:
+
+```
+# Build BaselineClassifier
+classifier = BaselineClassifier(n_classes=3)
+
+# Input builders
+def input_fn_train: # returns x, y (where y represents label's class index).
+  pass
+
+def input_fn_eval: # returns x, y (where y represents label's class index).
+  pass
+
+# Fit model.
+classifier.train(input_fn=input_fn_train)
+
+# Evaluate cross entropy between the test and train labels.
+loss = classifier.evaluate(input_fn=input_fn_eval)["loss"]
+
+# predict outputs the probability distribution of the classes as seen in
+# training.
+predictions = classifier.predict(new_samples)
+```
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.estimator.canned import optimizers
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import training_util
+
+# The default learning rate of 0.3 is a historical artifact of the initial
+# implementation, but seems a reasonable choice.
+_LEARNING_RATE = 0.3
+
+
+def _get_weight_column_key(weight_column):
+  if weight_column is None:
+    return None
+  if isinstance(weight_column, six.string_types):
+    return weight_column
+  if not isinstance(weight_column, feature_column_lib._NumericColumn):  # pylint: disable=protected-access
+    raise TypeError('Weight column must be either a string or _NumericColumn.'
+                    ' Given type: {}.'.format(type(weight_column)))
+  return weight_column.key()
+
+
+def _baseline_logit_fn_builder(num_outputs, weight_column=None):
+  """Function builder for a baseline logit_fn.
+
+  Args:
+    num_outputs: Number of outputs for the model.
+    weight_column: A string or a `_NumericColumn` created by
+      `tf.feature_column.numeric_column` defining feature column representing
+       weights. It will be multiplied by the loss of the example.
+  Returns:
+    A logit_fn (see below).
+  """
+
+  def baseline_logit_fn(features):
+    """Baseline model logit_fn.
+
+    The baseline model simply learns a bias, so the output logits are a
+    `Variable` with one weight for each output that learns the bias for the
+    corresponding output.
+
+    Args:
+      features: The first item returned from the `input_fn` passed to `train`,
+        `evaluate`, and `predict`. This should be a single `Tensor` or dict with
+        `Tensor` values.
+    Returns:
+      A `Tensor` representing the logits.
+    """
+    size_checks = []
+    batch_size = None
+
+    weight_column_key = _get_weight_column_key(weight_column)
+
+    # The first dimension is assumed to be a batch size and must be consistent
+    # among all of the features.
+    for key, feature in features.items():
+      # Skip weight_column to ensure we don't add size checks to it.
+      # These would introduce a dependency on the weight at serving time.
+      if key == weight_column_key:
+        continue
+      first_dim = array_ops.shape(feature)[0]
+      if batch_size is None:
+        batch_size = first_dim
+      else:
+        size_checks.append(check_ops.assert_equal(batch_size, first_dim))
+
+    with ops.control_dependencies(size_checks):
+      with variable_scope.variable_scope('baseline'):
+        bias = variable_scope.get_variable('bias', shape=[num_outputs],
+                                           initializer=init_ops.Zeros)
+        return math_ops.multiply(bias, array_ops.ones([batch_size,
+                                                       num_outputs]))
+
+  return baseline_logit_fn
+
+
+def _baseline_model_fn(features, labels, mode, head, optimizer,
+                       weight_column=None, config=None):
+  """Model_fn for baseline models.
+
+  Args:
+    features: `Tensor` or dict of `Tensor` (depends on data passed to `train`).
+    labels: `Tensor` of labels that are compatible with the `Head` instance.
+    mode: Defines whether this is training, evaluation or prediction.
+      See `ModeKeys`.
+    head: A `Head` instance.
+    optimizer: String, `tf.Optimizer` object, or callable that creates the
+      optimizer to use for training. If not specified, will use `FtrlOptimizer`
+      with a default learning rate of 0.3.
+    weight_column: A string or a `_NumericColumn` created by
+      `tf.feature_column.numeric_column` defining feature column representing
+       weights. It will be multiplied by the loss of the example.
+    config: `RunConfig` object to configure the runtime settings.
+
+  Raises:
+    KeyError: If weight column is specified but not present.
+    ValueError: If features is an empty dictionary.
+
+  Returns:
+    An `EstimatorSpec` instance.
+  """
+  del config  # Unused.
+
+  logit_fn = _baseline_logit_fn_builder(head.logits_dimension, weight_column)
+  logits = logit_fn(features)
+
+  def train_op_fn(loss):
+    opt = optimizers.get_optimizer_instance(
+        optimizer, learning_rate=_LEARNING_RATE)
+    return opt.minimize(loss, global_step=training_util.get_global_step())
+
+  return head.create_estimator_spec(
+      features=features,
+      mode=mode,
+      logits=logits,
+      labels=labels,
+      train_op_fn=train_op_fn)
+
+
+class BaselineClassifier(estimator.Estimator):
+  """A classifier that can establish a simple baseline.
+
+  This classifier ignores feature values and will learn to predict the average
+  value of each label. For single-label problems, this will predict the
+  probability distribution of the classes as seen in the labels. For multi-label
+  problems, this will predict the fraction of examples that are positive for
+  each class.
+
+  Example:
+
+  ```python
+
+  # Build BaselineClassifier
+  classifier = BaselineClassifier(n_classes=3)
+
+  # Input builders
+  def input_fn_train: # returns x, y (where y represents label's class index).
+    pass
+
+  def input_fn_eval: # returns x, y (where y represents label's class index).
+    pass
+
+  # Fit model.
+  classifier.train(input_fn=input_fn_train)
+
+  # Evaluate cross entropy between the test and train labels.
+  loss = classifier.evaluate(input_fn=input_fn_eval)["loss"]
+
+  # predict outputs the probability distribution of the classes as seen in
+  # training.
+  predictions = classifier.predict(new_samples)
+
+  ```
+
+  Input of `train` and `evaluate` should have following features,
+    otherwise there will be a `KeyError`:
+
+  * if `weight_column` is not `None`, a feature with
+     `key=weight_column` whose value is a `Tensor`.
+  """
+
+  def __init__(self,
+               model_dir=None,
+               n_classes=2,
+               weight_column=None,
+               label_vocabulary=None,
+               optimizer='Ftrl',
+               config=None):
+    """Initializes a BaselineClassifier instance.
+
+    Args:
+      model_dir: Directory to save model parameters, graph and etc. This can
+        also be used to load checkpoints from the directory into a estimator to
+        continue training a previously saved model.
+      n_classes: number of label classes. Default is binary classification.
+        It must be greater than 1. Note: Class labels are integers representing
+        the class index (i.e. values from 0 to n_classes-1). For arbitrary
+        label values (e.g. string labels), convert to class indices first.
+      weight_column: A string or a `_NumericColumn` created by
+        `tf.feature_column.numeric_column` defining feature column representing
+         weights. It will be multiplied by the loss of the example.
+      label_vocabulary: Optional list of strings with size `[n_classes]`
+        defining the label vocabulary. Only supported for `n_classes` > 2.
+      optimizer: String, `tf.Optimizer` object, or callable that creates the
+        optimizer to use for training. If not specified, will use
+        `FtrlOptimizer` with a default learning rate of 0.3.
+      config: `RunConfig` object to configure the runtime settings.
+    Returns:
+      A `BaselineClassifier` estimator.
+
+    Raises:
+      ValueError: If `n_classes` < 2.
+    """
+    if n_classes == 2:
+      head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(  # pylint: disable=protected-access
+          weight_column=weight_column,
+          label_vocabulary=label_vocabulary)
+    else:
+      head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(  # pylint: disable=protected-access
+          n_classes, weight_column=weight_column,
+          label_vocabulary=label_vocabulary)
+    def _model_fn(features, labels, mode, config):
+      return _baseline_model_fn(
+          features=features,
+          labels=labels,
+          mode=mode,
+          head=head,
+          optimizer=optimizer,
+          weight_column=weight_column,
+          config=config)
+    super(BaselineClassifier, self).__init__(
+        model_fn=_model_fn,
+        model_dir=model_dir,
+        config=config)
+
+
+class BaselineRegressor(estimator.Estimator):
+  """A regressor that can establish a simple baseline.
+
+  This regressor ignores feature values and will learn to predict the average
+  value of each label.
+
+  Example:
+
+  ```python
+
+  # Build BaselineRegressor
+  regressor = BaselineRegressor()
+
+  # Input builders
+  def input_fn_train: # returns x, y (where y is the label).
+    pass
+
+  def input_fn_eval: # returns x, y (where y is the label).
+    pass
+
+  # Fit model.
+  regressor.train(input_fn=input_fn_train)
+
+  # Evaluate squared-loss between the test and train targets.
+  loss = regressor.evaluate(input_fn=input_fn_eval)["loss"]
+
+  # predict outputs the mean value seen during training.
+  predictions = regressor.predict(new_samples)
+  ```
+
+  Input of `train` and `evaluate` should have following features,
+    otherwise there will be a `KeyError`:
+
+  * if `weight_column` is not `None`, a feature with
+     `key=weight_column` whose value is a `Tensor`.
+  """
+
+  def __init__(self,
+               model_dir=None,
+               label_dimension=1,
+               weight_column=None,
+               optimizer='Ftrl',
+               config=None):
+    """Initializes a BaselineRegressor instance.
+
+    Args:
+      model_dir: Directory to save model parameters, graph and etc. This can
+        also be used to load checkpoints from the directory into a estimator to
+        continue training a previously saved model.
+      label_dimension: Number of regression targets per example. This is the
+        size of the last dimension of the labels and logits `Tensor` objects
+        (typically, these have shape `[batch_size, label_dimension]`).
+      weight_column: A string or a `_NumericColumn` created by
+        `tf.feature_column.numeric_column` defining feature column representing
+         weights. It will be multiplied by the loss of the example.
+      optimizer: String, `tf.Optimizer` object, or callable that creates the
+        optimizer to use for training. If not specified, will use
+        `FtrlOptimizer` with a default learning rate of 0.3.
+      config: `RunConfig` object to configure the runtime settings.
+    Returns:
+      A `BaselineRegressor` estimator.
+    """
+
+    head = head_lib._regression_head_with_mean_squared_error_loss(  # pylint: disable=protected-access
+        label_dimension=label_dimension,
+        weight_column=weight_column)
+    def _model_fn(features, labels, mode, config):
+      return _baseline_model_fn(
+          features=features,
+          labels=labels,
+          mode=mode,
+          head=head,
+          optimizer=optimizer,
+          config=config)
+    super(BaselineRegressor, self).__init__(
+        model_fn=_model_fn,
+        model_dir=model_dir,
+        config=config)
diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py
new file mode 100644
index 00000000000..96639e88ea4
--- /dev/null
+++ b/tensorflow/python/estimator/canned/baseline_test.py
@@ -0,0 +1,1545 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for baseline.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os
+import shutil
+import tempfile
+
+import numpy as np
+import six
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.client import session as tf_session
+from tensorflow.python.estimator.canned import baseline
+from tensorflow.python.estimator.canned import metric_keys
+from tensorflow.python.estimator.export import export
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.estimator.inputs import pandas_io
+from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import input as input_lib
+from tensorflow.python.training import optimizer
+from tensorflow.python.training import queue_runner
+from tensorflow.python.training import saver
+
+
+try:
+  # pylint: disable=g-import-not-at-top
+  import pandas as pd
+  HAS_PANDAS = True
+except IOError:
+  # Pandas writes a temporary file during import. If it fails, don't use pandas.
+  HAS_PANDAS = False
+except ImportError:
+  HAS_PANDAS = False
+
+# pylint rules which are disabled by default for test files.
+# pylint: disable=invalid-name,protected-access,missing-docstring
+
+# Names of variables created by model.
+BIAS_NAME = 'baseline/bias'
+
+
+def assert_close(expected, actual, rtol=1e-04, name='assert_close'):
+  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:
+    expected = ops.convert_to_tensor(expected, name='expected')
+    actual = ops.convert_to_tensor(actual, name='actual')
+    rdiff = math_ops.abs(expected - actual, 'diff') / math_ops.abs(expected)
+    rtol = ops.convert_to_tensor(rtol, name='rtol')
+    return check_ops.assert_less(
+        rdiff,
+        rtol,
+        data=('Condition expected =~ actual did not hold element-wise:'
+              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,
+              'rtol = ', rtol,),
+        name=scope)
+
+
+def save_variables_to_ckpt(model_dir):
+  init_all_op = [variables.global_variables_initializer()]
+  with tf_session.Session() as sess:
+    sess.run(init_all_op)
+    saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
+
+
+def queue_parsed_features(feature_map):
+  tensors_to_enqueue = []
+  keys = []
+  for key, tensor in six.iteritems(feature_map):
+    keys.append(key)
+    tensors_to_enqueue.append(tensor)
+  queue_dtypes = [x.dtype for x in tensors_to_enqueue]
+  input_queue = data_flow_ops.FIFOQueue(capacity=100, dtypes=queue_dtypes)
+  queue_runner.add_queue_runner(
+      queue_runner.QueueRunner(input_queue,
+                               [input_queue.enqueue(tensors_to_enqueue)]))
+  dequeued_tensors = input_queue.dequeue()
+  return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
+
+
+def sorted_key_dict(unsorted_dict):
+  return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}
+
+
+def sigmoid(x):
+  return 1 / (1 + np.exp(-1.0 * x))
+
+
+def _baseline_regressor_fn(*args, **kwargs):
+  return baseline.BaselineRegressor(*args, **kwargs)
+
+
+def _baseline_classifier_fn(*args, **kwargs):
+  return baseline.BaselineClassifier(*args, **kwargs)
+
+
+# Tests for Baseline Regressor.
+
+
+# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.
+class BaselineRegressorEvaluationTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def test_evaluation_for_simple_data(self):
+    with ops.Graph().as_default():
+      variables.Variable([13.0], name=BIAS_NAME)
+      variables.Variable(
+          100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)
+    eval_metrics = baseline_regressor.evaluate(
+        input_fn=lambda: ({'age': ((1,),)}, ((10.,),)), steps=1)
+
+    # Logit is bias = 13, while label is 10. Loss is 3**2 = 9.
+    self.assertDictEqual({
+        metric_keys.MetricKeys.LOSS: 9.,
+        metric_keys.MetricKeys.LOSS_MEAN: 9.,
+        ops.GraphKeys.GLOBAL_STEP: 100
+    }, eval_metrics)
+
+  def test_evaluation_batch(self):
+    """Tests evaluation for batch_size==2."""
+    with ops.Graph().as_default():
+      variables.Variable([13.0], name=BIAS_NAME)
+      variables.Variable(
+          100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)
+    eval_metrics = baseline_regressor.evaluate(
+        input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1)
+
+    # Logit is bias = 13, while label is 10.
+    # Loss per example is 3**2 = 9.
+    # Training loss is the sum over batch = 9 + 9 = 18
+    # Average loss is the average over batch = 9
+    self.assertDictEqual({
+        metric_keys.MetricKeys.LOSS: 18.,
+        metric_keys.MetricKeys.LOSS_MEAN: 9.,
+        ops.GraphKeys.GLOBAL_STEP: 100
+    }, eval_metrics)
+
+  def test_evaluation_weights(self):
+    """Tests evaluation with weights."""
+    with ops.Graph().as_default():
+      variables.Variable([13.0], name=BIAS_NAME)
+      variables.Variable(
+          100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    def _input_fn():
+      features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))}
+      labels = ((10.,), (10.,))
+      return features, labels
+
+    baseline_regressor = _baseline_regressor_fn(
+        weight_column='weights',
+        model_dir=self._model_dir)
+    eval_metrics = baseline_regressor.evaluate(input_fn=_input_fn, steps=1)
+
+    # Logit is bias = 13, while label is 10.
+    # Loss per example is 3**2 = 9.
+    # Training loss is the weighted sum over batch = 9 + 2*9 = 27
+    # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9
+    self.assertDictEqual({
+        metric_keys.MetricKeys.LOSS: 27.,
+        metric_keys.MetricKeys.LOSS_MEAN: 9.,
+        ops.GraphKeys.GLOBAL_STEP: 100
+    }, eval_metrics)
+
+  def test_evaluation_for_multi_dimensions(self):
+    label_dim = 2
+    with ops.Graph().as_default():
+      variables.Variable([46.0, 58.0], name=BIAS_NAME)
+      variables.Variable(100, name='global_step', dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    baseline_regressor = _baseline_regressor_fn(
+        label_dimension=label_dim,
+        model_dir=self._model_dir)
+    input_fn = numpy_io.numpy_input_fn(
+        x={
+            'age': np.array([[2., 4., 5.]]),
+        },
+        y=np.array([[46., 58.]]),
+        batch_size=1,
+        num_epochs=None,
+        shuffle=False)
+    eval_metrics = baseline_regressor.evaluate(input_fn=input_fn, steps=1)
+
+    self.assertItemsEqual(
+        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,
+         ops.GraphKeys.GLOBAL_STEP), eval_metrics.keys())
+
+    # Logit is bias which is [46, 58]
+    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])
+
+
+class BaselineRegressorPredictTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def test_1d(self):
+    """Tests predict when all variables are one-dimensional."""
+    with ops.Graph().as_default():
+      variables.Variable([.2], name=BIAS_NAME)
+      variables.Variable(100, name='global_step', dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)
+
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x={'x': np.array([[2.]])},
+        y=None,
+        batch_size=1,
+        num_epochs=1,
+        shuffle=False)
+    predictions = baseline_regressor.predict(input_fn=predict_input_fn)
+    predicted_scores = list([x['predictions'] for x in predictions])
+    # x * weight + bias = 2. * 10. + .2 = 20.2
+    self.assertAllClose([[.2]], predicted_scores)
+
+  def testMultiDim(self):
+    """Tests predict when all variables are multi-dimenstional."""
+    batch_size = 2
+    label_dimension = 3
+    with ops.Graph().as_default():
+      variables.Variable(  # shape=[label_dimension]
+          [.2, .4, .6], name=BIAS_NAME)
+      variables.Variable(100, name='global_step', dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    baseline_regressor = _baseline_regressor_fn(
+        label_dimension=label_dimension,
+        model_dir=self._model_dir)
+
+    predict_input_fn = numpy_io.numpy_input_fn(
+        # x shape=[batch_size, x_dim]
+        x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])},
+        y=None,
+        batch_size=batch_size,
+        num_epochs=1,
+        shuffle=False)
+    predictions = baseline_regressor.predict(input_fn=predict_input_fn)
+    predicted_scores = list([x['predictions'] for x in predictions])
+    # score = bias, shape=[batch_size, label_dimension]
+    self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]],
+                        predicted_scores)
+
+
+class BaselineRegressorIntegrationTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+                          input_dimension, label_dimension, prediction_length):
+    feature_columns = [
+        feature_column_lib.numeric_column('x', shape=(input_dimension,))
+    ]
+    est = _baseline_regressor_fn(
+        label_dimension=label_dimension,
+        model_dir=self._model_dir)
+
+    # TRAIN
+    # learn y = x
+    est.train(train_input_fn, steps=200)
+
+    # EVALUTE
+    scores = est.evaluate(eval_input_fn)
+    self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP])
+    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))
+
+    # PREDICT
+    predictions = np.array(
+        [x['predictions'] for x in est.predict(predict_input_fn)])
+    self.assertAllEqual((prediction_length, label_dimension), predictions.shape)
+
+    # EXPORT
+    feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+        feature_spec)
+    export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+                                       serving_input_receiver_fn)
+    self.assertTrue(gfile.Exists(export_dir))
+
+  def test_numpy_input_fn(self):
+    """Tests complete flow with numpy_input_fn."""
+    label_dimension = 2
+    input_dimension = label_dimension
+    batch_size = 10
+    prediction_length = batch_size
+    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+    data = data.reshape(batch_size, label_dimension)
+
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data},
+        y=data,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    eval_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data},
+        y=data,
+        batch_size=batch_size,
+        num_epochs=1,
+        shuffle=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data},
+        y=None,
+        batch_size=batch_size,
+        num_epochs=1,
+        shuffle=False)
+
+    self._test_complete_flow(
+        train_input_fn=train_input_fn,
+        eval_input_fn=eval_input_fn,
+        predict_input_fn=predict_input_fn,
+        input_dimension=input_dimension,
+        label_dimension=label_dimension,
+        prediction_length=prediction_length)
+
+  def test_pandas_input_fn(self):
+    """Tests complete flow with pandas_input_fn."""
+    if not HAS_PANDAS:
+      return
+
+    # Pandas DataFrame natually supports 1 dim data only.
+    label_dimension = 1
+    input_dimension = label_dimension
+    batch_size = 10
+    data = np.array([1., 2., 3., 4.], dtype=np.float32)
+    x = pd.DataFrame({'x': data})
+    y = pd.Series(data)
+    prediction_length = 4
+
+    train_input_fn = pandas_io.pandas_input_fn(
+        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+    eval_input_fn = pandas_io.pandas_input_fn(
+        x=x, y=y, batch_size=batch_size, shuffle=False)
+    predict_input_fn = pandas_io.pandas_input_fn(
+        x=x, batch_size=batch_size, shuffle=False)
+
+    self._test_complete_flow(
+        train_input_fn=train_input_fn,
+        eval_input_fn=eval_input_fn,
+        predict_input_fn=predict_input_fn,
+        input_dimension=input_dimension,
+        label_dimension=label_dimension,
+        prediction_length=prediction_length)
+
+  def test_input_fn_from_parse_example(self):
+    """Tests complete flow with input_fn constructed from parse_example."""
+    label_dimension = 2
+    input_dimension = label_dimension
+    batch_size = 10
+    prediction_length = batch_size
+    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)
+    data = data.reshape(batch_size, label_dimension)
+
+    serialized_examples = []
+    for datum in data:
+      example = example_pb2.Example(features=feature_pb2.Features(
+          feature={
+              'x':
+                  feature_pb2.Feature(float_list=feature_pb2.FloatList(
+                      value=datum)),
+              'y':
+                  feature_pb2.Feature(float_list=feature_pb2.FloatList(
+                      value=datum[:label_dimension])),
+          }))
+      serialized_examples.append(example.SerializeToString())
+
+    feature_spec = {
+        'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
+        'y': parsing_ops.FixedLenFeature([label_dimension], dtypes.float32),
+    }
+
+    def _train_input_fn():
+      feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+      features = queue_parsed_features(feature_map)
+      labels = features.pop('y')
+      return features, labels
+
+    def _eval_input_fn():
+      feature_map = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      features = queue_parsed_features(feature_map)
+      labels = features.pop('y')
+      return features, labels
+
+    def _predict_input_fn():
+      feature_map = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      features = queue_parsed_features(feature_map)
+      features.pop('y')
+      return features, None
+
+    self._test_complete_flow(
+        train_input_fn=_train_input_fn,
+        eval_input_fn=_eval_input_fn,
+        predict_input_fn=_predict_input_fn,
+        input_dimension=input_dimension,
+        label_dimension=label_dimension,
+        prediction_length=prediction_length)
+
+
+class BaselineRegressorTrainingTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      writer_cache.FileWriterCache.clear()
+      shutil.rmtree(self._model_dir)
+
+  def _mock_optimizer(self, expected_loss=None):
+    expected_var_names = [
+        '%s:0' % BIAS_NAME
+    ]
+
+    def _minimize(loss, global_step=None, var_list=None):
+      trainable_vars = var_list or ops.get_collection(
+          ops.GraphKeys.TRAINABLE_VARIABLES)
+      self.assertItemsEqual(expected_var_names,
+                            [var.name for var in trainable_vars])
+
+      # Verify loss. We can't check the value directly, so we add an assert op.
+      self.assertEquals(0, loss.shape.ndims)
+      if expected_loss is None:
+        if global_step is not None:
+          return state_ops.assign_add(global_step, 1).op
+        return control_flow_ops.no_op()
+      assert_loss = assert_close(
+          math_ops.to_float(expected_loss, name='expected'),
+          loss,
+          name='assert_loss')
+      with ops.control_dependencies((assert_loss,)):
+        if global_step is not None:
+          return state_ops.assign_add(global_step, 1).op
+        return control_flow_ops.no_op()
+
+    mock_optimizer = test.mock.NonCallableMock(
+        spec=optimizer.Optimizer,
+        wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
+    mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize)
+
+    # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.
+    # So, return mock_optimizer itself for deepcopy.
+    mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
+    return mock_optimizer
+
+  def _assert_checkpoint(self,
+                         label_dimension,
+                         expected_global_step,
+                         expected_bias=None):
+    shapes = {
+        name: shape
+        for (name, shape) in checkpoint_utils.list_variables(self._model_dir)
+    }
+
+    self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
+    self.assertEqual(expected_global_step,
+                     checkpoint_utils.load_variable(self._model_dir,
+                                                    ops.GraphKeys.GLOBAL_STEP))
+
+    self.assertEqual([label_dimension], shapes[BIAS_NAME])
+    if expected_bias is not None:
+      self.assertEqual(expected_bias,
+                       checkpoint_utils.load_variable(self._model_dir,
+                                                      BIAS_NAME))
+
+  def testFromScratchWithDefaultOptimizer(self):
+    # Create BaselineRegressor.
+    label = 5.
+    age = 17
+    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)
+
+    # Train for a few steps, and validate final checkpoint.
+    num_steps = 10
+    baseline_regressor.train(
+        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+    self._assert_checkpoint(label_dimension=1, expected_global_step=num_steps)
+
+  def testTrainWithOneDimLabel(self):
+    label_dimension = 1
+    batch_size = 20
+    est = _baseline_regressor_fn(
+        label_dimension=label_dimension,
+        model_dir=self._model_dir)
+    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)
+    self.assertEqual((batch_size,), data_rank_1.shape)
+
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'age': data_rank_1},
+        y=data_rank_1,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    est.train(train_input_fn, steps=200)
+    self._assert_checkpoint(label_dimension=1, expected_global_step=200)
+
+  def testTrainWithOneDimWeight(self):
+    label_dimension = 1
+    batch_size = 20
+    est = _baseline_regressor_fn(
+        label_dimension=label_dimension,
+        weight_column='w',
+        model_dir=self._model_dir)
+
+    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)
+    self.assertEqual((batch_size,), data_rank_1.shape)
+
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'age': data_rank_1,
+           'w': data_rank_1},
+        y=data_rank_1,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    est.train(train_input_fn, steps=200)
+    self._assert_checkpoint(label_dimension=1, expected_global_step=200)
+
+  def testFromScratch(self):
+    # Create BaselineRegressor.
+    label = 5.
+    age = 17
+    # loss = (logits - label)^2 = (0 - 5.)^2 = 25.
+    mock_optimizer = self._mock_optimizer(expected_loss=25.)
+    baseline_regressor = _baseline_regressor_fn(
+        model_dir=self._model_dir,
+        optimizer=mock_optimizer)
+    self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+    # Train for a few steps, and validate optimizer and final checkpoint.
+    num_steps = 10
+    baseline_regressor.train(
+        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+    self.assertEqual(1, mock_optimizer.minimize.call_count)
+    self._assert_checkpoint(
+        label_dimension=1,
+        expected_global_step=num_steps,
+        expected_bias=[0.])
+
+  def testFromCheckpoint(self):
+    # Create initial checkpoint.
+    bias = 7.0
+    initial_global_step = 100
+    with ops.Graph().as_default():
+      variables.Variable([bias], name=BIAS_NAME)
+      variables.Variable(
+          initial_global_step,
+          name=ops.GraphKeys.GLOBAL_STEP,
+          dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    # logits = bias = 6.
+    # loss = (logits - label)^2 = (7 - 5)^2 = 4
+    mock_optimizer = self._mock_optimizer(expected_loss=4.)
+    baseline_regressor = _baseline_regressor_fn(
+        model_dir=self._model_dir,
+        optimizer=mock_optimizer)
+    self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+    # Train for a few steps, and validate optimizer and final checkpoint.
+    num_steps = 10
+    baseline_regressor.train(
+        input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps)
+    self.assertEqual(1, mock_optimizer.minimize.call_count)
+    self._assert_checkpoint(
+        label_dimension=1,
+        expected_global_step=initial_global_step + num_steps,
+        expected_bias=[bias])
+
+  def testFromCheckpointMultiBatch(self):
+    # Create initial checkpoint.
+    bias = 5.0
+    initial_global_step = 100
+    with ops.Graph().as_default():
+      variables.Variable([bias], name=BIAS_NAME)
+      variables.Variable(
+          initial_global_step,
+          name=ops.GraphKeys.GLOBAL_STEP,
+          dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    # logits = bias
+    # logits[0] = 5.
+    # logits[1] = 5.
+    # loss = sum(logits - label)^2 = (5 - 5)^2 + (5 - 3)^2 = 4
+    mock_optimizer = self._mock_optimizer(expected_loss=4.)
+    baseline_regressor = _baseline_regressor_fn(
+        model_dir=self._model_dir,
+        optimizer=mock_optimizer)
+    self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+    # Train for a few steps, and validate optimizer and final checkpoint.
+    num_steps = 10
+    baseline_regressor.train(
+        input_fn=lambda: ({'age': ((17,), (15,))}, ((5.,), (3.,))),
+        steps=num_steps)
+    self.assertEqual(1, mock_optimizer.minimize.call_count)
+    self._assert_checkpoint(
+        label_dimension=1,
+        expected_global_step=initial_global_step + num_steps,
+        expected_bias=bias)
+
+
+# Tests for Baseline Classifier.
+
+
+class BaselineClassifierTrainingTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      shutil.rmtree(self._model_dir)
+
+  def _mock_optimizer(self, expected_loss=None):
+    expected_var_names = [
+        '%s:0' % BIAS_NAME
+    ]
+
+    def _minimize(loss, global_step):
+      trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+      self.assertItemsEqual(
+          expected_var_names,
+          [var.name for var in trainable_vars])
+
+      # Verify loss. We can't check the value directly, so we add an assert op.
+      self.assertEquals(0, loss.shape.ndims)
+      if expected_loss is None:
+        return state_ops.assign_add(global_step, 1).op
+      assert_loss = assert_close(
+          math_ops.to_float(expected_loss, name='expected'),
+          loss,
+          name='assert_loss')
+      with ops.control_dependencies((assert_loss,)):
+        return state_ops.assign_add(global_step, 1).op
+
+    mock_optimizer = test.mock.NonCallableMock(
+        spec=optimizer.Optimizer,
+        wraps=optimizer.Optimizer(use_locking=False, name='my_optimizer'))
+    mock_optimizer.minimize = test.mock.MagicMock(wraps=_minimize)
+
+    # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.
+    # So, return mock_optimizer itself for deepcopy.
+    mock_optimizer.__deepcopy__ = lambda _: mock_optimizer
+    return mock_optimizer
+
+  def _assert_checkpoint(
+      self, n_classes, expected_global_step, expected_bias=None):
+    logits_dimension = n_classes if n_classes > 2 else 1
+
+    shapes = {
+        name: shape for (name, shape) in
+        checkpoint_utils.list_variables(self._model_dir)
+    }
+
+    self.assertEqual([], shapes[ops.GraphKeys.GLOBAL_STEP])
+    self.assertEqual(
+        expected_global_step,
+        checkpoint_utils.load_variable(
+            self._model_dir, ops.GraphKeys.GLOBAL_STEP))
+
+    self.assertEqual([logits_dimension], shapes[BIAS_NAME])
+    if expected_bias is not None:
+      self.assertAllEqual(expected_bias,
+                          checkpoint_utils.load_variable(
+                              self._model_dir, BIAS_NAME))
+
+  def _testFromScratchWithDefaultOptimizer(self, n_classes):
+    label = 0
+    age = 17
+    est = baseline.BaselineClassifier(
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+
+    # Train for a few steps, and validate final checkpoint.
+    num_steps = 10
+    est.train(
+        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+    self._assert_checkpoint(n_classes, num_steps)
+
+  def testBinaryClassesFromScratchWithDefaultOptimizer(self):
+    self._testFromScratchWithDefaultOptimizer(n_classes=2)
+
+  def testMultiClassesFromScratchWithDefaultOptimizer(self):
+    self._testFromScratchWithDefaultOptimizer(n_classes=4)
+
+  def _testTrainWithTwoDimsLabel(self, n_classes):
+    batch_size = 20
+
+    est = baseline.BaselineClassifier(
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+    data_rank_1 = np.array([0, 1])
+    data_rank_2 = np.array([[0], [1]])
+    self.assertEqual((2,), data_rank_1.shape)
+    self.assertEqual((2, 1), data_rank_2.shape)
+
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'age': data_rank_1},
+        y=data_rank_2,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    est.train(train_input_fn, steps=200)
+    self._assert_checkpoint(n_classes, 200)
+
+  def testBinaryClassesTrainWithTwoDimsLabel(self):
+    self._testTrainWithTwoDimsLabel(n_classes=2)
+
+  def testMultiClassesTrainWithTwoDimsLabel(self):
+    self._testTrainWithTwoDimsLabel(n_classes=4)
+
+  def _testTrainWithOneDimLabel(self, n_classes):
+    batch_size = 20
+
+    est = baseline.BaselineClassifier(
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+    data_rank_1 = np.array([0, 1])
+    self.assertEqual((2,), data_rank_1.shape)
+
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'age': data_rank_1},
+        y=data_rank_1,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    est.train(train_input_fn, steps=200)
+    self._assert_checkpoint(n_classes, 200)
+
+  def testBinaryClassesTrainWithOneDimLabel(self):
+    self._testTrainWithOneDimLabel(n_classes=2)
+
+  def testMultiClassesTrainWithOneDimLabel(self):
+    self._testTrainWithOneDimLabel(n_classes=4)
+
+  def _testTrainWithTwoDimsWeight(self, n_classes):
+    batch_size = 20
+
+    est = baseline.BaselineClassifier(
+        weight_column='w',
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+    data_rank_1 = np.array([0, 1])
+    data_rank_2 = np.array([[0], [1]])
+    self.assertEqual((2,), data_rank_1.shape)
+    self.assertEqual((2, 1), data_rank_2.shape)
+
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'age': data_rank_1, 'w': data_rank_2}, y=data_rank_1,
+        batch_size=batch_size, num_epochs=None,
+        shuffle=True)
+    est.train(train_input_fn, steps=200)
+    self._assert_checkpoint(n_classes, 200)
+
+  def testBinaryClassesTrainWithTwoDimsWeight(self):
+    self._testTrainWithTwoDimsWeight(n_classes=2)
+
+  def testMultiClassesTrainWithTwoDimsWeight(self):
+    self._testTrainWithTwoDimsWeight(n_classes=4)
+
+  def _testTrainWithOneDimWeight(self, n_classes):
+    batch_size = 20
+
+    est = baseline.BaselineClassifier(
+        weight_column='w',
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+    data_rank_1 = np.array([0, 1])
+    self.assertEqual((2,), data_rank_1.shape)
+
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'age': data_rank_1, 'w': data_rank_1}, y=data_rank_1,
+        batch_size=batch_size, num_epochs=None,
+        shuffle=True)
+    est.train(train_input_fn, steps=200)
+    self._assert_checkpoint(n_classes, 200)
+
+  def testBinaryClassesTrainWithOneDimWeight(self):
+    self._testTrainWithOneDimWeight(n_classes=2)
+
+  def testMultiClassesTrainWithOneDimWeight(self):
+    self._testTrainWithOneDimWeight(n_classes=4)
+
+  def _testFromScratch(self, n_classes):
+    label = 1
+    age = 17
+    # For binary classifier:
+    #   loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are
+    #   all zero initially) and label = 1 so,
+    #      loss = 1 * -log ( sigmoid(logits) ) = 0.69315
+    # For multi class classifier:
+    #   loss = cross_entropy(logits, label) where logits are all 0s (weights are
+    #   all zero initially) and label = 1 so,
+    #      loss = 1 * -log ( 1.0 / n_classes )
+    # For this particular test case, as logits are same, the formula
+    # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases.
+    mock_optimizer = self._mock_optimizer(
+        expected_loss=-1 * math.log(1.0/n_classes))
+
+    est = baseline.BaselineClassifier(
+        n_classes=n_classes,
+        optimizer=mock_optimizer,
+        model_dir=self._model_dir)
+    self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+    # Train for a few steps, and validate optimizer and final checkpoint.
+    num_steps = 10
+    est.train(
+        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+    self.assertEqual(1, mock_optimizer.minimize.call_count)
+    self._assert_checkpoint(
+        n_classes,
+        expected_global_step=num_steps,
+        expected_bias=[0.] if n_classes == 2 else [.0] * n_classes)
+
+  def testBinaryClassesFromScratch(self):
+    self._testFromScratch(n_classes=2)
+
+  def testMultiClassesFromScratch(self):
+    self._testFromScratch(n_classes=4)
+
+  def _testFromCheckpoint(self, n_classes):
+    # Create initial checkpoint.
+    label = 1
+    age = 17
+    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+    initial_global_step = 100
+    with ops.Graph().as_default():
+      variables.Variable(bias, name=BIAS_NAME)
+      variables.Variable(
+          initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+          dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    # For binary classifier:
+    #   logits = bias = -1.
+    #   loss = sigmoid_cross_entropy(logits, label)
+    #   so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133
+    # For multi class classifier:
+    #   loss = cross_entropy(logits, label)
+    #   where logits = bias and label = 1
+    #   so, loss = 1 * -log ( softmax(logits)[1] )
+    if n_classes == 2:
+      expected_loss = 1.3133
+    else:
+      logits = bias
+      logits_exp = np.exp(logits)
+      softmax = logits_exp / logits_exp.sum()
+      expected_loss = -1 * math.log(softmax[label])
+
+    mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
+
+    est = baseline.BaselineClassifier(
+        n_classes=n_classes,
+        optimizer=mock_optimizer,
+        model_dir=self._model_dir)
+    self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+    # Train for a few steps, and validate optimizer and final checkpoint.
+    num_steps = 10
+    est.train(
+        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+    self.assertEqual(1, mock_optimizer.minimize.call_count)
+    self._assert_checkpoint(
+        n_classes,
+        expected_global_step=initial_global_step + num_steps,
+        expected_bias=bias)
+
+  def testBinaryClassesFromCheckpoint(self):
+    self._testFromCheckpoint(n_classes=2)
+
+  def testMultiClassesFromCheckpoint(self):
+    self._testFromCheckpoint(n_classes=4)
+
+  def _testFromCheckpointFloatLabels(self, n_classes):
+    """Tests float labels for binary classification."""
+    # Create initial checkpoint.
+    if n_classes > 2:
+      return
+    label = 0.8
+    age = 17
+    bias = [-1.0]
+    initial_global_step = 100
+    with ops.Graph().as_default():
+      variables.Variable(bias, name=BIAS_NAME)
+      variables.Variable(
+          initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+          dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    # logits = bias = -1.
+    # loss = sigmoid_cross_entropy(logits, label)
+    # => loss = -0.8 * log(sigmoid(-1)) -0.2 * log(sigmoid(+1)) = 1.1132617
+    mock_optimizer = self._mock_optimizer(expected_loss=1.1132617)
+
+    est = baseline.BaselineClassifier(
+        n_classes=n_classes,
+        optimizer=mock_optimizer,
+        model_dir=self._model_dir)
+    self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+    # Train for a few steps, and validate optimizer and final checkpoint.
+    num_steps = 10
+    est.train(
+        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)
+    self.assertEqual(1, mock_optimizer.minimize.call_count)
+
+  def testBinaryClassesFromCheckpointFloatLabels(self):
+    self._testFromCheckpointFloatLabels(n_classes=2)
+
+  def testMultiClassesFromCheckpointFloatLabels(self):
+    self._testFromCheckpointFloatLabels(n_classes=4)
+
+  def _testFromCheckpointMultiBatch(self, n_classes):
+    # Create initial checkpoint.
+    label = [1, 0]
+    age = [17, 18.5]
+    # For binary case, the expected weight has shape (1,1). For multi class
+    # case, the shape is (1, n_classes). In order to test the weights, set
+    # weights as 2.0 * range(n_classes).
+    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+    initial_global_step = 100
+    with ops.Graph().as_default():
+      variables.Variable(bias, name=BIAS_NAME)
+      variables.Variable(
+          initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+          dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    # For binary classifier:
+    #   logits = bias
+    #   logits[0] = -1.
+    #   logits[1] = -1.
+    #   loss = sigmoid_cross_entropy(logits, label)
+    #   so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133
+    #       loss[1] = (1 - 0) * -log ( 1- sigmoid(-1) ) = 0.3132
+    # For multi class classifier:
+    #   loss = cross_entropy(logits, label)
+    #   where logits = bias and label = [1, 0]
+    #   so, loss = 1 * -log ( softmax(logits)[label] )
+    if n_classes == 2:
+      expected_loss = (1.3133 + 0.3132)
+    else:
+      # Expand logits since batch_size=2
+      logits = bias * np.ones(shape=(2, 1))
+      logits_exp = np.exp(logits)
+      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
+      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
+      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
+      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
+      expected_loss = expected_loss_0 + expected_loss_1
+
+    mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
+
+    est = baseline.BaselineClassifier(
+        n_classes=n_classes,
+        optimizer=mock_optimizer,
+        model_dir=self._model_dir)
+    self.assertEqual(0, mock_optimizer.minimize.call_count)
+
+    # Train for a few steps, and validate optimizer and final checkpoint.
+    num_steps = 10
+    est.train(
+        input_fn=lambda: ({'age': (age)}, (label)),
+        steps=num_steps)
+    self.assertEqual(1, mock_optimizer.minimize.call_count)
+    self._assert_checkpoint(
+        n_classes,
+        expected_global_step=initial_global_step + num_steps,
+        expected_bias=bias)
+
+  def testBinaryClassesFromCheckpointMultiBatch(self):
+    self._testFromCheckpointMultiBatch(n_classes=2)
+
+  def testMultiClassesFromCheckpointMultiBatch(self):
+    self._testFromCheckpointMultiBatch(n_classes=4)
+
+
+class BaselineClassifierEvaluationTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      shutil.rmtree(self._model_dir)
+
+  def _test_evaluation_for_simple_data(self, n_classes):
+    label = 1
+    age = 1.
+
+    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+
+    with ops.Graph().as_default():
+      variables.Variable(bias, name=BIAS_NAME)
+      variables.Variable(
+          100, name=ops.GraphKeys.GLOBAL_STEP, dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    est = _baseline_classifier_fn(
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+    eval_metrics = est.evaluate(
+        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=1)
+
+    if n_classes == 2:
+      # Binary classes: loss = -log(sigmoid(-1)) = 1.3133
+      # Prediction = sigmoid(-1) = 0.2689
+      expected_metrics = {
+          metric_keys.MetricKeys.LOSS: 1.3133,
+          ops.GraphKeys.GLOBAL_STEP: 100,
+          metric_keys.MetricKeys.LOSS_MEAN: 1.3133,
+          metric_keys.MetricKeys.ACCURACY: 0.,
+          metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,
+          metric_keys.MetricKeys.LABEL_MEAN: 1.,
+          metric_keys.MetricKeys.ACCURACY_BASELINE: 1,
+          metric_keys.MetricKeys.AUC: 0.,
+          metric_keys.MetricKeys.AUC_PR: 1.,
+      }
+    else:
+      # Multi classes: loss = 1 * -log ( softmax(logits)[label] )
+      logits = bias
+      logits_exp = np.exp(logits)
+      softmax = logits_exp / logits_exp.sum()
+      expected_loss = -1 * math.log(softmax[label])
+
+      expected_metrics = {
+          metric_keys.MetricKeys.LOSS: expected_loss,
+          ops.GraphKeys.GLOBAL_STEP: 100,
+          metric_keys.MetricKeys.LOSS_MEAN: expected_loss,
+          metric_keys.MetricKeys.ACCURACY: 0.,
+      }
+
+    self.assertAllClose(sorted_key_dict(expected_metrics),
+                        sorted_key_dict(eval_metrics), rtol=1e-3)
+
+  def test_binary_classes_evaluation_for_simple_data(self):
+    self._test_evaluation_for_simple_data(n_classes=2)
+
+  def test_multi_classes_evaluation_for_simple_data(self):
+    self._test_evaluation_for_simple_data(n_classes=4)
+
+  def _test_evaluation_batch(self, n_classes):
+    """Tests evaluation for batch_size==2."""
+    label = [1, 0]
+    age = [17., 18.]
+    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+    initial_global_step = 100
+    with ops.Graph().as_default():
+      variables.Variable(bias, name=BIAS_NAME)
+      variables.Variable(
+          initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+          dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    est = _baseline_classifier_fn(
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+    eval_metrics = est.evaluate(
+        input_fn=lambda: ({'age': (age)}, (label)), steps=1)
+
+    if n_classes == 2:
+      # Logits are (-1., -1.) labels are (1, 0).
+      # Loss is
+      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133
+      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132
+      # Prediction = sigmoid(-1) = 0.2689
+      expected_loss = 1.3133 + 0.3132
+
+      expected_metrics = {
+          metric_keys.MetricKeys.LOSS: expected_loss,
+          ops.GraphKeys.GLOBAL_STEP: 100,
+          metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
+          metric_keys.MetricKeys.ACCURACY: 0.5,
+          metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,
+          metric_keys.MetricKeys.LABEL_MEAN: 0.5,
+          metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+          metric_keys.MetricKeys.AUC: 0.5,
+          metric_keys.MetricKeys.AUC_PR: 0.75,
+      }
+    else:
+      # Expand logits since batch_size=2
+      logits = bias * np.ones(shape=(2, 1))
+      logits_exp = np.exp(logits)
+      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
+      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
+      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
+      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
+      expected_loss = expected_loss_0 + expected_loss_1
+
+      expected_metrics = {
+          metric_keys.MetricKeys.LOSS: expected_loss,
+          ops.GraphKeys.GLOBAL_STEP: 100,
+          metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
+          metric_keys.MetricKeys.ACCURACY: 0.5,
+      }
+
+    self.assertAllClose(sorted_key_dict(expected_metrics),
+                        sorted_key_dict(eval_metrics), rtol=1e-3)
+
+  def test_binary_classes_evaluation_batch(self):
+    self._test_evaluation_batch(n_classes=2)
+
+  def test_multi_classes_evaluation_batch(self):
+    self._test_evaluation_batch(n_classes=4)
+
+  def _test_evaluation_weights(self, n_classes):
+    """Tests evaluation with weights."""
+
+    label = [1, 0]
+    age = [17., 18.]
+    weights = [1., 2.]
+    # For binary case, the expected weight has shape (1,1). For multi class
+    # case, the shape is (1, n_classes). In order to test the weights, set
+    # weights as 2.0 * range(n_classes).
+    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes
+    initial_global_step = 100
+    with ops.Graph().as_default():
+      variables.Variable(bias, name=BIAS_NAME)
+      variables.Variable(
+          initial_global_step, name=ops.GraphKeys.GLOBAL_STEP,
+          dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    est = _baseline_classifier_fn(
+        n_classes=n_classes,
+        weight_column='w',
+        model_dir=self._model_dir)
+    eval_metrics = est.evaluate(
+        input_fn=lambda: ({'age': (age), 'w': (weights)}, (label)), steps=1)
+
+    if n_classes == 2:
+      # Logits are (-1., -1.) labels are (1, 0).
+      # Loss is
+      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133
+      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132
+      #   weights = [1., 2.]
+      expected_loss = 1.3133 * 1. + 0.3132 * 2.
+      loss_mean = expected_loss / (1.0 + 2.0)
+      label_mean = np.average(label, weights=weights)
+      logits = [-1, -1]
+      logistics = sigmoid(np.array(logits))
+      predictions_mean = np.average(logistics, weights=weights)
+
+      expected_metrics = {
+          metric_keys.MetricKeys.LOSS: expected_loss,
+          ops.GraphKeys.GLOBAL_STEP: 100,
+          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,
+          metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.),
+          metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean,
+          metric_keys.MetricKeys.LABEL_MEAN: label_mean,
+          metric_keys.MetricKeys.ACCURACY_BASELINE: (
+              max(label_mean, 1-label_mean)),
+          metric_keys.MetricKeys.AUC: 0.5,
+          metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.),
+      }
+    else:
+      # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )
+      # Expand logits since batch_size=2
+      logits = bias * np.ones(shape=(2, 1))
+      logits_exp = np.exp(logits)
+      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()
+      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()
+      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])
+      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])
+      loss_mean = np.average([expected_loss_0, expected_loss_1],
+                             weights=weights)
+      expected_loss = loss_mean * np.sum(weights)
+
+      expected_metrics = {
+          metric_keys.MetricKeys.LOSS: expected_loss,
+          ops.GraphKeys.GLOBAL_STEP: 100,
+          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,
+          metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.),
+      }
+
+    self.assertAllClose(sorted_key_dict(expected_metrics),
+                        sorted_key_dict(eval_metrics), rtol=1e-3)
+
+  def test_binary_classes_evaluation_weights(self):
+    self._test_evaluation_weights(n_classes=2)
+
+  def test_multi_classes_evaluation_weights(self):
+    self._test_evaluation_weights(n_classes=4)
+
+
+class BaselineClassifierPredictTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      shutil.rmtree(self._model_dir)
+
+  def _testPredictions(self, n_classes, label_vocabulary, label_output_fn):
+    """Tests predict when all variables are one-dimensional."""
+    age = 1.
+
+    bias = [10.0] if n_classes == 2 else [10.0] * n_classes
+
+    with ops.Graph().as_default():
+      variables.Variable(bias, name=BIAS_NAME)
+      variables.Variable(100, name='global_step', dtype=dtypes.int64)
+      save_variables_to_ckpt(self._model_dir)
+
+    est = _baseline_classifier_fn(
+        label_vocabulary=label_vocabulary,
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x={'age': np.array([[age]])},
+        y=None,
+        batch_size=1,
+        num_epochs=1,
+        shuffle=False)
+    predictions = list(est.predict(input_fn=predict_input_fn))
+
+    if n_classes == 2:
+      scalar_logits = bias[0]
+      two_classes_logits = [0, scalar_logits]
+      two_classes_logits_exp = np.exp(two_classes_logits)
+      softmax = two_classes_logits_exp / two_classes_logits_exp.sum()
+
+      expected_predictions = {
+          'class_ids': [1],
+          'classes': [label_output_fn(1)],
+          'logistic': [sigmoid(np.array(scalar_logits))],
+          'logits': [scalar_logits],
+          'probabilities': softmax,
+      }
+    else:
+      onedim_logits = np.array(bias)
+      class_ids = onedim_logits.argmax()
+      logits_exp = np.exp(onedim_logits)
+      softmax = logits_exp / logits_exp.sum()
+      expected_predictions = {
+          'class_ids': [class_ids],
+          'classes': [label_output_fn(class_ids)],
+          'logits': onedim_logits,
+          'probabilities': softmax,
+      }
+
+    self.assertEqual(1, len(predictions))
+    # assertAllClose cannot handle byte type.
+    self.assertEqual(expected_predictions['classes'], predictions[0]['classes'])
+    expected_predictions.pop('classes')
+    predictions[0].pop('classes')
+    self.assertAllClose(sorted_key_dict(expected_predictions),
+                        sorted_key_dict(predictions[0]))
+
+  def testBinaryClassesWithoutLabelVocabulary(self):
+    n_classes = 2
+    self._testPredictions(n_classes,
+                          label_vocabulary=None,
+                          label_output_fn=lambda x: ('%s' % x).encode())
+
+  def testBinaryClassesWithLabelVocabulary(self):
+    n_classes = 2
+    self._testPredictions(
+        n_classes,
+        label_vocabulary=['class_vocab_{}'.format(i)
+                          for i in range(n_classes)],
+        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())
+
+  def testMultiClassesWithoutLabelVocabulary(self):
+    n_classes = 4
+    self._testPredictions(
+        n_classes,
+        label_vocabulary=None,
+        label_output_fn=lambda x: ('%s' % x).encode())
+
+  def testMultiClassesWithLabelVocabulary(self):
+    n_classes = 4
+    self._testPredictions(
+        n_classes,
+        label_vocabulary=['class_vocab_{}'.format(i)
+                          for i in range(n_classes)],
+        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())
+
+
+class BaselineClassifierIntegrationTest(test.TestCase):
+
+  def setUp(self):
+    self._model_dir = tempfile.mkdtemp()
+
+  def tearDown(self):
+    if self._model_dir:
+      shutil.rmtree(self._model_dir)
+
+  def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,
+                          predict_input_fn, input_dimension, prediction_length):
+    feature_columns = [
+        feature_column_lib.numeric_column('x', shape=(input_dimension,))
+    ]
+    est = _baseline_classifier_fn(
+        n_classes=n_classes,
+        model_dir=self._model_dir)
+
+    # TRAIN
+    # learn y = x
+    est.train(train_input_fn, steps=200)
+
+    # EVALUTE
+    scores = est.evaluate(eval_input_fn)
+    self.assertEqual(200, scores[ops.GraphKeys.GLOBAL_STEP])
+    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))
+
+    # PREDICT
+    predictions = np.array(
+        [x['classes'] for x in est.predict(predict_input_fn)])
+    self.assertAllEqual((prediction_length, 1), predictions.shape)
+
+    # EXPORT
+    feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+        feature_spec)
+    export_dir = est.export_savedmodel(tempfile.mkdtemp(),
+                                       serving_input_receiver_fn)
+    self.assertTrue(gfile.Exists(export_dir))
+
+  def _test_numpy_input_fn(self, n_classes):
+    """Tests complete flow with numpy_input_fn."""
+    input_dimension = 4
+    batch_size = 10
+    prediction_length = batch_size
+    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
+    data = data.reshape(batch_size, input_dimension)
+    target = np.array([1] * batch_size)
+
+    train_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data},
+        y=target,
+        batch_size=batch_size,
+        num_epochs=None,
+        shuffle=True)
+    eval_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data},
+        y=target,
+        batch_size=batch_size,
+        num_epochs=1,
+        shuffle=False)
+    predict_input_fn = numpy_io.numpy_input_fn(
+        x={'x': data},
+        y=None,
+        batch_size=batch_size,
+        num_epochs=1,
+        shuffle=False)
+
+    self._test_complete_flow(
+        n_classes=n_classes,
+        train_input_fn=train_input_fn,
+        eval_input_fn=eval_input_fn,
+        predict_input_fn=predict_input_fn,
+        input_dimension=input_dimension,
+        prediction_length=prediction_length)
+
+  def test_binary_classes_numpy_input_fn(self):
+    self._test_numpy_input_fn(n_classes=2)
+
+  def test_multi_classes_numpy_input_fn(self):
+    self._test_numpy_input_fn(n_classes=4)
+
+  def _test_pandas_input_fn(self, n_classes):
+    """Tests complete flow with pandas_input_fn."""
+    if not HAS_PANDAS:
+      return
+
+    # Pandas DataFrame natually supports 1 dim data only.
+    input_dimension = 1
+    batch_size = 10
+    data = np.array([1., 2., 3., 4.], dtype=np.float32)
+    target = np.array([1, 0, 1, 0], dtype=np.int32)
+    x = pd.DataFrame({'x': data})
+    y = pd.Series(target)
+    prediction_length = 4
+
+    train_input_fn = pandas_io.pandas_input_fn(
+        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)
+    eval_input_fn = pandas_io.pandas_input_fn(
+        x=x, y=y, batch_size=batch_size, shuffle=False)
+    predict_input_fn = pandas_io.pandas_input_fn(
+        x=x, batch_size=batch_size, shuffle=False)
+
+    self._test_complete_flow(
+        n_classes=n_classes,
+        train_input_fn=train_input_fn,
+        eval_input_fn=eval_input_fn,
+        predict_input_fn=predict_input_fn,
+        input_dimension=input_dimension,
+        prediction_length=prediction_length)
+
+  def test_binary_classes_pandas_input_fn(self):
+    self._test_pandas_input_fn(n_classes=2)
+
+  def test_multi_classes_pandas_input_fn(self):
+    self._test_pandas_input_fn(n_classes=4)
+
+  def _test_input_fn_from_parse_example(self, n_classes):
+    """Tests complete flow with input_fn constructed from parse_example."""
+    input_dimension = 2
+    batch_size = 10
+    prediction_length = batch_size
+    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)
+    data = data.reshape(batch_size, input_dimension)
+    target = np.array([1] * batch_size, dtype=np.int64)
+
+    serialized_examples = []
+    for x, y in zip(data, target):
+      example = example_pb2.Example(features=feature_pb2.Features(
+          feature={
+              'x':
+                  feature_pb2.Feature(float_list=feature_pb2.FloatList(
+                      value=x)),
+              'y':
+                  feature_pb2.Feature(int64_list=feature_pb2.Int64List(
+                      value=[y])),
+          }))
+      serialized_examples.append(example.SerializeToString())
+
+    feature_spec = {
+        'x': parsing_ops.FixedLenFeature([input_dimension], dtypes.float32),
+        'y': parsing_ops.FixedLenFeature([1], dtypes.int64),
+    }
+
+    def _train_input_fn():
+      feature_map = parsing_ops.parse_example(serialized_examples, feature_spec)
+      features = queue_parsed_features(feature_map)
+      labels = features.pop('y')
+      return features, labels
+
+    def _eval_input_fn():
+      feature_map = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      features = queue_parsed_features(feature_map)
+      labels = features.pop('y')
+      return features, labels
+
+    def _predict_input_fn():
+      feature_map = parsing_ops.parse_example(
+          input_lib.limit_epochs(serialized_examples, num_epochs=1),
+          feature_spec)
+      features = queue_parsed_features(feature_map)
+      features.pop('y')
+      return features, None
+
+    self._test_complete_flow(
+        n_classes=n_classes,
+        train_input_fn=_train_input_fn,
+        eval_input_fn=_eval_input_fn,
+        predict_input_fn=_predict_input_fn,
+        input_dimension=input_dimension,
+        prediction_length=prediction_length)
+
+  def test_binary_classes_input_fn_from_parse_example(self):
+    self._test_input_fn_from_parse_example(n_classes=2)
+
+  def test_multi_classes_input_fn_from_parse_example(self):
+    self._test_input_fn_from_parse_example(n_classes=4)
+
+
+# Tests for Baseline logit_fn.
+
+
+class BaselineLogitFnTest(test.TestCase):
+
+  def test_basic_logit_correctness(self):
+    """baseline_logit_fn simply returns the bias variable."""
+    with ops.Graph().as_default():
+      logit_fn = baseline._baseline_logit_fn_builder(num_outputs=2)
+      logits = logit_fn(features={'age': [[23.], [31.]]})
+      with variable_scope.variable_scope('baseline', reuse=True):
+        bias_var = variable_scope.get_variable('bias')
+      with tf_session.Session() as sess:
+        sess.run([variables.global_variables_initializer()])
+        self.assertAllClose([[0., 0.], [0., 0.]], logits.eval())
+        sess.run(bias_var.assign([10., 5.]))
+        self.assertAllClose([[10., 5.], [10., 5.]], logits.eval())
+
+
+if __name__ == '__main__':
+  test.main()
+
diff --git a/tensorflow/python/estimator/estimator_lib.py b/tensorflow/python/estimator/estimator_lib.py
index 5b82fd75ff3..bed2b674192 100644
--- a/tensorflow/python/estimator/estimator_lib.py
+++ b/tensorflow/python/estimator/estimator_lib.py
@@ -19,6 +19,8 @@ from __future__ import division
 from __future__ import print_function
 
 # pylint: disable=unused-import,line-too-long,wildcard-import
+from tensorflow.python.estimator.canned.baseline import BaselineClassifier
+from tensorflow.python.estimator.canned.baseline import BaselineRegressor
 from tensorflow.python.estimator.canned.dnn import DNNClassifier
 from tensorflow.python.estimator.canned.dnn import DNNRegressor
 from tensorflow.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedClassifier
@@ -46,6 +48,8 @@ from tensorflow.python.util.all_util import remove_undocumented
 
 _allowed_symbols = [
     # Canned Estimators
+    'BaselineClassifier',
+    'BaselineRegressor',
     'DNNClassifier',
     'DNNRegressor',
     'DNNLinearCombinedClassifier',
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
new file mode 100644
index 00000000000..f5ed263f0e2
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
@@ -0,0 +1,54 @@
+path: "tensorflow.estimator.BaselineClassifier"
+tf_class {
+  is_instance: "<class \'tensorflow.python.estimator.canned.baseline.BaselineClassifier\'>"
+  is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "config"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "model_dir"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "model_fn"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "params"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'2\', \'None\', \'None\', \'Ftrl\', \'None\'], "
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "export_savedmodel"
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+  }
+  member_method {
+    name: "get_variable_names"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_variable_value"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "latest_checkpoint"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "train"
+    argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
new file mode 100644
index 00000000000..61a29942c57
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
@@ -0,0 +1,54 @@
+path: "tensorflow.estimator.BaselineRegressor"
+tf_class {
+  is_instance: "<class \'tensorflow.python.estimator.canned.baseline.BaselineRegressor\'>"
+  is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "config"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "model_dir"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "model_fn"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "params"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'model_dir\', \'label_dimension\', \'weight_column\', \'optimizer\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'Ftrl\', \'None\'], "
+  }
+  member_method {
+    name: "evaluate"
+    argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "export_savedmodel"
+    argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+  }
+  member_method {
+    name: "get_variable_names"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_variable_value"
+    argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "latest_checkpoint"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "predict"
+    argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "train"
+    argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt
index ef93a61bd84..cdc367b99e8 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.pbtxt
@@ -1,5 +1,13 @@
 path: "tensorflow.estimator"
 tf_module {
+  member {
+    name: "BaselineClassifier"
+    mtype: "<type \'type\'>"
+  }
+  member {
+    name: "BaselineRegressor"
+    mtype: "<type \'type\'>"
+  }
   member {
     name: "DNNClassifier"
     mtype: "<type \'type\'>"

From 7db94de969662cfc83b7152d57b23d6c57da0784 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 16:36:22 -0800
Subject: [PATCH 058/115] Update ops-related pbtxt files.

PiperOrigin-RevId: 175085154
---
 .../core/ops/compat/ops_history.v1.pbtxt      | 46 +++++++++++++++++++
 tensorflow/core/ops/ops.pbtxt                 |  1 -
 2 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index a4b5ca16af7..60f67543f1f 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -21670,6 +21670,52 @@ op {
   }
   is_stateful: true
 }
+op {
+  name: "Print"
+  input_arg {
+    name: "input"
+    type_attr: "T"
+  }
+  input_arg {
+    name: "data"
+    type_list_attr: "U"
+  }
+  output_arg {
+    name: "output"
+    type_attr: "T"
+  }
+  attr {
+    name: "T"
+    type: "type"
+  }
+  attr {
+    name: "U"
+    type: "list(type)"
+    has_minimum: true
+  }
+  attr {
+    name: "message"
+    type: "string"
+    default_value {
+      s: ""
+    }
+  }
+  attr {
+    name: "first_n"
+    type: "int"
+    default_value {
+      i: -1
+    }
+  }
+  attr {
+    name: "summarize"
+    type: "int"
+    default_value {
+      i: 3
+    }
+  }
+  is_stateful: true
+}
 op {
   name: "PriorityQueue"
   output_arg {
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 8353b45e225..2a74c207076 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -16977,7 +16977,6 @@ op {
     name: "U"
     type: "list(type)"
     has_minimum: true
-    minimum: 1
   }
   attr {
     name: "message"

From ecb3557621229deaebec209629d154c37da7f9d3 Mon Sep 17 00:00:00 2001
From: Igor Ganichev <iga@google.com>
Date: Wed, 8 Nov 2017 16:44:37 -0800
Subject: [PATCH 059/115] Make assert_equal/_none_equal/_less ops work in eager
 mode

Also, fix documentation of eager mode execute() method and
make tf_should_use work with empty list returned by execute()

RELNOTES: tf.assert_equal no longer raises ValueError. It now raises InvalidArgumentError, as documented.
PiperOrigin-RevId: 175086223
---
 tensorflow/python/eager/execute.py            |   3 +-
 tensorflow/python/kernel_tests/BUILD          |   1 +
 .../python/kernel_tests/check_ops_test.py     | 307 ++++++++++++------
 tensorflow/python/ops/check_ops.py            |  79 ++++-
 tensorflow/python/ops/control_flow_ops.py     |  41 ++-
 tensorflow/python/util/tf_should_use.py       |   2 +-
 6 files changed, 314 insertions(+), 119 deletions(-)

diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py
index 983c1ea73e5..c6457232e91 100644
--- a/tensorflow/python/eager/execute.py
+++ b/tensorflow/python/eager/execute.py
@@ -47,8 +47,7 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
     name: Customized name for the operation.
 
   Returns:
-    None if there are no outputs, a single Tensor object if there is one output
-    and a list of Tensor objects if there are multiple outputs.
+    List of output Tensor objects. The list is empty if there are no outputs
 
   Raises:
     An exception on error.
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 7fa504e85ed..8d6f863a4c0 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1186,6 +1186,7 @@ cuda_py_test(
     srcs = ["check_ops_test.py"],
     additional_deps = [
         "//third_party/py/numpy",
+        "//tensorflow/python/eager:context",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:check_ops",
         "//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index ed859e37741..43785adceec 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -20,10 +20,13 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.platform import test
@@ -71,110 +74,178 @@ class AssertProperIterableTest(test.TestCase):
 
 class AssertEqualTest(test.TestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_equal(self):
-    with self.test_session():
-      small = constant_op.constant([1, 2], name="small")
-      with ops.control_dependencies([check_ops.assert_equal(small, small)]):
-        out = array_ops.identity(small)
-      out.eval()
+    small = constant_op.constant([1, 2], name="small")
+    with ops.control_dependencies([check_ops.assert_equal(small, small)]):
+      out = array_ops.identity(small)
+    self.evaluate(out)
 
+  def test_returns_none_with_eager(self):
+    with context.eager_mode():
+      small = constant_op.constant([1, 2], name="small")
+      x = check_ops.assert_equal(small, small)
+      assert x is None
+
+  @test_util.run_in_graph_and_eager_modes()
   def test_raises_when_greater(self):
-    with self.test_session():
-      # Static check
-      static_small = constant_op.constant([1, 2], name="small")
-      static_big = constant_op.constant([3, 4], name="big")
-      with self.assertRaisesRegexp(ValueError, "fail"):
-        check_ops.assert_equal(static_big, static_small, message="fail")
-      # Dynamic check
-      small = array_ops.placeholder(dtypes.int32, name="small")
-      big = array_ops.placeholder(dtypes.int32, name="big")
-      with ops.control_dependencies(
-          [check_ops.assert_equal(
-              big, small, message="fail")]):
-        out = array_ops.identity(small)
-      with self.assertRaisesOpError("fail.*big.*small"):
-        out.eval(feed_dict={small: [1, 2], big: [3, 4]})
+    # Static check
+    static_small = constant_op.constant([1, 2], name="small")
+    static_big = constant_op.constant([3, 4], name="big")
+    with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
+      check_ops.assert_equal(static_big, static_small, message="fail")
 
+    # Dynamic check
+    if context.in_graph_mode():
+      with self.test_session():
+        small = array_ops.placeholder(dtypes.int32, name="small")
+        big = array_ops.placeholder(dtypes.int32, name="big")
+        with ops.control_dependencies(
+            [check_ops.assert_equal(
+                big, small, message="fail")]):
+          out = array_ops.identity(small)
+        with self.assertRaisesOpError("fail.*big.*small"):
+          out.eval(feed_dict={small: [1, 2], big: [3, 4]})
+
+  def test_error_message_eager(self):
+    expected_error_msg_full = r"""big does not equal small
+Condition x == y did not hold.
+Indices of first 6 different values:
+\[\[0 0\]
+ \[1 1\]
+ \[2 0\]\]
+Corresponding x values:
+\[2 3 6\]
+Corresponding y values:
+\[20 30 60\]
+First 6 elements of x:
+\[2 2 3 3 6 6\]
+First 6 elements of y:
+\[20  2  3 30 60  6\]
+"""
+    expected_error_msg_short = r"""big does not equal small
+Condition x == y did not hold.
+Indices of first 2 different values:
+\[\[0 0\]
+ \[1 1\]\]
+Corresponding x values:
+\[2 3\]
+Corresponding y values:
+\[20 30\]
+First 2 elements of x:
+\[2 2\]
+First 2 elements of y:
+\[20  2\]
+"""
+    with context.eager_mode():
+      big = constant_op.constant([[2, 2], [3, 3], [6, 6]])
+      small = constant_op.constant([[20, 2], [3, 30], [60, 6]])
+      with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                   expected_error_msg_full):
+        check_ops.assert_equal(big, small, message="big does not equal small",
+                               summarize=10)
+      with self.assertRaisesRegexp(errors.InvalidArgumentError,
+                                   expected_error_msg_short):
+        check_ops.assert_equal(big, small, message="big does not equal small",
+                               summarize=2)
+
+  @test_util.run_in_graph_and_eager_modes()
   def test_raises_when_less(self):
-    with self.test_session():
-      # Static check
-      static_small = constant_op.constant([3, 1], name="small")
-      static_big = constant_op.constant([4, 2], name="big")
-      with self.assertRaisesRegexp(ValueError, "fail"):
-        check_ops.assert_equal(static_big, static_small, message="fail")
-      # Dynamic check
-      small = array_ops.placeholder(dtypes.int32, name="small")
-      big = array_ops.placeholder(dtypes.int32, name="big")
-      with ops.control_dependencies([check_ops.assert_equal(small, big)]):
-        out = array_ops.identity(small)
-      with self.assertRaisesOpError("small.*big"):
-        out.eval(feed_dict={small: [3, 1], big: [4, 2]})
+    # Static check
+    static_small = constant_op.constant([3, 1], name="small")
+    static_big = constant_op.constant([4, 2], name="big")
+    with self.assertRaisesRegexp(errors.InvalidArgumentError, "fail"):
+      check_ops.assert_equal(static_big, static_small, message="fail")
 
+    # Dynamic check
+    if context.in_graph_mode():
+      with self.test_session():
+        small = array_ops.placeholder(dtypes.int32, name="small")
+        big = array_ops.placeholder(dtypes.int32, name="big")
+        with ops.control_dependencies([check_ops.assert_equal(small, big)]):
+          out = array_ops.identity(small)
+        with self.assertRaisesOpError("small.*big"):
+          out.eval(feed_dict={small: [3, 1], big: [4, 2]})
+
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):
-    with self.test_session():
-      small = constant_op.constant([1, 2], name="small")
-      small_2 = constant_op.constant([1, 2], name="small_2")
+    small = constant_op.constant([[1, 2], [1, 2]], name="small")
+    small_2 = constant_op.constant([1, 2], name="small_2")
+    with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
+      out = array_ops.identity(small)
+    self.evaluate(out)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def test_raises_when_equal_but_non_broadcastable_shapes(self):
+    small = constant_op.constant([1, 1, 1], name="small")
+    small_2 = constant_op.constant([1, 1], name="small_2")
+    # The exception in eager and non-eager mode is different because
+    # eager mode relies on shape check done as part of the C++ op, while
+    # graph mode does shape checks when creating the `Operation` instance.
+    with self.assertRaisesRegexp(
+        (errors.InvalidArgumentError, ValueError),
+        (r"Incompatible shapes: \[3\] vs. \[2\]|"
+         r"Dimensions must be equal, but are 3 and 2")):
       with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
         out = array_ops.identity(small)
-      out.eval()
-
-  def test_raises_when_equal_but_non_broadcastable_shapes(self):
-    with self.test_session():
-      small = constant_op.constant([1, 1, 1], name="small")
-      small_2 = constant_op.constant([1, 1], name="small_2")
-      with self.assertRaisesRegexp(ValueError, "must be"):
-        with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
-          out = array_ops.identity(small)
-        out.eval()
+      self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_both_empty(self):
-    with self.test_session():
-      larry = constant_op.constant([])
-      curly = constant_op.constant([])
-      with ops.control_dependencies([check_ops.assert_equal(larry, curly)]):
-        out = array_ops.identity(larry)
-      out.eval()
+    larry = constant_op.constant([])
+    curly = constant_op.constant([])
+    with ops.control_dependencies([check_ops.assert_equal(larry, curly)]):
+      out = array_ops.identity(larry)
+    self.evaluate(out)
 
 
 class AssertNoneEqualTest(test.TestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_not_equal(self):
-    with self.test_session():
-      small = constant_op.constant([1, 2], name="small")
-      big = constant_op.constant([10, 20], name="small")
-      with ops.control_dependencies(
-          [check_ops.assert_none_equal(big, small)]):
-        out = array_ops.identity(small)
-      out.eval()
+    small = constant_op.constant([1, 2], name="small")
+    big = constant_op.constant([10, 20], name="small")
+    with ops.control_dependencies(
+        [check_ops.assert_none_equal(big, small)]):
+      out = array_ops.identity(small)
+    self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_raises_when_equal(self):
-    with self.test_session():
-      small = constant_op.constant([3, 1], name="small")
+    small = constant_op.constant([3, 1], name="small")
+    with self.assertRaisesOpError("x != y did not hold"):
       with ops.control_dependencies(
           [check_ops.assert_none_equal(small, small)]):
         out = array_ops.identity(small)
-      with self.assertRaisesOpError("x != y did not hold"):
-        out.eval()
+      self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self):
-    with self.test_session():
-      small = constant_op.constant([1, 2], name="small")
-      big = constant_op.constant([3], name="big")
-      with ops.control_dependencies(
-          [check_ops.assert_none_equal(small, big)]):
-        out = array_ops.identity(small)
-      out.eval()
+    small = constant_op.constant([1, 2], name="small")
+    big = constant_op.constant([3], name="big")
+    with ops.control_dependencies(
+        [check_ops.assert_none_equal(small, big)]):
+      out = array_ops.identity(small)
+    self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
     with self.test_session():
       small = constant_op.constant([1, 1, 1], name="small")
       big = constant_op.constant([10, 10], name="big")
-      with self.assertRaisesRegexp(ValueError, "must be"):
+      # The exception in eager and non-eager mode is different because
+      # eager mode relies on shape check done as part of the C++ op, while
+      # graph mode does shape checks when creating the `Operation` instance.
+      with self.assertRaisesRegexp(
+          (ValueError, errors.InvalidArgumentError),
+          (r"Incompatible shapes: \[3\] vs. \[2\]|"
+           r"Dimensions must be equal, but are 3 and 2")):
         with ops.control_dependencies(
             [check_ops.assert_none_equal(small, big)]):
           out = array_ops.identity(small)
-        out.eval()
+        self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_both_empty(self):
     with self.test_session():
       larry = constant_op.constant([])
@@ -182,62 +253,82 @@ class AssertNoneEqualTest(test.TestCase):
       with ops.control_dependencies(
           [check_ops.assert_none_equal(larry, curly)]):
         out = array_ops.identity(larry)
-      out.eval()
+      self.evaluate(out)
+
+  def test_returns_none_with_eager(self):
+    with context.eager_mode():
+      t1 = constant_op.constant([1, 2])
+      t2 = constant_op.constant([3, 4])
+      x = check_ops.assert_none_equal(t1, t2)
+      assert x is None
 
 
 class AssertLessTest(test.TestCase):
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_raises_when_equal(self):
-    with self.test_session():
-      small = constant_op.constant([1, 2], name="small")
+    small = constant_op.constant([1, 2], name="small")
+    with self.assertRaisesOpError("failure message.*\n*.* x < y did not hold"):
       with ops.control_dependencies(
           [check_ops.assert_less(
-              small, small, message="fail")]):
+              small, small, message="failure message")]):
         out = array_ops.identity(small)
-      with self.assertRaisesOpError("fail.*small.*small"):
-        out.eval()
+      self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_raises_when_greater(self):
-    with self.test_session():
-      small = constant_op.constant([1, 2], name="small")
-      big = constant_op.constant([3, 4], name="big")
+    small = constant_op.constant([1, 2], name="small")
+    big = constant_op.constant([3, 4], name="big")
+    with self.assertRaisesOpError("x < y did not hold"):
       with ops.control_dependencies([check_ops.assert_less(big, small)]):
         out = array_ops.identity(small)
-      with self.assertRaisesOpError("big.*small"):
-        out.eval()
+      self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_less(self):
-    with self.test_session():
-      small = constant_op.constant([3, 1], name="small")
-      big = constant_op.constant([4, 2], name="big")
-      with ops.control_dependencies([check_ops.assert_less(small, big)]):
-        out = array_ops.identity(small)
-      out.eval()
+    small = constant_op.constant([3, 1], name="small")
+    big = constant_op.constant([4, 2], name="big")
+    with ops.control_dependencies([check_ops.assert_less(small, big)]):
+      out = array_ops.identity(small)
+    self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_less_and_broadcastable_shapes(self):
-    with self.test_session():
-      small = constant_op.constant([1], name="small")
-      big = constant_op.constant([3, 2], name="big")
+    small = constant_op.constant([1], name="small")
+    big = constant_op.constant([3, 2], name="big")
+    with ops.control_dependencies([check_ops.assert_less(small, big)]):
+      out = array_ops.identity(small)
+    self.evaluate(out)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def test_raises_when_less_but_non_broadcastable_shapes(self):
+    small = constant_op.constant([1, 1, 1], name="small")
+    big = constant_op.constant([3, 2], name="big")
+    # The exception in eager and non-eager mode is different because
+    # eager mode relies on shape check done as part of the C++ op, while
+    # graph mode does shape checks when creating the `Operation` instance.
+    with self.assertRaisesRegexp(
+        (ValueError, errors.InvalidArgumentError),
+        (r"Incompatible shapes: \[3\] vs. \[2\]|"
+         "Dimensions must be equal, but are 3 and 2")):
       with ops.control_dependencies([check_ops.assert_less(small, big)]):
         out = array_ops.identity(small)
-      out.eval()
-
-  def test_raises_when_less_but_non_broadcastable_shapes(self):
-    with self.test_session():
-      small = constant_op.constant([1, 1, 1], name="small")
-      big = constant_op.constant([3, 2], name="big")
-      with self.assertRaisesRegexp(ValueError, "must be"):
-        with ops.control_dependencies([check_ops.assert_less(small, big)]):
-          out = array_ops.identity(small)
-        out.eval()
+      self.evaluate(out)
 
+  @test_util.run_in_graph_and_eager_modes()
   def test_doesnt_raise_when_both_empty(self):
-    with self.test_session():
-      larry = constant_op.constant([])
-      curly = constant_op.constant([])
-      with ops.control_dependencies([check_ops.assert_less(larry, curly)]):
-        out = array_ops.identity(larry)
-      out.eval()
+    larry = constant_op.constant([])
+    curly = constant_op.constant([])
+    with ops.control_dependencies([check_ops.assert_less(larry, curly)]):
+      out = array_ops.identity(larry)
+    self.evaluate(out)
+
+  def test_returns_none_with_eager(self):
+    with context.eager_mode():
+      t1 = constant_op.constant([1, 2])
+      t2 = constant_op.constant([3, 4])
+      x = check_ops.assert_less(t1, t2)
+      assert x is None
 
 
 class AssertLessEqualTest(test.TestCase):
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index ceee009104c..7e509f72c15 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -48,6 +48,7 @@ import numpy as np
 
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_util
@@ -96,10 +97,11 @@ def _maybe_constant_value_string(t):
 
 
 def _assert_static(condition, data):
-  """Raises a static ValueError with as much information as possible."""
+  """Raises a InvalidArgumentError with as much information as possible."""
   if not condition:
     data_static = [_maybe_constant_value_string(x) for x in data]
-    raise ValueError('\n'.join(data_static))
+    raise errors.InvalidArgumentError(node_def=None, op=None,
+                                      message='\n'.join(data_static))
 
 
 def assert_proper_iterable(values):
@@ -303,11 +305,60 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
 
   Returns:
     Op that raises `InvalidArgumentError` if `x == y` is False.
+    @compatibility{eager} returns None
+
+  Raises:
+    InvalidArgumentError if the check can be performed immediately and
+    `x == y` is False. The check can be performed immediately during
+    eager execution or if `x` and `y` are statically known.
   """
   message = message or ''
   with ops.name_scope(name, 'assert_equal', [x, y, data]):
     x = ops.convert_to_tensor(x, name='x')
     y = ops.convert_to_tensor(y, name='y')
+
+    if context.in_eager_mode():
+      eq = math_ops.equal(x, y)
+      condition = math_ops.reduce_all(eq)
+      if not condition:
+        # Prepare a message with first elements of x and y
+        summary_msg = ''
+        if summarize:
+          # reshape((-1,)) is the fastest way to get a flat array view.
+          x_np = x.numpy().reshape((-1,))
+          y_np = y.numpy().reshape((-1,))
+          x_sum = min(x_np.size, summarize)
+          y_sum = min(y_np.size, summarize)
+          summary_msg = ('First %d elements of x:\n%s\n'
+                         'First %d elements of y:\n%s\n' %
+                         (x_sum, x_np[:x_sum],
+                          y_sum, y_np[:y_sum]))
+
+        # Get the values that actually differed and their indices
+        mask = math_ops.logical_not(eq)
+        indices = array_ops.where(mask)
+        indices_np = indices.numpy()
+        x_vals = array_ops.boolean_mask(x, mask)
+        y_vals = array_ops.boolean_mask(y, mask)
+        diff_to_print = 0
+        if summarize:
+          diff_to_print = min(summarize, indices_np.size)
+
+        raise errors.InvalidArgumentError(
+            node_def=None, op=None,
+            message=('%s\nCondition x == y did not hold.\n'
+                     'Indices of first %s different values:\n%s\n'
+                     'Corresponding x values:\n%s\n'
+                     'Corresponding y values:\n%s\n'
+                     '%s'
+                     %
+                     (message or '',
+                      diff_to_print, indices_np[:diff_to_print],
+                      x_vals.numpy().reshape((-1,))[:diff_to_print],
+                      y_vals.numpy().reshape((-1,))[:diff_to_print],
+                      summary_msg)))
+      return
+
     if data is None:
       data = [
           message,
@@ -356,12 +407,19 @@ def assert_none_equal(
   with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
     x = ops.convert_to_tensor(x, name='x')
     y = ops.convert_to_tensor(y, name='y')
+    if context.in_eager_mode():
+      x_name = 'x'
+      y_name = 'y'
+    else:
+      x_name = x.name
+      y_name = y.name
+
     if data is None:
       data = [
           message,
-          'Condition x != y did not hold for every single element:'
-          'x (%s) = ' % x.name, x,
-          'y (%s) = ' % y.name, y
+          'Condition x != y did not hold for every single element:',
+          'x (%s) = ' % x_name, x,
+          'y (%s) = ' % y_name, y
       ]
     condition = math_ops.reduce_all(math_ops.not_equal(x, y))
     return control_flow_ops.Assert(condition, data, summarize=summarize)
@@ -397,11 +455,18 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
   with ops.name_scope(name, 'assert_less', [x, y, data]):
     x = ops.convert_to_tensor(x, name='x')
     y = ops.convert_to_tensor(y, name='y')
+    if context.in_eager_mode():
+      x_name = 'x'
+      y_name = 'y'
+    else:
+      x_name = x.name
+      y_name = y.name
+
     if data is None:
       data = [
           message,
-          'Condition x < y did not hold element-wise:'
-          'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y
+          'Condition x < y did not hold element-wise:',
+          'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
       ]
     condition = math_ops.reduce_all(math_ops.less(x, y))
     return control_flow_ops.Assert(condition, data, summarize=summarize)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 10d8e013043..8afb079d20f 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -60,6 +60,7 @@ from tensorflow.core.protobuf import control_flow_pb2
 from tensorflow.python.eager import context
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
@@ -86,6 +87,29 @@ from tensorflow.python.util import tf_should_use
 _basetuple = tuple
 
 
+def _summarize_eager(tensor, summarize=None):
+  """Returns a summarized string representation of eager `tensor`.
+
+  Args:
+    tensor: EagerTensor to summarize
+    summarize: Include these many first elements of `array`
+  """
+  # reshape((-1,)) is the fastest way to get a flat array view
+  if tensor._rank():  # pylint: disable=protected-access
+    flat = tensor.numpy().reshape((-1,))
+    lst = [str(x) for x in flat[:summarize]]
+    if len(lst) < flat.size:
+      lst.append("...")
+  else:
+    # tensor.numpy() returns a scalar for zero dimensional arrays
+    if summarize != 0:
+      lst = [str(tensor.numpy())]
+    else:
+      lst = []
+
+  return ", ".join(lst)
+
+
 # pylint: disable=protected-access
 
 
@@ -98,7 +122,8 @@ def Assert(condition, data, summarize=None, name=None):
   If `condition` evaluates to false, print the list of tensors in `data`.
   `summarize` determines how many entries of the tensors to print.
 
-  NOTE: To ensure that Assert executes, one usually attaches a dependency:
+  NOTE: In graph mode, to ensure that Assert executes, one usually attaches
+  a dependency:
 
   ```python
   # Ensure maximum element of x is smaller or equal to 1
@@ -117,7 +142,21 @@ def Assert(condition, data, summarize=None, name=None):
     assert_op: An `Operation` that, when executed, raises a
     `tf.errors.InvalidArgumentError` if `condition` is not true.
     @compatibility{eager} returns None.
+
+  Raises:
+    @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
+    is not true
   """
+  if context.in_eager_mode():
+    if not condition:
+      xs = ops.convert_n_to_tensor(data)
+      data_str = [_summarize_eager(x, summarize) for x in xs]
+      raise errors.InvalidArgumentError(
+          node_def=None, op=None,
+          message="Expected '%s' to be true. Summarized data: %s" % (
+              condition, "\n".join(data_str)))
+    return
+
   with ops.name_scope(name, "Assert", [condition, data]) as name:
     xs = ops.convert_n_to_tensor(data)
     if all([x.dtype in {dtypes.string, dtypes.int32} for x in xs]):
diff --git a/tensorflow/python/util/tf_should_use.py b/tensorflow/python/util/tf_should_use.py
index a576547d5f2..37733152e8e 100644
--- a/tensorflow/python/util/tf_should_use.py
+++ b/tensorflow/python/util/tf_should_use.py
@@ -44,7 +44,7 @@ def _add_should_use_warning(x, fatal_error=False):
     and is a very shallow wrapper for `x` which logs access into `x`.
   """
   del fatal_error
-  if x is None:  # special corner case where x is None
+  if x is None or x == []:  # pylint: disable=g-explicit-bool-comparison
     return x
 
   if context.in_eager_mode():

From 04d3d4d3a70aed9a8a09c7c87765652fea38cbfd Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 16:49:17 -0800
Subject: [PATCH 060/115] Changed gradient of GatherNd to use IndexedSlices
 when possible rather than producing a dense output.

PiperOrigin-RevId: 175086874
---
 tensorflow/python/kernel_tests/gather_nd_op_test.py | 10 ++++++++--
 tensorflow/python/ops/array_grad.py                 |  6 +++++-
 2 files changed, 13 insertions(+), 3 deletions(-)

diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py
index af5e23c926c..5109ed98c92 100644
--- a/tensorflow/python/kernel_tests/gather_nd_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py
@@ -25,6 +25,7 @@ import numpy as np
 from tensorflow.python.client import session
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import variables
@@ -185,6 +186,9 @@ class GatherNdTest(test.TestCase):
     self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val)
     self.assertEqual([10, 10, 20], gather_nd_t.get_shape())
 
+  def assertIndexedSlices(self, t):
+    self.assertIsInstance(t, ops.IndexedSlices)
+
   def testUnknownIndices(self):
     params = constant_op.constant([[0, 1, 2]])
     indices = array_ops.placeholder(dtypes.int32)
@@ -233,7 +237,8 @@ class GatherNdTest(test.TestCase):
     grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
     expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64)
     with self.test_session(use_gpu=True):
-      self.assertAllEqual(expected_grads, grads.eval())
+      self.assertIndexedSlices(grads)
+      self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
 
   def testGradientsRank3Elements(self):
     indices = constant_op.constant(
@@ -284,7 +289,8 @@ class GatherNdTest(test.TestCase):
          [0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3, 3, 3]],
         dtype=np.float64)
     with self.test_session(use_gpu=True):
-      self.assertAllEqual(expected_grads, grads.eval())
+      self.assertIndexedSlices(grads)
+      self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
 
 
 class GatherNdOpBenchmark(test.Benchmark):
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 3c025881cb8..87f8d148601 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -460,7 +460,11 @@ def _GatherNdGrad(op, grad):
   ref = op.inputs[0]
   indices = op.inputs[1]
   ref_shape = array_ops.shape(ref, out_type=indices.dtype)
-  ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
+  if indices.shape.ndims == 2 and indices.shape[-1].value == 1:
+    ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
+                                 ref_shape)
+  else:
+    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
   return [ref_grad, None]
 
 

From d1dc152b5c97b5b58314a6959543311ced35deed Mon Sep 17 00:00:00 2001
From: Igor Ganichev <iga@google.com>
Date: Wed, 8 Nov 2017 16:59:13 -0800
Subject: [PATCH 061/115] Improve error message for @graph_callable argument
 check

PiperOrigin-RevId: 175088248
---
 tensorflow/python/eager/graph_callable.py | 13 ++++++++-----
 1 file changed, 8 insertions(+), 5 deletions(-)

diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index a7f1061d18b..ce51d17cfca 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -247,7 +247,9 @@ def _get_graph_callable_inputs(shape_and_dtypes):
       ret.append(_get_graph_callable_inputs(x))
     else:
       raise errors.InvalidArgumentError(
-          None, None, "shape_and_dtypes not ShapeAndDtype, type: %s " % type(x))
+          None, None, "Expected the argument to @graph_callable to be a "
+          "(possibly nested) list or tuple of ShapeAndDtype objects, "
+          "but got an object of type: %s" % type(x))
 
   return tuple(ret) if isinstance(shape_and_dtypes, tuple) else ret
 
@@ -267,7 +269,7 @@ def _graph_callable_internal(func, shape_and_dtypes):
 
   Args:
     func: The tfe Python function to compile.
-    shape_and_dtypes: A list of type ShapeAndDtype.
+    shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects.
 
   Raises:
     ValueError: If any one of func's outputs is not a Tensor.
@@ -430,9 +432,10 @@ def graph_callable(shape_and_dtypes):
   ret = foo(tfe.Tensor(2.0))  # `ret` here now is a Tensor with value 9.0.
   ```
   Args:
-    shape_and_dtypes: A list of type ShapeAndDtype that specifies shape and type
-      information for each of the callable's arguments. The length of this list
-      must be equal to the number of arguments accepted by the wrapped function.
+    shape_and_dtypes: A possibly nested list or tuple of ShapeAndDtype objects
+      that specifies shape and type information for each of the callable's
+      arguments. The length of this list must be equal to the number of
+      arguments accepted by the wrapped function.
 
   Returns:
     A callable graph object.

From c58da5291a6b1344de8e3e7e7ea59d770701fc15 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 8 Nov 2017 17:01:26 -0800
Subject: [PATCH 062/115] Remove extra copy of literal in client
 TransferToOutfeed

PiperOrigin-RevId: 175088538
---
 tensorflow/compiler/xla/client/client.cc | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index 92cd8e729d6..66937d64aff 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -142,8 +142,7 @@ StatusOr<std::unique_ptr<Literal>> Client::TransferFromOutfeed(
         "TransferToClient request");
   }
 
-  Literal literal(response.literal());
-  return MakeUnique<Literal>(literal);
+  return MakeUnique<Literal>(response.literal());
 }
 
 Status Client::ResetDevice() {

From 1c79da73c193944878025a6f49565c54f63da4f2 Mon Sep 17 00:00:00 2001
From: Shanqing Cai <cais@google.com>
Date: Wed, 8 Nov 2017 19:01:15 -0800
Subject: [PATCH 063/115] Add hooks keyword argument to slim evaluate_once

to enable TFDBG debugging of slim.evaluation.evaluate_once()

Fixes: #13444
PiperOrigin-RevId: 175101022
---
 tensorflow/contrib/slim/BUILD                 |  2 +
 .../contrib/slim/python/slim/evaluation.py    | 15 ++++--
 .../slim/python/slim/evaluation_test.py       | 46 +++++++++++++++++--
 .../docs_src/programmers_guide/debugger.md    | 26 ++++++++++-
 4 files changed, 77 insertions(+), 12 deletions(-)

diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD
index 23c23af2f48..c2f106c2b28 100644
--- a/tensorflow/contrib/slim/BUILD
+++ b/tensorflow/contrib/slim/BUILD
@@ -39,6 +39,8 @@ py_test(
         "//tensorflow/python:summary",
         "//tensorflow/python:training",
         "//tensorflow/python:variables",
+        "//tensorflow/python/debug:debug_data",
+        "//tensorflow/python/debug:hooks",
         "//third_party/py/numpy",
     ],
 )
diff --git a/tensorflow/contrib/slim/python/slim/evaluation.py b/tensorflow/contrib/slim/python/slim/evaluation.py
index 2d4b08df61a..cdb720b36ba 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation.py
@@ -153,7 +153,8 @@ def evaluate_once(master,
                   summary_op=_USE_DEFAULT,
                   summary_op_feed_dict=None,
                   variables_to_restore=None,
-                  session_config=None):
+                  session_config=None,
+                  hooks=None):
   """Evaluates the model at the given checkpoint path.
 
   Args:
@@ -177,6 +178,8 @@ def evaluate_once(master,
       slim.variables.GetVariablesToRestore() is used.
     session_config: An instance of `tf.ConfigProto` that will be used to
       configure the `Session`. If left as `None`, the default will be used.
+    hooks: A list of additional `SessionRunHook` objects to pass during the
+      evaluation.
 
   Returns:
     The value of `final_op` or `None` if `final_op` is `None`.
@@ -184,11 +187,13 @@ def evaluate_once(master,
   if summary_op == _USE_DEFAULT:
     summary_op = summary.merge_all()
 
-  hooks = [evaluation.StopAfterNEvalsHook(num_evals),]
+  all_hooks = [evaluation.StopAfterNEvalsHook(num_evals),]
 
   if summary_op is not None:
-    hooks.append(evaluation.SummaryAtEndHook(
+    all_hooks.append(evaluation.SummaryAtEndHook(
         log_dir=logdir, summary_op=summary_op, feed_dict=summary_op_feed_dict))
+  if hooks is not None:
+    all_hooks.extend(hooks)
 
   saver = None
   if variables_to_restore is not None:
@@ -203,7 +208,7 @@ def evaluate_once(master,
       feed_dict=eval_op_feed_dict,
       final_ops=final_op,
       final_ops_feed_dict=final_op_feed_dict,
-      hooks=hooks,
+      hooks=all_hooks,
       config=session_config)
 
 
@@ -256,7 +261,7 @@ def evaluation_loop(master,
       configure the `Session`. If left as `None`, the default will be used.
     timeout: The maximum amount of time to wait between checkpoints. If left as
       `None`, then the process will wait indefinitely.
-    hooks: A list of additional SessionRunHook objects to pass during
+    hooks: A list of additional `SessionRunHook` objects to pass during
       repeated evaluations.
 
   Returns:
diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py
index d9e0f54b724..870f504d103 100644
--- a/tensorflow/contrib/slim/python/slim/evaluation_test.py
+++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import glob
 import os
+import shutil
 import time
 
 import numpy as np
@@ -29,6 +30,8 @@ from tensorflow.contrib.metrics.python.ops import metric_ops
 from tensorflow.contrib.slim.python.slim import evaluation
 from tensorflow.contrib.training.python.training import evaluation as evaluation_lib
 from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.wrappers import hooks
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -230,11 +233,7 @@ class SingleEvaluationTest(test.TestCase):
     with self.assertRaises(errors.NotFoundError):
       evaluation.evaluate_once('', checkpoint_path, log_dir)
 
-  def testRestoredModelPerformance(self):
-    checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
-    log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
-
-    # First, save out the current model to a checkpoint:
+  def _prepareCheckpoint(self, checkpoint_path):
     init_op = control_flow_ops.group(variables.global_variables_initializer(),
                                      variables.local_variables_initializer())
     saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
@@ -242,6 +241,13 @@ class SingleEvaluationTest(test.TestCase):
       sess.run(init_op)
       saver.save(sess, checkpoint_path)
 
+  def testRestoredModelPerformance(self):
+    checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
+    log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
+
+    # First, save out the current model to a checkpoint:
+    self._prepareCheckpoint(checkpoint_path)
+
     # Next, determine the metric to evaluate:
     value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
                                                         self._labels)
@@ -251,6 +257,36 @@ class SingleEvaluationTest(test.TestCase):
         '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op)
     self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
 
+  def testAdditionalHooks(self):
+    checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
+    log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
+
+    # First, save out the current model to a checkpoint:
+    self._prepareCheckpoint(checkpoint_path)
+
+    # Next, determine the metric to evaluate:
+    value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
+                                                        self._labels)
+
+    dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir')
+    dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False)
+    try:
+      # Run the evaluation and verify the results:
+      accuracy_value = evaluation.evaluate_once(
+          '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op,
+          hooks=[dumping_hook])
+      self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
+
+      dump = debug_data.DebugDumpDir(
+          glob.glob(os.path.join(dumping_root, 'run_*'))[0])
+      # Here we simply assert that the dumped data has been loaded and is
+      # non-empty. We do not care about the detailed model-internal tensors or
+      # their values.
+      self.assertTrue(dump.dumped_tensor_data)
+    finally:
+      if os.path.isdir(dumping_root):
+        shutil.rmtree(dumping_root)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/docs_src/programmers_guide/debugger.md b/tensorflow/docs_src/programmers_guide/debugger.md
index 36a016e8802..1f856bbf3f2 100644
--- a/tensorflow/docs_src/programmers_guide/debugger.md
+++ b/tensorflow/docs_src/programmers_guide/debugger.md
@@ -509,8 +509,12 @@ model.fit(...)  # This will break into the TFDBG CLI.
 
 ## Debugging tf-slim with TFDBG
 
-TFDBG currently supports only training with
+TFDBG supports debugging of training and evaluation with
 [tf-slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim).
+As detailed below, training and evaluation require slightly different debugging
+workflows.
+
+### Debugging training in tf-slim
 To debug the training process, provide `LocalCLIDebugWrapperSession` to the
 `session_wrapper` argument of `slim.learning.train()`. For example:
 
@@ -519,13 +523,31 @@ import tensorflow as tf
 from tensorflow.python import debug as tf_debug
 
 # ... Code that creates the graph and the train_op ...
-tf.contrib.slim.learning_train(
+tf.contrib.slim.learning.train(
     train_op,
     logdir,
     number_of_steps=10,
     session_wrapper=tf_debug.LocalCLIDebugWrapperSession)
 ```
 
+### Debugging evaluation in tf-slim
+To debug the evaluation process, provide `LocalCLIDebugHook` to the
+`hooks` argument of `slim.evaluation.evaluate_once()`. For example:
+
+``` python
+import tensorflow as tf
+from tensorflow.python import debug as tf_debug
+
+# ... Code that creates the graph and the eval and final ops ...
+tf.contrib.slim.evaluation.evaluate_once(
+    '',
+    checkpoint_path,
+    logdir,
+    eval_op=my_eval_op,
+    final_op=my_value_op,
+    hooks=[tf_debug.LocalCLIDebugHook()])
+```
+
 ## Offline Debugging of Remotely-Running Sessions
 
 Often, your model is running on a remote machine or a process that you don't

From 29833cac91cb2f7c5016db9fc82f47124d2c94da Mon Sep 17 00:00:00 2001
From: Yao Zhang <yaozhang@google.com>
Date: Wed, 8 Nov 2017 22:23:01 -0800
Subject: [PATCH 064/115] Simplify graph construction with an option to not
 validate colocation constraints (for graph optimizations, colocation
 constraints are already validated previously and device placement of nodes
 has completed previously and there is no need to validate again).

PiperOrigin-RevId: 175113956
---
 tensorflow/core/graph/graph_constructor.cc    | 10 +++++---
 tensorflow/core/graph/graph_constructor.h     |  3 +++
 .../core/graph/graph_constructor_test.cc      | 15 +++++++++++
 .../core/grappler/costs/graph_properties.cc   |  5 ++++
 .../grappler/costs/graph_properties_test.cc   | 25 +++++++++++++++++++
 5 files changed, 55 insertions(+), 3 deletions(-)

diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 753cb260e51..2ee409768b5 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -68,7 +68,8 @@ class GraphConstructor {
     Options(const GraphConstructorOptions& in)  // NOLINT(runtime/explicit)
         : allow_internal_ops(in.allow_internal_ops),
           expect_device_spec(in.expect_device_spec),
-          importing(false) {}
+          importing(false),
+          validate_colocation_constraints(false) {}
     Options(const ImportGraphDefOptions& in)  // NOLINT(runtime/explicit)
         : allow_internal_ops(false),
           expect_device_spec(false),
@@ -81,7 +82,8 @@ class GraphConstructor {
           control_dependencies(in.control_dependencies),
           return_tensors(in.return_tensors),
           return_nodes(in.return_nodes),
-          importing(true) {}
+          importing(true),
+          validate_colocation_constraints(in.validate_colocation_constraints) {}
 
     bool allow_internal_ops;
     bool expect_device_spec;
@@ -103,6 +105,7 @@ class GraphConstructor {
     // applicable to ConvertGraphDefToGraph as well, so make an attempt to
     // remove this.
     bool importing;
+    bool validate_colocation_constraints;
   };
 
   typedef gtl::ArraySlice<const NodeDef*> NodeDefSlice;
@@ -492,7 +495,8 @@ Status GraphConstructor::InitFromEdges() {
 
 Status GraphConstructor::ValidateColocationConstraints(
     const NodeDef& node_def) {
-  if (!opts_.importing) return Status::OK();
+  if (!opts_.validate_colocation_constraints || !opts_.importing)
+    return Status::OK();
   const auto iter = node_def.attr().find(kColocationAttrName);
   if (iter == node_def.attr().end()) return Status::OK();
   for (const string& c : iter->second.list().s()) {
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index 416c0ee9ae8..4b418b86229 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -119,6 +119,9 @@ struct ImportGraphDefOptions {
   // TODO(skyewm): make this work with `skip_mapped_nodes` if there's a need.
   std::vector<string> return_nodes;
 
+  // If true, checks that all colocation constraints are nodes in the GraphDef.
+  bool validate_colocation_constraints = true;
+
   // TODO(ashankar): Enable handling of GraphDefs produced by newer binaries
   // with ops that are not defined in the binary calling ImportGraphDef.
   // Similar to the producer_op_list argument to import_graph_def in the
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index cd541c7d86f..893826da3ed 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -2978,5 +2978,20 @@ versions {
   EXPECT_EQ(17, refiner.graph_def_version());
 }
 
+TEST_F(GraphConstructorTest, ImportGraphDef_ValidateColationConstraints) {
+  GraphDef def;
+  ASSERT_TRUE(protobuf::TextFormat::ParseFromString(
+      "node { name: 'A' op: 'TestInput' attr { key: '_class' value { list { "
+      "s:'loc:@missing' } } } }",
+      &def));
+  ImportGraphDefOptions options;
+  // TODO(yaozhang): Extend ExpectError to check error type and use ExpectError
+  // and ExpectOK to replace the code below.
+  Status s = ImportGraphDef(options, def, &graph_, nullptr);
+  EXPECT_TRUE(errors::IsInvalidArgument(s)) << s;
+  options.validate_colocation_constraints = false;
+  TF_EXPECT_OK(ImportGraphDef(options, def, &graph_, nullptr));
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index a59879f53cd..8654a2a3ed0 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -447,6 +447,11 @@ Status GraphProperties::InferStatically() {
   shape_refiner.set_disable_constant_propagation(true);
   shape_refiner.set_function_library_for_shape_inference(&function_library);
   ImportGraphDefOptions options;
+  // Graph optimization happens at the late stage of graph execution,
+  // when colocation constraints are already validated previously and
+  // the device placement of nodes has also completed, so there
+  // is no need to validate colocation constraints again.
+  options.validate_colocation_constraints = false;
   Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
   TF_RETURN_IF_ERROR(s);
 
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index a33cdacc092..acd0b598aef 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
 #include "tensorflow/core/grappler/inputs/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/io/path.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/protobuf.h"
@@ -784,6 +785,30 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
   EXPECT_EQ(shape_f.dim(1).size(), shape_a.dim(1).size());
 }
 
+TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output a = ops::Const(s.WithOpName("a"), 1.0f, {1});
+  Output b = ops::Const(s.WithOpName("b"), 2.0f, {1});
+  Output c = ops::Const(s.WithOpName("c").ColocateWith(a), 3.0f, {1});
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  // Create a graph with node a removed (say by some graph optimization
+  // pass), noting that node c is colocated with a. This is fine as it
+  // is in the late stage of graph execution, the colocation constraints have
+  // been validated previously and the device placement of nodes has completed.
+  GraphDef optimized_graph;
+  for (const auto& node : item.graph.node()) {
+    if (node.name() != "a") {
+      *optimized_graph.add_node() = node;
+    }
+  }
+  item.graph.Swap(&optimized_graph);
+  GraphProperties properties(item);
+  // This function should return OK, since it doesn't validate the colocation
+  // constraints internally.
+  TF_EXPECT_OK(properties.InferStatically());
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow

From 71bd045af1ebe74c7e3b1b968b5a5b86e0a153c3 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 02:32:44 -0800
Subject: [PATCH 065/115] When sharding a tuple, we typically want to describe
 the data sharding of each individual subtensor individually. Tuples are
 essentially just containers - the tensors they contain should be able to be
 sharded differently.

Tuples are hierarchically structured, but shardings were designed to
not contain the sharded type (the sharded type is inferred from the
output type of the instruction the sharding is applied to). Therefore,
shardings for tuples contain shardings for each subtensor as a
non-structured list.

This list is ordered as a preorder walk of the tuple shape, and of
course only the leaf nodes of the tuple shape are stored. The
structure is reapplied when the sharded instruction's shape is known.

PiperOrigin-RevId: 175132692
---
 .../compiler/xla/service/hlo_sharding.cc      | 71 +++++++++++++++-
 .../compiler/xla/service/hlo_sharding.h       | 83 +++++++++++++++++--
 .../compiler/xla/service/hlo_sharding_test.cc | 68 +++++++++++++++
 tensorflow/compiler/xla/shape_tree.h          |  3 +
 .../compiler/xla/tools/parser/hlo_parser.cc   | 41 ++++++++-
 .../xla/tools/parser/hlo_parser_test.cc       | 15 +++-
 tensorflow/compiler/xla/xla_data.proto        | 13 ++-
 7 files changed, 278 insertions(+), 16 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 0d019d22f5d..bc5663513b9 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
 
 #include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
 
 namespace xla {
 
@@ -38,6 +39,15 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
 }
 
 string HloSharding::ToString() const {
+  if (IsTuple()) {
+    std::vector<string> parts;
+    parts.reserve(tuple_elements_.size());
+    for (const HloSharding& element : tuple_elements_) {
+      parts.push_back(element.ToString());
+    }
+    return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
+  }
+
   string result = StrCat("{", (replicated_ ? " replicated" : ""),
                          (maximal_ ? " maximal" : ""));
 
@@ -53,6 +63,11 @@ string HloSharding::ToString() const {
 }
 
 bool HloSharding::UsesDevice(int64 device) const {
+  if (IsTuple()) {
+    return std::any_of(
+        tuple_elements_.begin(), tuple_elements_.end(),
+        [&](const HloSharding& s) { return s.UsesDevice(device); });
+  }
   const auto& devices = tile_assignment_;
   return replicated_ ||
          std::find(devices.begin(), devices.end(), device) != devices.end();
@@ -61,6 +76,7 @@ bool HloSharding::UsesDevice(int64 device) const {
 std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
   CHECK(!ShapeUtil::IsTuple(tile_shape_));
   CHECK(!maximal_);
+  CHECK(!IsTuple());
   std::vector<int64> ret_index;
   tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
     if (d == device) {
@@ -74,6 +90,7 @@ std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
 int64 HloSharding::DeviceForTileIndex(
     tensorflow::gtl::ArraySlice<int64> index) const {
   CHECK(!replicated_);
+  CHECK(!IsTuple());
   if (maximal_) {
     return *tile_assignment_.begin();
   }
@@ -82,7 +99,7 @@ int64 HloSharding::DeviceForTileIndex(
 }
 
 std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
-  CHECK(!ShapeUtil::IsTuple(tile_shape_));
+  CHECK(!IsTuple());
 
   std::vector<int64> index = TileIndexForDevice(device);
   if (maximal_) {
@@ -97,7 +114,7 @@ std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
 }
 
 std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
-  CHECK(!ShapeUtil::IsTuple(tile_shape_));
+  CHECK(!IsTuple());
   CHECK(!maximal_);  // Maximal shardings do not have a valid tile shape.
 
   std::vector<int64> index = TileIndexForDevice(device);
@@ -108,13 +125,41 @@ std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
 }
 
 StatusOr<int64> HloSharding::UniqueDevice() const {
-  if (!replicated_ && maximal_) {
+  if (IsTuple()) {
+    if (tuple_elements_.empty()) {
+      return tensorflow::errors::InvalidArgument(
+          "UniqueDevice() called on empty tuple");
+    }
+    std::vector<StatusOr<int64>> results;
+    std::transform(tuple_elements_.begin(), tuple_elements_.end(),
+                   std::back_inserter(results),
+                   [](const HloSharding& s) { return s.UniqueDevice(); });
+    if (std::all_of(results.begin(), results.end(),
+                    [&](const StatusOr<int64>& s) {
+                      return s.ok() && results[0].ok() &&
+                             s.ValueOrDie() == results[0].ValueOrDie();
+                    })) {
+      return results[0];
+    } else {
+      return tensorflow::errors::InvalidArgument(
+          "Tuple did not contain a unique device");
+    }
+  }
+  if (!replicated_ && maximal_ && !IsTuple()) {
     return static_cast<int64>(*tile_assignment_.begin());
   }
   return tensorflow::errors::InvalidArgument(
       "UniqueDevice() called on sharding that executes on multiple devices");
 }
 
+bool HloSharding::HasUniqueDevice() const {
+  if (IsTuple()) {
+    return UniqueDevice().status().ok();
+  } else {
+    return !IsReplicated() && IsTileMaximal();
+  }
+}
+
 Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
   if (replicated_) {
     return Status::OK();
@@ -193,7 +238,16 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
 
 /*static*/ StatusOr<HloSharding> HloSharding::FromProto(
     const OpSharding& proto) {
-  if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
+  if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) {
+    std::vector<HloSharding> tuple_shardings;
+    tuple_shardings.reserve(proto.tuple_shardings().size());
+    for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
+      TF_ASSIGN_OR_RETURN(HloSharding sharding,
+                          HloSharding::FromProto(tuple_sharding_proto));
+      tuple_shardings.push_back(sharding);
+    }
+    return HloSharding(tuple_shardings);
+  } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
     return Replicate();
   } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) {
     return HloSharding(proto.tile_assignment_devices(0));
@@ -212,6 +266,15 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
 
 OpSharding HloSharding::ToProto() const {
   OpSharding result;
+
+  if (IsTuple()) {
+    for (const HloSharding& element : tuple_elements_) {
+      *result.add_tuple_shardings() = element.ToProto();
+    }
+    result.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+    return result;
+  }
+
   *result.mutable_tile_shape() = tile_shape_;
   for (int64 dim : tile_assignment_.dimensions()) {
     result.add_tile_assignment_dimensions(dim);
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h
index d7ada30c70b..f8ef2a3d059 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.h
+++ b/tensorflow/compiler/xla/service/hlo_sharding.h
@@ -24,6 +24,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/array.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/protobuf_util.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
 #include "tensorflow/core/lib/hash/hash.h"
@@ -67,6 +68,18 @@ class HloSharding {
   // `num_tiles` tiles.
   static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles);
 
+  // Creates a new sharding for a tuple type. The given ShapeTree must have
+  // elements for every leaf shape contained in the tuple.
+  static HloSharding Tuple(const ShapeTree<HloSharding>& sub_shardings) {
+    std::vector<HloSharding> flattened_list;
+    flattened_list.reserve(
+        std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end()));
+    for (const auto& index_to_sharding : sub_shardings.leaves()) {
+      flattened_list.push_back(index_to_sharding.second);
+    }
+    return HloSharding(flattened_list);
+  }
+
   // Create a new sharding from a protobuf OpSharding.
   static StatusOr<HloSharding> FromProto(const OpSharding& proto);
 
@@ -76,47 +89,89 @@ class HloSharding {
   // Validate that this sharding can be applied to a tensor with shape `shape`.
   Status Validate(const Shape& shape, int64 num_devices) const;
 
+  // Returns true if the sharding has tuple type.
+  bool IsTuple() const { return tuple_; }
+
   // Returns true if the sharding is trivial: replicate on all devices.
-  bool IsReplicated() const { return replicated_; }
+  bool IsReplicated() const {
+    if (!IsTuple()) {
+      return replicated_;
+    }
+    return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
+                       [](const HloSharding& s) { return s.IsReplicated(); });
+  }
 
   // Returns true if the tile size is the same as the input size.
-  bool IsTileMaximal() const { return maximal_; }
+  bool IsTileMaximal() const {
+    if (!IsTuple()) {
+      return maximal_;
+    }
+    return std::all_of(tuple_elements_.begin(), tuple_elements_.end(),
+                       [](const HloSharding& s) { return s.IsTileMaximal(); });
+  }
 
   // Returns true if the sharding defines an operation on the given device.
   bool UsesDevice(int64 device) const;
 
   // Returns the tile that should be executed on the given device.
+  // REQUIRES: !IsTuple()
   std::vector<int64> TileIndexForDevice(int64 device) const;
 
   // Returns the device that should execute the given tile.
   // It is an error to call this if is_replicated() is true.
+  // REQUIRES: !IsTuple()
   int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice<int64> index) const;
 
   // Given a device ID, returns the offset within the input space of the
   // tile that should be executed on the given core. This returns the lower
   // extent of the tile in the input space.
+  // REQUIRES: !IsTuple()
   std::vector<int64> TileOffsetForDevice(int64 device) const;
 
   // Given a device ID, returns the limit within the input space of the
   // tile that should be executed on the given core. This returns the upper
   // extent of the tile in the input space.
+  // REQUIRES: !IsTuple()
   std::vector<int64> TileLimitForDevice(int64 device) const;
 
   // Returns the single device this op operates on.
-  // Requires !Replicated() && IsTileMaximal().
+  // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal()
   StatusOr<int64> UniqueDevice() const;
 
   // Returns true if this op only uses a single device.
-  bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); }
+  bool HasUniqueDevice() const;
+
+  // Returns the ShapeTree containing the shardings for each element of this
+  // tuple. Only the leaf elements are populated. This creates a new ShapeTree
+  // object so is not cheap. REQUIRES: IsTuple()
+  ShapeTree<HloSharding> GetTupleShardingsAsShapeTree(
+      const Shape& tuple_shape) const {
+    ShapeTree<HloSharding> result(tuple_shape, HloSharding::Replicate());
+    CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()),
+             tuple_elements_.size());
+    auto it = tuple_elements_.begin();
+    for (auto& index_to_sharding : result.leaves()) {
+      index_to_sharding.second = *it++;
+    }
+    return result;
+  }
 
   bool operator==(const HloSharding& other) const {
     return replicated_ == other.replicated_ && maximal_ == other.maximal_ &&
            protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) &&
-           tile_assignment_ == other.tile_assignment_;
+           tile_assignment_ == other.tile_assignment_ &&
+           tuple_elements_ == other.tuple_elements_;
   }
   bool operator!=(const HloSharding& other) const { return !(*this == other); }
 
   size_t Hash() const {
+    if (!tuple_) {
+      size_t h = 0;
+      for (const auto& element : tuple_elements_) {
+        h = tensorflow::Hash64Combine(h, element.Hash());
+      }
+      return h;
+    }
     if (replicated_) {
       return 0;
     }
@@ -131,33 +186,47 @@ class HloSharding {
   }
 
   // Gets the tile shape.
-  // It is an error to call this if IsTileMaximal() is true.
+  // REQUIRES: !IsTileMaximal() && !IsTuple()
   const Shape& tile_shape() const { return tile_shape_; }
   // Gets the tile assignment tensor.
-  // It is an error to call this if IsReplicated() is true.
+  // REQUIRES: !IsReplicated() && !IsTuple()
   const Array<int64>& tile_assignment() const { return tile_assignment_; }
 
  private:
   HloSharding()
       : replicated_(true),
         maximal_(true),
+        tuple_(false),
         tile_shape_(),
         tile_assignment_({0}) {}
   explicit HloSharding(int64 device_id)
       : replicated_(false),
         maximal_(true),
+        tuple_(false),
         tile_shape_(),
         tile_assignment_({1}, device_id) {}
   HloSharding(const Shape& tile_shape, const Array<int64>& tile_assignment)
       : replicated_(false),
         maximal_(false),
+        tuple_(false),
         tile_shape_(tile_shape),
         tile_assignment_(tile_assignment) {}
+  HloSharding(const std::vector<HloSharding>& tuple_shardings)
+      : replicated_(false),
+        maximal_(false),
+        tuple_(true),
+        tile_assignment_({0}),
+        tuple_elements_(tuple_shardings) {}
 
   bool replicated_;
   bool maximal_;
+  bool tuple_;
   Shape tile_shape_;
   Array<int64> tile_assignment_;
+  // Only non-empty when tuple_ is true, but because empty tuples are allowed
+  // may also be empty even then. This is a flattened list of all the leaf
+  // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order).
+  std::vector<HloSharding> tuple_elements_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index d0a20471a0f..00ea38480ee 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -132,6 +132,29 @@ TEST_F(HloShardingTest, Tile) {
   }
 }
 
+TEST_F(HloShardingTest, NestedTuple) {
+  // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6])
+  Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({
+      ShapeUtil::MakeShape(F32, {}),
+      ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}),
+      ShapeUtil::MakeShape(F32, {4, 6}),
+  });
+
+  OpSharding proto;
+  proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+  *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto();
+  *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto();
+  *proto.add_tuple_shardings() = HloSharding::AssignDevice(1).ToProto();
+  HloSharding tuple_sharding =
+      HloSharding::FromProto(proto).ConsumeValueOrDie();
+
+  ShapeTree<HloSharding> shape_tree =
+      tuple_sharding.GetTupleShardingsAsShapeTree(nested_tuple_shape);
+  EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate());
+  EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0));
+  EXPECT_EQ(shape_tree.element({2}), HloSharding::AssignDevice(1));
+}
+
 TEST_F(HloShardingTest, Hash) {
   auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) {
     if (a.Hash() != b.Hash()) {
@@ -184,6 +207,51 @@ TEST_F(HloShardingTest, Hash) {
                                               MakeArray({2, 2}, {0, 3, 1, 2}));
     EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
   }
+
+  HloSharding default_sharding = HloSharding::Replicate();
+  {
+    ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+                                      default_sharding);
+    HloSharding sharding1 = HloSharding::Replicate();
+    HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+    EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+  }
+
+  {
+    ShapeTree<HloSharding> shape_tree(ShapeUtil::MakeTupleShape({}),
+                                      default_sharding);
+    HloSharding sharding1 = HloSharding::Tuple(shape_tree);
+    HloSharding sharding2 = HloSharding::Tuple(shape_tree);
+    EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+  }
+
+  {
+    ShapeTree<HloSharding> shape_tree1(
+        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+        default_sharding);
+    *shape_tree1.mutable_element({0}) = HloSharding::Replicate();
+    ShapeTree<HloSharding> shape_tree2(
+        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+        default_sharding);
+    *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+    HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+    HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+    EXPECT_FALSE(hash_compare_equal(sharding1, sharding2));
+  }
+
+  {
+    ShapeTree<HloSharding> shape_tree1(
+        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+        default_sharding);
+    *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0);
+    ShapeTree<HloSharding> shape_tree2(
+        ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}),
+        default_sharding);
+    *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0);
+    HloSharding sharding1 = HloSharding::Tuple(shape_tree1);
+    HloSharding sharding2 = HloSharding::Tuple(shape_tree2);
+    EXPECT_TRUE(hash_compare_equal(sharding1, sharding2));
+  }
 }
 
 }  // namespace
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
index 64a36471b9f..a898a4d3757 100644
--- a/tensorflow/compiler/xla/shape_tree.h
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -116,6 +116,7 @@ class ShapeTree {
   ShapeTree(const Shape* shape, const T& init_value);
 
   ShapeTree(const ShapeTree& other) { *this = other; }
+  ShapeTree(ShapeTree&&) = default;
 
   ShapeTree& operator=(const ShapeTree& other) {
     root_ = other.root_;
@@ -132,6 +133,8 @@ class ShapeTree {
     return *this;
   }
 
+  ShapeTree& operator=(ShapeTree&& other) = default;
+
   // Returns the data element associated with the array in the shape at the
   // given index (see ShapeUtil::GetSubshape for how indexes are defined).
   const T& element(const ShapeIndex& index) const;
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index f1e987cb15c..df07e069a04 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -60,6 +60,7 @@ class HloParser {
   bool ParseInstructionList(HloComputation::Builder* builder,
                             string* root_name);
   bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
+  bool ParseControlPredecessors(HloInstruction* instruction);
   bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
   bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
   bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
@@ -123,6 +124,7 @@ class HloParser {
   bool ParseWindow(Window* window);
   bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
   bool ParseSharding(OpSharding* sharding);
+  bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
 
   // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
   bool ParseDxD(const string& name, std::vector<int64>* result);
@@ -548,14 +550,49 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
   return AddInstruction(name, instruction);
 }
 
-// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
-// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
+// ::= '{' (single_sharding | tuple_sharding) '}'
+//
+// tuple_sharding ::= single_sharding* (',' single_sharding)*
 bool HloParser::ParseSharding(OpSharding* sharding) {
+  // A single sharding starts with '{' and is not followed by '{'.
+  // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
+  // an empty tuple.
   if (!ParseToken(TokKind::kLbrace,
                   "expected '{' to start sharding attribute")) {
     return false;
   }
 
+  if (lexer_.GetKind() != TokKind::kLbrace &&
+      lexer_.GetKind() != TokKind::kRbrace) {
+    return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
+  }
+
+  // Tuple sharding.
+  // Allow empty tuple shardings.
+  if (lexer_.GetKind() != TokKind::kRbrace) {
+    do {
+      if (!ParseSingleSharding(sharding->add_tuple_shardings(),
+                               /*lbrace_pre_lexed=*/false)) {
+        return false;
+      }
+    } while (EatIfPresent(TokKind::kComma));
+  }
+  sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE);
+
+  return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
+}
+
+//  ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
+//          ('devices=' ('[' dims ']')* device_list)? '}'
+// dims ::= int_list device_list ::= int_list
+bool HloParser::ParseSingleSharding(OpSharding* sharding,
+                                    bool lbrace_pre_lexed) {
+  if (!lbrace_pre_lexed &&
+      !ParseToken(TokKind::kLbrace,
+                  "expected '{' to start sharding attribute")) {
+    return false;
+  }
+
   bool maximal = false;
   bool replicated = false;
   std::vector<int64> devices;
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index 62b4385e76f..a9dc3609784 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -152,7 +152,7 @@ ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f3
   %v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
   %v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
   %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated}
-  ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
+  ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
 }
 
 )"
@@ -180,6 +180,19 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f
   ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
 }
 
+)"
+},
+{
+"ShardedTupleCreate",
+R"(HloModule ShardedTupleCreate_module:
+
+ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
+  %v1 = f32[] parameter(0)
+  %v2 = f32[3]{0} parameter(1)
+  %v3 = f32[2,3]{1,0} parameter(2)
+  ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
+}
+
 )"
 },
 // int32 result = 0;
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 06987e0044d..71466047080 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -825,8 +825,10 @@ message OpSharding {
     REPLICATED = 0;
     // This sharding is maximal - one device runs the entire operation.
     MAXIMAL = 1;
-    // Neither of the above; tile_shape and tile_assignment are both used.
-    OTHER = 2;
+    // This sharding is a tuple - only the tuple_shardings field is valid.
+    TUPLE = 2;
+    // None of the above; tile_shape and tile_assignment are both used.
+    OTHER = 3;
   }
   Type type = 1;
   // The shape of the sharded tile.
@@ -838,6 +840,13 @@ message OpSharding {
   // Flattened list of device IDs. The order of flattening is the same as used
   // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
   repeated int64 tile_assignment_devices = 4;
+  // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
+  // in pre-order. The tuple shape could be nested; here we store just a
+  // flattened list of all leaves in the tuple shape. Note that the tuple shape
+  // is not stored here; shardings do not store the shapes to which they are
+  // applied, this is inferred from the instruction this sharding gets attached
+  // to.
+  repeated OpSharding tuple_shardings = 5;
 }
 
 message OpRequest {

From 18d5c3e4cf1ea8459d4eb12eb741283263c1a065 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 05:36:43 -0800
Subject: [PATCH 066/115] Previously we had a large number of
 ComputeAndCompare* methods to run a computation and then compare the reuslt
 to a specified value (Array or Literal). The new method takes adventage of
 the recently added ComputeConstant method to calculate the expected value
 using the HloEvaluator eliminating the need for doing the calculation
 manually.

As a usage example I converted the convolution tests to the new method
what simplified them by quite a bit. If there is interest then we can
migrate the other tests as well and then remove the old style
ComputeAndCompare* methods.

PiperOrigin-RevId: 175145596
---
 .../xla/tests/client_library_test_base.cc     |  54 ++++++
 .../xla/tests/client_library_test_base.h      |  17 ++
 .../compiler/xla/tests/convolution_test.cc    | 158 ++++++------------
 3 files changed, 125 insertions(+), 104 deletions(-)

diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 065bce7e314..ef54714e46f 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -346,6 +346,60 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
   LiteralTestUtil::ExpectNearTuple(expected, *actual, error);
 }
 
+void ClientLibraryTestBase::ComputeAndCompare(
+    ComputationBuilder* builder, const ComputationDataHandle& operand,
+    tensorflow::gtl::ArraySlice<Literal> arguments) {
+  auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
+  EXPECT_IS_OK(status_or_data);
+  if (!status_or_data.ok()) {
+    return;
+  }
+  std::unique_ptr<Literal> reference, result;
+  std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
+  LiteralTestUtil::ExpectEqual(*reference, *result);
+}
+
+void ClientLibraryTestBase::ComputeAndCompare(
+    ComputationBuilder* builder, const ComputationDataHandle& operand,
+    tensorflow::gtl::ArraySlice<Literal> arguments, ErrorSpec error) {
+  auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
+  EXPECT_IS_OK(status_or_data);
+  if (!status_or_data.ok()) {
+    return;
+  }
+  std::unique_ptr<Literal> reference, result;
+  std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
+  LiteralTestUtil::ExpectNear(*reference, *result, error);
+}
+
+StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+ClientLibraryTestBase::ComputeValueAndReference(
+    ComputationBuilder* builder, const ComputationDataHandle& operand,
+    tensorflow::gtl::ArraySlice<Literal> arguments) {
+  // Transfer the arguments to the executor service. We put the unique_ptr's
+  // into a vector to keep the data alive on the service until the end of this
+  // function.
+  std::vector<std::unique_ptr<GlobalData>> argument_data;
+  for (const auto& arg : arguments) {
+    TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg));
+    argument_data.push_back(std::move(data));
+  }
+
+  // Create raw pointers to the GlobalData for the rest of the call stack.
+  std::vector<GlobalData*> argument_data_ptr;
+  std::transform(
+      argument_data.begin(), argument_data.end(),
+      std::back_inserter(argument_data_ptr),
+      [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
+
+  TF_ASSIGN_OR_RETURN(
+      auto reference,
+      builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments));
+  TF_ASSIGN_OR_RETURN(auto result,
+                      ExecuteAndTransfer(builder, argument_data_ptr));
+  return std::make_pair(std::move(reference), std::move(result));
+}
+
 Computation ClientLibraryTestBase::CreateScalarRelu() {
   ComputationBuilder builder(client_, "relu");
   auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 2c37466ff20..b5786677353 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -196,6 +196,16 @@ class ClientLibraryTestBase : public ::testing::Test {
       ComputationBuilder* builder, const Literal& expected,
       tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error);
 
+  // Convenience method for running a built computation and comparing the result
+  // with the HloEvaluator.
+  void ComputeAndCompare(ComputationBuilder* builder,
+                         const ComputationDataHandle& operand,
+                         tensorflow::gtl::ArraySlice<Literal> arguments);
+  void ComputeAndCompare(ComputationBuilder* builder,
+                         const ComputationDataHandle& operand,
+                         tensorflow::gtl::ArraySlice<Literal> arguments,
+                         ErrorSpec error);
+
   // Create scalar operations for use in reductions.
   Computation CreateScalarRelu();
   Computation CreateScalarMax();
@@ -298,6 +308,13 @@ class ClientLibraryTestBase : public ::testing::Test {
       const std::function<void(const Literal& actual,
                                const string& error_message)>& verify_output,
       const Shape* output_with_layout = nullptr);
+
+  // Executes the computation and calculates the expected reference value using
+  // the HloEvaluator. Returns two literal in the order of (expected, actual).
+  StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
+  ComputeValueAndReference(ComputationBuilder* builder,
+                           const ComputationDataHandle& operand,
+                           tensorflow::gtl::ArraySlice<Literal> arguments);
 };
 
 template <typename NativeT>
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 0cc2e5fb7e6..7425f778a63 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -82,177 +82,127 @@ XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
   ComputationBuilder builder(client_, TestName());
   auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
   auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
-  builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+  auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
 
-  std::unique_ptr<Array4D<float>> aexpected =
-      ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid);
-
-  ComputeAndCompareR4<float>(&builder, *aexpected, {}, error_spec_);
+  ComputeAndCompare(&builder, conv, {}, error_spec_);
 }
 
 TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
   ComputationBuilder builder(client_, TestName());
-  {
-    Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
-    Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
-    auto input = builder.Parameter(0, input_shape, "input");
-    auto filter = builder.Parameter(1, filter_shape, "filter");
-    builder.Conv(input, filter, {1, 1}, Padding::kValid);
-  }
+  Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+  Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+  auto input = builder.Parameter(0, input_shape, "input");
+  auto filter = builder.Parameter(1, filter_shape, "filter");
+  auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
 
-  Array4D<float> input(1, 1, 1, 2);
-  input.FillWithYX(Array2D<float>({
+  Array4D<float> input_data(1, 1, 1, 2);
+  input_data.FillWithYX(Array2D<float>({
       {1, 2},
   }));
-  Array4D<float> filter(1, 1, 1, 2);
-  filter.FillWithYX(Array2D<float>({
+  Array4D<float> filter_data(1, 1, 1, 2);
+  filter_data.FillWithYX(Array2D<float>({
       {5, 6},
   }));
 
-  std::unique_ptr<Array4D<float>> aexpected =
-      ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
-
-  auto input_literal =
-      client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
-          .ConsumeValueOrDie();
-  auto filter_literal =
-      client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
-          .ConsumeValueOrDie();
-
-  ComputeAndCompareR4<float>(&builder, *aexpected,
-                             {input_literal.get(), filter_literal.get()},
-                             error_spec_);
+  ComputeAndCompare(&builder, conv,
+                    {*Literal::CreateFromArray(input_data),
+                     *Literal::CreateFromArray(filter_data)},
+                    error_spec_);
 }
 
 // Tests valid padding for 2D convolution in raster space.
 TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
   ComputationBuilder builder(client_, TestName());
-  {
-    Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
-    Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
-    auto input = builder.Parameter(0, input_shape, "input");
-    auto filter = builder.Parameter(1, filter_shape, "filter");
-    builder.Conv(input, filter, {1, 1}, Padding::kValid);
-  }
+  Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+  Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+  auto input = builder.Parameter(0, input_shape, "input");
+  auto filter = builder.Parameter(1, filter_shape, "filter");
+  auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
 
-  Array4D<float> input(1, 1, 4, 4);
+  Array4D<float> input_data(1, 1, 4, 4);
   // clang-format off
-  input.FillWithYX(Array2D<float>({
+  input_data.FillWithYX(Array2D<float>({
     {1,  2,  3,  4 },
     {5,  6,  7,  8 },
     {9,  10, 11, 12},
     {13, 14, 15, 16},
   }));
   // clang-format on
-  Array4D<float> filter(1, 1, 2, 2);
+  Array4D<float> filter_data(1, 1, 2, 2);
   // clang-format off
-  filter.FillWithYX(Array2D<float>({
+  filter_data.FillWithYX(Array2D<float>({
     {5, 6},
     {7, 8},
   }));
   // clang-format on
-
-  std::unique_ptr<Array4D<float>> aexpected =
-      ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
-
-  auto input_literal =
-      client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
-          .ConsumeValueOrDie();
-  auto filter_literal =
-      client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
-          .ConsumeValueOrDie();
-
-  ComputeAndCompareR4<float>(&builder, *aexpected,
-                             {input_literal.get(), filter_literal.get()},
-                             error_spec_);
+  ComputeAndCompare(&builder, conv,
+                    {*Literal::CreateFromArray(input_data),
+                     *Literal::CreateFromArray(filter_data)},
+                    error_spec_);
 }
 
 // Tests same padding for 2D convolution in raster space.
 TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
   ComputationBuilder builder(client_, TestName());
-  {
-    Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
-    Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
-    auto input = builder.Parameter(0, input_shape, "input");
-    auto filter = builder.Parameter(1, filter_shape, "filter");
-    builder.Conv(input, filter, {1, 1}, Padding::kSame);
-  }
+  Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+  Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+  auto input = builder.Parameter(0, input_shape, "input");
+  auto filter = builder.Parameter(1, filter_shape, "filter");
+  auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
 
-  Array4D<float> input(1, 1, 4, 4);
+  Array4D<float> input_data(1, 1, 4, 4);
   // clang-format off
-  input.FillWithYX(Array2D<float>({
+  input_data.FillWithYX(Array2D<float>({
     {1,  2,  3,  4 },
     {5,  6,  7,  8 },
     {9,  10, 11, 12},
     {13, 14, 15, 16},
   }));
   // clang-format on
-  Array4D<float> filter(1, 1, 2, 2);
+  Array4D<float> filter_data(1, 1, 2, 2);
   // clang-format off
-  filter.FillWithYX(Array2D<float>({
+  filter_data.FillWithYX(Array2D<float>({
     {5, 6},
     {7, 8},
   }));
   // clang-format on
-
-  std::unique_ptr<Array4D<float>> aexpected =
-      ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
-
-  auto input_literal =
-      client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
-          .ConsumeValueOrDie();
-  auto filter_literal =
-      client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
-          .ConsumeValueOrDie();
-
-  ComputeAndCompareR4<float>(&builder, *aexpected,
-                             {input_literal.get(), filter_literal.get()},
-                             error_spec_);
+  ComputeAndCompare(&builder, conv,
+                    {*Literal::CreateFromArray(input_data),
+                     *Literal::CreateFromArray(filter_data)},
+                    error_spec_);
 }
 
 // Tests same padding for 2D convolution in raster space with an odd sized
 // kernel.
 TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
   ComputationBuilder builder(client_, TestName());
-  {
-    Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
-    Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
-    auto input = builder.Parameter(0, input_shape, "input");
-    auto filter = builder.Parameter(1, filter_shape, "filter");
-    builder.Conv(input, filter, {1, 1}, Padding::kSame);
-  }
+  Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+  Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
+  auto input = builder.Parameter(0, input_shape, "input");
+  auto filter = builder.Parameter(1, filter_shape, "filter");
+  auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
 
-  Array4D<float> input(1, 1, 4, 4);
+  Array4D<float> input_data(1, 1, 4, 4);
   // clang-format off
-  input.FillWithYX(Array2D<float>({
+  input_data.FillWithYX(Array2D<float>({
     {1,  2,  3,  4 },
     {5,  6,  7,  8 },
     {9,  10, 11, 12},
     {13, 14, 15, 16},
   }));
   // clang-format on
-  Array4D<float> filter(1, 1, 3, 3);
+  Array4D<float> filter_data(1, 1, 3, 3);
   // clang-format off
-  filter.FillWithYX(Array2D<float>({
+  filter_data.FillWithYX(Array2D<float>({
     { 5,  6,  7},
     { 8,  9, 10},
     {11, 12, 13},
   }));
   // clang-format on
-
-  std::unique_ptr<Array4D<float>> aexpected =
-      ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
-
-  auto input_literal =
-      client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
-          .ConsumeValueOrDie();
-  auto filter_literal =
-      client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
-          .ConsumeValueOrDie();
-
-  ComputeAndCompareR4<float>(&builder, *aexpected,
-                             {input_literal.get(), filter_literal.get()},
-                             error_spec_);
+  ComputeAndCompare(&builder, conv,
+                    {*Literal::CreateFromArray(input_data),
+                     *Literal::CreateFromArray(filter_data)},
+                    error_spec_);
 }
 
 XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {

From efab2e1d91507b948d545d14f942b15250c8bb92 Mon Sep 17 00:00:00 2001
From: Alexandre Passos <apassos@google.com>
Date: Thu, 9 Nov 2017 07:37:15 -0800
Subject: [PATCH 067/115] Removes void*s from the tape gradient code, replacing
 with templates.

PiperOrigin-RevId: 175155685
---
 tensorflow/c/eager/BUILD                  |   1 -
 tensorflow/c/eager/tape.cc                | 410 -------------------
 tensorflow/c/eager/tape.h                 | 473 ++++++++++++++++++++--
 tensorflow/python/eager/pywrap_tfe_src.cc |  60 +--
 4 files changed, 479 insertions(+), 465 deletions(-)
 delete mode 100644 tensorflow/c/eager/tape.cc

diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 74e94be8d68..d533758e360 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -106,7 +106,6 @@ tf_cc_test(
 
 cc_library(
     name = "tape",
-    srcs = ["tape.cc"],
     hdrs = ["tape.h"],
     visibility = ["//tensorflow:internal"],
     deps = [
diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc
deleted file mode 100644
index 459499bb694..00000000000
--- a/tensorflow/c/eager/tape.cc
+++ /dev/null
@@ -1,410 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <unordered_set>
-
-#include "tensorflow/c/eager/tape.h"
-
-namespace tensorflow {
-namespace eager {
-
-bool GradientTape::ShouldRecord(gtl::ArraySlice<int64> tensor_ids) {
-  for (int64 i : tensor_ids) {
-    if (tensor_tape_.find(i) != tensor_tape_.end()) {
-      return true;
-    }
-  }
-  return false;
-}
-
-void GradientTape::Watch(int64 tensor_id) {
-  tensor_tape_.emplace(tensor_id, -1);
-}
-
-void GradientTape::RecordOperation(
-    const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
-    gtl::ArraySlice<int64> input_tensor_id, void* backward_function,
-    const std::function<void()>& backward_function_deleter) {
-  if (!ShouldRecord(input_tensor_id)) {
-    backward_function_deleter();
-    return;
-  }
-  std::vector<int64> ids;
-  ids.reserve(input_tensor_id.size());
-  for (int64 i : input_tensor_id) {
-    tensor_usage_[i]++;
-    ids.push_back(i);
-  }
-  const int64 op_id = next_op_id_++;
-  std::vector<TapeTensor> tensors;
-  tensors.reserve(output_tensors.size());
-  for (const TapeTensor& o : output_tensors) {
-    // Note: the tensor can have already been watched and hence be in the tape,
-    // so we cannot check that we're inserting it here.
-    tensor_tape_[o.id] = op_id;
-    tensor_usage_[o.id] = 1;
-    tensors.push_back(o);
-  }
-  op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function,
-                                backward_function_deleter};
-}
-
-void GradientTape::DeleteTrace(int64 tensor_id) {
-  auto it = tensor_usage_.find(tensor_id);
-  if (it == tensor_usage_.end()) {
-    return;
-  }
-  it->second--;
-  if (it->second != 0) {
-    return;
-  }
-  tensor_usage_.erase(it);
-  auto tensor_op_it = tensor_tape_.find(tensor_id);
-  if (tensor_op_it == tensor_tape_.end()) {
-    return;
-  }
-  const int64 op_id = tensor_op_it->second;
-  if (op_id == -1) {
-    // Do not delete watched tensors.
-    return;
-  }
-  tensor_tape_.erase(tensor_op_it);
-  auto op_it = op_tape_.find(op_id);
-  CHECK(op_it != op_tape_.end());
-  for (const auto& output : op_it->second.output_tensor_info) {
-    if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
-      // Found a usage for an output, so cannot delete the op.
-      return;
-    }
-  }
-  for (int64 id : op_it->second.input_tensor_id) {
-    DeleteTrace(id);
-  }
-  op_it->second.backward_function_deleter();
-  op_tape_.erase(op_it);
-}
-
-// Terminology:
-//
-//  - op: a possibly composite operation, which has an entry in the tape
-//  - target: dy in dx/dy
-//  - source: dx in dx/dy
-//  - tensor: one of the many inputs or outputs of an operation
-//
-// Below here we do the gradient algorithm. It works as follows:
-//
-// First we filter the tape to just the subset of operations we want to
-// differentiate. In the process of doing so we count how many times each Tensor
-// is used as an input to an op (so we know when we're done computing gradients
-// for that Tensor). We also count, for each tape entry, how many of its output
-// Tensors need gradients to be computed (Tensors which are not used do not need
-// any gradients to be computed).
-//
-// Finally, we start a backprop stack with a set of tape entries for which we
-// have all gradients available. This set usually is a subset of the set of
-// targets (not all since targets which have outputs in the tape will not have
-// gradients available initially).
-//
-// Then we repeatedly pop an entry from the stack, run its backprop, and update
-// the gradients of its inputs. Once we have computed all gradients for a single
-// input we can mark this input as done, and this can trigger adding an entry to
-// the stack if all outputs of that entry are now done.
-//
-// When the stack is empty we have gradients for all tensors we're interested
-// in.
-
-struct BackpropInitialState {
-  OpTape op_tape;
-
-  // Map from tensor ID to how many references still exist for this tensor in
-  // the tape.
-  std::unordered_map<int64, int64> tensor_usage_counts;
-
-  // Maps from op ID to how many output tensors of this op still need to have
-  // their gradients computed.
-  std::unordered_map<int64, int64> op_missing_tensor;
-};
-
-BackpropInitialState PrepareBackprop(
-    gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
-    OpTape op_tape, const std::unordered_set<int64>& sources_set) {
-  std::vector<int64> tensor_stack;
-  tensor_stack.reserve(target.size());
-  for (auto t : target) {
-    tensor_stack.push_back(t);
-  }
-  BackpropInitialState result;
-  while (!tensor_stack.empty()) {
-    int64 tensor_id = tensor_stack.back();
-    tensor_stack.pop_back();
-    auto op_id_it = tensor_tape.find(tensor_id);
-    if (op_id_it == tensor_tape.end()) {
-      continue;
-    }
-    int64 op_id = op_id_it->second;
-    auto op_it = op_tape.find(op_id);
-    auto result_op_it = result.op_tape.find(op_id);
-    if (op_id == -1 || op_it == op_tape.end() ||
-        result_op_it != result.op_tape.end()) {
-      continue;
-    }
-    CHECK(result.op_tape.emplace(op_id, op_it->second).second);
-    for (auto it : op_it->second.input_tensor_id) {
-      auto count_it = result.tensor_usage_counts.find(it);
-      if (count_it != result.tensor_usage_counts.end()) {
-        count_it->second++;
-      } else {
-        result.tensor_usage_counts[it] = 1;
-        if (sources_set.find(it) == sources_set.end() &&
-            tensor_tape.find(it) != tensor_tape.end()) {
-          tensor_stack.push_back(it);
-        }
-      }
-    }
-    op_tape.erase(op_it);
-  }
-  for (auto& pair : result.tensor_usage_counts) {
-    auto it = tensor_tape.find(pair.first);
-    if (it != tensor_tape.end() && it->second != -1) {
-      result.op_missing_tensor[it->second] += 1;
-    }
-  }
-  // Call destructors for all unneeded gradient functions.
-  for (const auto& op_pair : op_tape) {
-    op_pair.second.backward_function_deleter();
-  }
-  return result;
-}
-
-std::vector<int64> InitialStack(
-    const OpTape& op_tape,
-    const std::unordered_map<int64, int64>& op_missing_tensor) {
-  std::vector<int64> result;
-  for (auto& op_entry : op_tape) {
-    if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
-      result.push_back(op_entry.first);
-    }
-  }
-  return result;
-}
-
-Status InitialGradients(const VSpace& vspace, gtl::ArraySlice<void*> target,
-                        gtl::ArraySlice<void*> output_gradients,
-                        std::unordered_map<int64, int64> tensor_usage_counts,
-                        std::unordered_map<int64, std::vector<void*>>* result) {
-  for (int i = 0; i < target.size(); ++i) {
-    int64 id = vspace.TensorId(target[i]);
-    if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
-      if (!output_gradients.empty() && output_gradients[i] != nullptr) {
-        // TODO(apassos) figure out how to print debugging information here.
-        return errors::InvalidArgument(
-            "A gradient was provided for a tensor which is used as part of the "
-            "computation.");
-      }
-    } else {
-      if (output_gradients.empty() || output_gradients[i] == nullptr) {
-        (*result)[id].push_back(vspace.OnesLike(target[i]));
-      } else {
-        (*result)[id].push_back(output_gradients[i]);
-      }
-    }
-  }
-  return Status::OK();
-}
-
-// If over kMinAggregateCount gradients are accumulated and the total
-// memory consumption is over kMinAggregateBytes, do an early aggregation
-// so as to release the gradient tensor to save memory.
-static const int kMinAggregateCount = 4;
-static const int kMinAggregateBytes = 128 * 1024 * 1024;
-
-Status GradientTape::Gradient(const VSpace& vspace,
-                              gtl::ArraySlice<void*> target,
-                              gtl::ArraySlice<void*> sources,
-                              gtl::ArraySlice<void*> output_gradients,
-                              std::vector<void*>* result) {
-  std::vector<int64> id_sources;
-  id_sources.reserve(sources.size());
-  for (void* s : sources) {
-    id_sources.push_back(vspace.TensorId(s));
-  }
-  std::unordered_set<int64> sources_set(id_sources.begin(), id_sources.end());
-  std::vector<int64> id_targets;
-  id_sources.reserve(target.size());
-  for (void* t : target) {
-    id_targets.push_back(vspace.TensorId(t));
-  }
-  BackpropInitialState state = PrepareBackprop(
-      id_targets, tensor_tape_, std::move(op_tape_), sources_set);
-  std::vector<int64> op_stack =
-      InitialStack(state.op_tape, state.op_missing_tensor);
-  std::unordered_map<int64, std::vector<void*>> gradients;
-  Status s = InitialGradients(vspace, target, output_gradients,
-                              state.tensor_usage_counts, &gradients);
-  auto cleanup = [&state]() {
-    // Release all backprop functions
-    for (const auto& pair : state.op_tape) {
-      pair.second.backward_function_deleter();
-    }
-  };
-  if (!s.ok()) {
-    cleanup();
-    return s;
-  }
-  std::unordered_map<int64, int64> gradients_size;
-  // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
-  // time, for better CPU backprop performance.
-  VLOG(1) << "Initial stack:";
-  if (VLOG_IS_ON(1)) {
-    for (auto t : op_stack) {
-      VLOG(1) << "  " << t;
-    }
-  }
-  std::unordered_map<string, std::unordered_set<int>>
-      functions_accept_none_for_indices({
-          {"SoftmaxCrossEntropyWithLogits", {1}},
-          {"FusedBatchNorm", {1, 2, 3, 4}},
-      });
-  while (!op_stack.empty()) {
-    const int64 op = op_stack.back();
-    VLOG(1) << "Popped " << op;
-    op_stack.pop_back();
-    auto op_it = state.op_tape.find(op);
-    if (op_it == state.op_tape.end()) {
-      // It is possible for ops to end up on the stack if they are unrelated to
-      // the target; we should just skip them.
-      continue;
-    }
-    auto trace = std::move(op_it->second);
-    state.op_tape.erase(op_it);
-    std::vector<void*> out_gradients;
-    out_gradients.reserve(trace.output_tensor_info.size());
-    for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
-      const int64 id = trace.output_tensor_info[i].id;
-      auto grad_it = gradients.find(id);
-      if (grad_it == gradients.end()) {
-        auto func_name_it =
-            functions_accept_none_for_indices.find(trace.op_type);
-        if (func_name_it != functions_accept_none_for_indices.end() &&
-            func_name_it->second.find(i) != func_name_it->second.end()) {
-          out_gradients.push_back(nullptr);
-        } else {
-          out_gradients.push_back(
-              vspace.Zeros(trace.output_tensor_info[i].shape,
-                           trace.output_tensor_info[i].dtype));
-        }
-      } else {
-        out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
-        if (sources_set.find(grad_it->first) == sources_set.end()) {
-          gradients.erase(grad_it);
-        }
-      }
-    }
-    std::vector<void*> in_gradients;
-    Status s = vspace.CallBackwardFunction(trace.backward_function,
-                                           out_gradients, &in_gradients);
-    if (!s.ok()) {
-      VLOG(1) << "Gradient function failed.";
-      cleanup();
-      return s;
-    }
-    VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
-            << trace.input_tensor_id.size() << " sources";
-    for (int i = 0; i < in_gradients.size(); ++i) {
-      const int64 id = trace.input_tensor_id[i];
-      if (in_gradients[i] != nullptr) {
-        auto& unaggregated_grads = gradients[id];
-        unaggregated_grads.push_back(in_gradients[i]);
-        if (unaggregated_grads.size() > kMinAggregateCount) {
-          auto size_it = gradients_size.find(id);
-          int64 size;
-          if (size_it == gradients_size.end()) {
-            size = vspace.NumElements(unaggregated_grads[0]);
-            gradients_size.emplace(id, size);
-          } else {
-            size = size_it->second;
-          }
-          if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
-            void* tensor = vspace.AggregateGradients(unaggregated_grads);
-            unaggregated_grads.clear();
-            unaggregated_grads.push_back(tensor);
-          }
-        }
-      }
-      auto usage_count_it = state.tensor_usage_counts.find(id);
-      if (usage_count_it == state.tensor_usage_counts.end()) {
-        VLOG(1) << "Tensor " << id << " not used";
-        continue;
-      }
-      usage_count_it->second--;
-      if (usage_count_it->second > 0) {
-        VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
-        continue;
-      }
-      auto tape_it = tensor_tape_.find(id);
-      if (tape_it == tensor_tape_.end()) {
-        VLOG(1) << "Tensor " << id
-                << " has no associated op. Deleting gradient";
-        auto grad_it = gradients.find(id);
-        if (grad_it != gradients.end()) {
-          for (auto g : grad_it->second) {
-            vspace.DeleteTensor(g);
-          }
-          gradients.erase(grad_it);
-        }
-        continue;
-      }
-      const int64 op_id = tape_it->second;
-      if (op_id == -1) {
-        VLOG(1) << "Tensor " << id << " is source";
-        continue;
-      }
-      auto missing_it = state.op_missing_tensor.find(op_id);
-      if (missing_it != state.op_missing_tensor.end()) {
-        missing_it->second--;
-        VLOG(1) << "Op " << op_id << " missing " << missing_it->second
-                << " output gradients";
-        if (missing_it->second == 0) {
-          op_stack.push_back(op_id);
-        }
-      }
-    }
-  }
-  CHECK(state.op_tape.empty());
-  result->reserve(sources.size());
-  for (auto is : id_sources) {
-    auto grad_it = gradients.find(is);
-    if (grad_it == gradients.end()) {
-      result->push_back(nullptr);
-    } else {
-      if (grad_it->second.size() == 1) {
-        result->push_back(grad_it->second[0]);
-      } else {
-        result->push_back(vspace.AggregateGradients(grad_it->second));
-      }
-      gradients.erase(grad_it);
-    }
-  }
-  VLOG(1) << "Final gradients size: " << gradients.size();
-  for (auto grad_pair : gradients) {
-    for (const auto& g : grad_pair.second) {
-      vspace.DeleteTensor(g);
-    }
-  }
-  return Status::OK();
-}
-
-}  // namespace eager
-}  // namespace tensorflow
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 2bb62a7ab37..654ceb7bec4 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -19,6 +19,7 @@ limitations under the License.
 // maintains the data structures required to do so.
 
 #include <unordered_map>
+#include <unordered_set>
 #include <vector>
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/types.h"
@@ -36,13 +37,14 @@ struct TapeTensor {
 };
 
 // Represents an entry in the tape.
+template <typename BackwardFunction>
 struct OpTapeEntry {
   string op_type;
   std::vector<TapeTensor> output_tensor_info;
   std::vector<int64> input_tensor_id;
 
   // TODO(apassos) consider narrowing down this interface.
-  void* backward_function;
+  BackwardFunction* backward_function;
 
   // Should be called before deleting the backward function. TODO(apassos) use
   // unique_ptrs to ensure this happens.
@@ -55,51 +57,67 @@ struct OpTapeEntry {
 using TensorTape = std::unordered_map<int64, int64>;
 
 // Map from operation-id to tape entry.
-using OpTape = std::unordered_map<int64, OpTapeEntry>;
+template <typename BackwardFunction>
+using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
 
 // Operations the tape needs to perform on tensors to do backpropagation. Named
 // "vspace" because a subset of these are related to a vector space, such as
 // adding gradients, getting zeroes, etc. Currently cannot be implemented
 // without using tensorflow python code, hence left unspecified here.
 //
-// We currently use void* for tensors, backward functions, and gradients (which
-// can be but are not required to be tensors). TODO(apassos) replace this first
-// with templates to allow for pyobject specialization in the client followed by
-// a TFE_TensorHandle specialization, which is blocked by quite a few things
-// still.
+// Tensor is a representation of a tensor. We need to take its ID, and it needs
+// to match IDs in the tape.
+//
+// Gradient is the type returned by gradient functions. In Python TF it's either
+// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
+// to allow their size to be computed and they need to be passable to a backward
+// function and deleted (as the backprop code creates lots of gradients the user
+// is not interested in).
+//
+// BackwardFunction needs to be a closure which stores intermediate activations
+// from the forward computation and calls a vector-jacobian product function
+// (also known as adjoint function) to compute, given downstream gradients,
+// upstream gradients.
+//
+// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
+// specialization, which is blocked by quite a few things needing to loop back
+// into python now.
+template <typename Tensor, typename Gradient, typename BackwardFunction>
 class VSpace {
  public:
   virtual ~VSpace() {}
 
-  // Returns the number of elements in the tensor.
-  virtual int64 NumElements(void* tensor) const = 0;
+  // Returns the number of elements in the gradient tensor.
+  virtual int64 NumElements(Gradient* tensor) const = 0;
 
   // Consumes references to the tensors in the gradient_tensors list and returns
   // a tensor with the result.
-  virtual void* AggregateGradients(
-      gtl::ArraySlice<void*> gradient_tensors) const = 0;
+  virtual Gradient* AggregateGradients(
+      gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
 
   // Returns a tensor of the right shape and dtype filled with zeros.
-  virtual void* Zeros(TensorShape shape, DataType dtype) const = 0;
+  virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
 
   // Returns a Tensor which is filled with ones and like the input.
-  virtual void* OnesLike(void*) const = 0;
+  virtual Gradient* OnesLike(Tensor*) const = 0;
 
   // Returns an integer which is a unique-to-within-this-program handle for this
   // tensor.
-  virtual int64 TensorId(void* tensor) const = 0;
+  virtual int64 TensorId(Tensor* tensor) const = 0;
 
   // Calls the passed-in backward function.
-  virtual Status CallBackwardFunction(void* backward_function,
-                                      gtl::ArraySlice<void*> output_gradients,
-                                      std::vector<void*>* result) const = 0;
+  virtual Status CallBackwardFunction(
+      BackwardFunction* backward_function,
+      gtl::ArraySlice<Gradient*> output_gradients,
+      std::vector<Gradient*>* result) const = 0;
 
   // Deletes the input tensor.
-  virtual void DeleteTensor(void* tensor) const = 0;
+  virtual void DeleteGradient(Gradient* gradient) const = 0;
 };
 
 // Traces the execution of operations, doing eager garbage collection, and
 // exporting a full trace so other code can do backpropagation. Not thread-safe.
+template <typename Tensor, typename Gradient, typename BackwardFunction>
 class GradientTape {
  public:
   GradientTape() {}
@@ -116,7 +134,7 @@ class GradientTape {
   void RecordOperation(const string& op_type,
                        gtl::ArraySlice<TapeTensor> output_tensors,
                        gtl::ArraySlice<int64> input_tensor_id,
-                       void* backward_function,
+                       BackwardFunction* backward_function,
                        const std::function<void()>& backward_function_deleter);
 
   void DeleteTrace(int64 tensor_id);
@@ -125,14 +143,15 @@ class GradientTape {
   // once) and produces the gradient of the target tensors with respect to the
   // source tensors. The output gradients are used if not empty and not
   // null. The result is populated with one tensor per target element.
-  Status Gradient(const VSpace& vspace, gtl::ArraySlice<void*> target,
-                  gtl::ArraySlice<void*> sources,
-                  gtl::ArraySlice<void*> output_gradients,
-                  std::vector<void*>* result);
+  Status ComputeGradient(
+      const VSpace<Tensor, Gradient, BackwardFunction>& vspace,
+      gtl::ArraySlice<Tensor*> target, gtl::ArraySlice<Tensor*> sources,
+      gtl::ArraySlice<Gradient*> output_gradients,
+      std::vector<Gradient*>* result);
 
  private:
   TensorTape tensor_tape_;
-  OpTape op_tape_;
+  OpTape<BackwardFunction> op_tape_;
   int64 next_op_id_{0};
 
   // Map from tensor id to number of remaining usages (i.e. how many entries in
@@ -140,6 +159,412 @@ class GradientTape {
   std::unordered_map<int64, int64> tensor_usage_;
 };
 
+// Template instantiations here
+
+template <typename Tensor, typename Gradient, typename BackwardFunction>
+bool GradientTape<Tensor, Gradient, BackwardFunction>::ShouldRecord(
+    gtl::ArraySlice<int64> tensor_ids) {
+  for (int64 i : tensor_ids) {
+    if (tensor_tape_.find(i) != tensor_tape_.end()) {
+      return true;
+    }
+  }
+  return false;
+}
+
+template <typename Tensor, typename Gradient, typename BackwardFunction>
+void GradientTape<Tensor, Gradient, BackwardFunction>::Watch(int64 tensor_id) {
+  tensor_tape_.emplace(tensor_id, -1);
+}
+
+template <typename Tensor, typename Gradient, typename BackwardFunction>
+void GradientTape<Tensor, Gradient, BackwardFunction>::RecordOperation(
+    const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+    gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
+    const std::function<void()>& backward_function_deleter) {
+  if (!ShouldRecord(input_tensor_id)) {
+    backward_function_deleter();
+    return;
+  }
+  std::vector<int64> ids;
+  ids.reserve(input_tensor_id.size());
+  for (int64 i : input_tensor_id) {
+    tensor_usage_[i]++;
+    ids.push_back(i);
+  }
+  const int64 op_id = next_op_id_++;
+  std::vector<TapeTensor> tensors;
+  tensors.reserve(output_tensors.size());
+  for (const TapeTensor& o : output_tensors) {
+    // Note: the tensor can have already been watched and hence be in the tape,
+    // so we cannot check that we're inserting it here.
+    tensor_tape_[o.id] = op_id;
+    tensor_usage_[o.id] = 1;
+    tensors.push_back(o);
+  }
+  op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
+      op_type, tensors, ids, backward_function, backward_function_deleter};
+}
+
+template <typename Tensor, typename Gradient, typename BackwardFunction>
+void GradientTape<Tensor, Gradient, BackwardFunction>::DeleteTrace(
+    int64 tensor_id) {
+  auto it = tensor_usage_.find(tensor_id);
+  if (it == tensor_usage_.end()) {
+    return;
+  }
+  it->second--;
+  if (it->second != 0) {
+    return;
+  }
+  tensor_usage_.erase(it);
+  auto tensor_op_it = tensor_tape_.find(tensor_id);
+  if (tensor_op_it == tensor_tape_.end()) {
+    return;
+  }
+  const int64 op_id = tensor_op_it->second;
+  if (op_id == -1) {
+    // Do not delete watched tensors.
+    return;
+  }
+  tensor_tape_.erase(tensor_op_it);
+  auto op_it = op_tape_.find(op_id);
+  CHECK(op_it != op_tape_.end());
+  for (const auto& output : op_it->second.output_tensor_info) {
+    if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
+      // Found a usage for an output, so cannot delete the op.
+      return;
+    }
+  }
+  for (int64 id : op_it->second.input_tensor_id) {
+    DeleteTrace(id);
+  }
+  op_it->second.backward_function_deleter();
+  op_tape_.erase(op_it);
+}
+
+// Terminology:
+//
+//  - op: a possibly composite operation, which has an entry in the tape
+//  - target: dy in dx/dy
+//  - source: dx in dx/dy
+//  - tensor: one of the many inputs or outputs of an operation
+//
+// Below here we do the gradient algorithm. It works as follows:
+//
+// First we filter the tape to just the subset of operations we want to
+// differentiate. In the process of doing so we count how many times each Tensor
+// is used as an input to an op (so we know when we're done computing gradients
+// for that Tensor). We also count, for each tape entry, how many of its output
+// Tensors need gradients to be computed (Tensors which are not used do not need
+// any gradients to be computed).
+//
+// Finally, we start a backprop stack with a set of tape entries for which we
+// have all gradients available. This set usually is a subset of the set of
+// targets (not all since targets which have outputs in the tape will not have
+// gradients available initially).
+//
+// Then we repeatedly pop an entry from the stack, run its backprop, and update
+// the gradients of its inputs. Once we have computed all gradients for a single
+// input we can mark this input as done, and this can trigger adding an entry to
+// the stack if all outputs of that entry are now done.
+//
+// When the stack is empty we have gradients for all tensors we're interested
+// in.
+
+namespace {
+
+template <typename BackwardFunction>
+struct BackpropInitialState {
+  OpTape<BackwardFunction> op_tape;
+
+  // Map from tensor ID to how many references still exist for this tensor in
+  // the tape.
+  std::unordered_map<int64, int64> tensor_usage_counts;
+
+  // Maps from op ID to how many output tensors of this op still need to have
+  // their gradients computed.
+  std::unordered_map<int64, int64> op_missing_tensor;
+};
+
+template <typename BackwardFunction>
+BackpropInitialState<BackwardFunction> PrepareBackprop(
+    gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
+    OpTape<BackwardFunction> op_tape,
+    const std::unordered_set<int64>& sources_set) {
+  std::vector<int64> tensor_stack;
+  tensor_stack.reserve(target.size());
+  for (auto t : target) {
+    tensor_stack.push_back(t);
+  }
+  BackpropInitialState<BackwardFunction> result;
+  while (!tensor_stack.empty()) {
+    int64 tensor_id = tensor_stack.back();
+    tensor_stack.pop_back();
+    auto op_id_it = tensor_tape.find(tensor_id);
+    if (op_id_it == tensor_tape.end()) {
+      continue;
+    }
+    int64 op_id = op_id_it->second;
+    auto op_it = op_tape.find(op_id);
+    auto result_op_it = result.op_tape.find(op_id);
+    if (op_id == -1 || op_it == op_tape.end() ||
+        result_op_it != result.op_tape.end()) {
+      continue;
+    }
+    CHECK(result.op_tape.emplace(op_id, op_it->second).second);
+    for (auto it : op_it->second.input_tensor_id) {
+      auto count_it = result.tensor_usage_counts.find(it);
+      if (count_it != result.tensor_usage_counts.end()) {
+        count_it->second++;
+      } else {
+        result.tensor_usage_counts[it] = 1;
+        if (sources_set.find(it) == sources_set.end() &&
+            tensor_tape.find(it) != tensor_tape.end()) {
+          tensor_stack.push_back(it);
+        }
+      }
+    }
+    op_tape.erase(op_it);
+  }
+  for (auto& pair : result.tensor_usage_counts) {
+    auto it = tensor_tape.find(pair.first);
+    if (it != tensor_tape.end() && it->second != -1) {
+      result.op_missing_tensor[it->second] += 1;
+    }
+  }
+  // Call destructors for all unneeded gradient functions.
+  for (const auto& op_pair : op_tape) {
+    op_pair.second.backward_function_deleter();
+  }
+  return result;
+}
+
+template <typename BackwardFunction>
+std::vector<int64> InitialStack(
+    const OpTape<BackwardFunction>& op_tape,
+    const std::unordered_map<int64, int64>& op_missing_tensor) {
+  std::vector<int64> result;
+  for (auto& op_entry : op_tape) {
+    if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) {
+      result.push_back(op_entry.first);
+    }
+  }
+  return result;
+}
+
+template <typename Tensor, typename Gradient, typename BackwardFunction>
+Status InitialGradients(
+    const VSpace<Tensor, Gradient, BackwardFunction>& vspace,
+    gtl::ArraySlice<Tensor*> target,
+    gtl::ArraySlice<Gradient*> output_gradients,
+    std::unordered_map<int64, int64> tensor_usage_counts,
+    std::unordered_map<int64, std::vector<Gradient*>>* result) {
+  for (int i = 0; i < target.size(); ++i) {
+    int64 id = vspace.TensorId(target[i]);
+    if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
+      if (!output_gradients.empty() && output_gradients[i] != nullptr) {
+        // TODO(apassos) figure out how to print debugging information here.
+        return errors::InvalidArgument(
+            "A gradient was provided for a tensor which is used as part of the "
+            "computation.");
+      }
+    } else {
+      if (output_gradients.empty() || output_gradients[i] == nullptr) {
+        (*result)[id].push_back(vspace.OnesLike(target[i]));
+      } else {
+        (*result)[id].push_back(output_gradients[i]);
+      }
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
+// If over kMinAggregateCount gradients are accumulated and the total
+// memory consumption is over kMinAggregateBytes, do an early aggregation
+// so as to release the gradient tensor to save memory.
+constexpr int kMinAggregateCount = 4;
+constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
+
+template <typename Tensor, typename Gradient, typename BackwardFunction>
+Status GradientTape<Tensor, Gradient, BackwardFunction>::ComputeGradient(
+    const VSpace<Tensor, Gradient, BackwardFunction>& vspace,
+    gtl::ArraySlice<Tensor*> target, gtl::ArraySlice<Tensor*> sources,
+    gtl::ArraySlice<Gradient*> output_gradients,
+    std::vector<Gradient*>* result) {
+  std::vector<int64> id_sources;
+  id_sources.reserve(sources.size());
+  for (Tensor* s : sources) {
+    id_sources.push_back(vspace.TensorId(s));
+  }
+  std::unordered_set<int64> sources_set(id_sources.begin(), id_sources.end());
+  std::vector<int64> id_targets;
+  id_sources.reserve(target.size());
+  for (Tensor* t : target) {
+    id_targets.push_back(vspace.TensorId(t));
+  }
+  BackpropInitialState<BackwardFunction> state = PrepareBackprop(
+      id_targets, tensor_tape_, std::move(op_tape_), sources_set);
+  std::vector<int64> op_stack =
+      InitialStack(state.op_tape, state.op_missing_tensor);
+  std::unordered_map<int64, std::vector<Gradient*>> gradients;
+  Status s = InitialGradients(vspace, target, output_gradients,
+                              state.tensor_usage_counts, &gradients);
+  auto cleanup = [&state]() {
+    // Release all backprop functions
+    for (const auto& pair : state.op_tape) {
+      pair.second.backward_function_deleter();
+    }
+  };
+  if (!s.ok()) {
+    cleanup();
+    return s;
+  }
+  std::unordered_map<int64, int64> gradients_size;
+  // TODO(apassos) multiple threads could be dequeuing from op_stack at the same
+  // time, for better CPU backprop performance.
+  VLOG(1) << "Initial stack:";
+  if (VLOG_IS_ON(1)) {
+    for (auto t : op_stack) {
+      VLOG(1) << "  " << t;
+    }
+  }
+  std::unordered_map<string, std::unordered_set<int>>
+      functions_accept_none_for_indices({
+          {"SoftmaxCrossEntropyWithLogits", {1}},
+          {"FusedBatchNorm", {1, 2, 3, 4}},
+      });
+  while (!op_stack.empty()) {
+    const int64 op = op_stack.back();
+    VLOG(1) << "Popped " << op;
+    op_stack.pop_back();
+    auto op_it = state.op_tape.find(op);
+    if (op_it == state.op_tape.end()) {
+      // It is possible for ops to end up on the stack if they are unrelated to
+      // the target; we should just skip them.
+      continue;
+    }
+    auto trace = std::move(op_it->second);
+    state.op_tape.erase(op_it);
+    std::vector<Gradient*> out_gradients;
+    out_gradients.reserve(trace.output_tensor_info.size());
+    for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
+      const int64 id = trace.output_tensor_info[i].id;
+      auto grad_it = gradients.find(id);
+      if (grad_it == gradients.end()) {
+        auto func_name_it =
+            functions_accept_none_for_indices.find(trace.op_type);
+        if (func_name_it != functions_accept_none_for_indices.end() &&
+            func_name_it->second.find(i) != func_name_it->second.end()) {
+          out_gradients.push_back(nullptr);
+        } else {
+          out_gradients.push_back(
+              vspace.Zeros(trace.output_tensor_info[i].shape,
+                           trace.output_tensor_info[i].dtype));
+        }
+      } else {
+        out_gradients.push_back(vspace.AggregateGradients(grad_it->second));
+        if (sources_set.find(grad_it->first) == sources_set.end()) {
+          gradients.erase(grad_it);
+        }
+      }
+    }
+    std::vector<Gradient*> in_gradients;
+    Status s = vspace.CallBackwardFunction(trace.backward_function,
+                                           out_gradients, &in_gradients);
+    if (!s.ok()) {
+      VLOG(1) << "Gradient function failed.";
+      cleanup();
+      return s;
+    }
+    VLOG(1) << "Got " << in_gradients.size() << " in_gradients for "
+            << trace.input_tensor_id.size() << " sources";
+    for (int i = 0; i < in_gradients.size(); ++i) {
+      const int64 id = trace.input_tensor_id[i];
+      if (in_gradients[i] != nullptr) {
+        auto& unaggregated_grads = gradients[id];
+        unaggregated_grads.push_back(in_gradients[i]);
+        if (unaggregated_grads.size() > kMinAggregateCount) {
+          auto size_it = gradients_size.find(id);
+          int64 size;
+          if (size_it == gradients_size.end()) {
+            size = vspace.NumElements(unaggregated_grads[0]);
+            gradients_size.emplace(id, size);
+          } else {
+            size = size_it->second;
+          }
+          if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) {
+            Gradient* grad = vspace.AggregateGradients(unaggregated_grads);
+            unaggregated_grads.clear();
+            unaggregated_grads.push_back(grad);
+          }
+        }
+      }
+      auto usage_count_it = state.tensor_usage_counts.find(id);
+      if (usage_count_it == state.tensor_usage_counts.end()) {
+        VLOG(1) << "Tensor " << id << " not used";
+        continue;
+      }
+      usage_count_it->second--;
+      if (usage_count_it->second > 0) {
+        VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second;
+        continue;
+      }
+      auto tape_it = tensor_tape_.find(id);
+      if (tape_it == tensor_tape_.end()) {
+        VLOG(1) << "Tensor " << id
+                << " has no associated op. Deleting gradient";
+        auto grad_it = gradients.find(id);
+        if (grad_it != gradients.end()) {
+          for (auto g : grad_it->second) {
+            vspace.DeleteGradient(g);
+          }
+          gradients.erase(grad_it);
+        }
+        continue;
+      }
+      const int64 op_id = tape_it->second;
+      if (op_id == -1) {
+        VLOG(1) << "Tensor " << id << " is source";
+        continue;
+      }
+      auto missing_it = state.op_missing_tensor.find(op_id);
+      if (missing_it != state.op_missing_tensor.end()) {
+        missing_it->second--;
+        VLOG(1) << "Op " << op_id << " missing " << missing_it->second
+                << " output gradients";
+        if (missing_it->second == 0) {
+          op_stack.push_back(op_id);
+        }
+      }
+    }
+  }
+  CHECK(state.op_tape.empty());
+  result->reserve(sources.size());
+  for (auto is : id_sources) {
+    auto grad_it = gradients.find(is);
+    if (grad_it == gradients.end()) {
+      result->push_back(nullptr);
+    } else {
+      if (grad_it->second.size() == 1) {
+        result->push_back(grad_it->second[0]);
+      } else {
+        result->push_back(vspace.AggregateGradients(grad_it->second));
+      }
+      gradients.erase(grad_it);
+    }
+  }
+  VLOG(1) << "Final gradients size: " << gradients.size();
+  for (auto grad_pair : gradients) {
+    for (const auto& g : grad_pair.second) {
+      vspace.DeleteGradient(g);
+    }
+  }
+  return Status::OK();
+}
+
 }  // namespace eager
 }  // namespace tensorflow
 
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index a00a7615d7d..d67c3b18f7b 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -443,10 +443,13 @@ void TFE_DeleteContextCapsule(PyObject* context) {
   TF_DeleteStatus(status);
 }
 
+using GradientTape =
+    tensorflow::eager::GradientTape<PyObject, PyObject, PyObject>;
+
 typedef struct {
   PyObject_HEAD
       /* Type-specific fields go here. */
-      tensorflow::eager::GradientTape* tape;
+      GradientTape* tape;
 } TFE_Py_Tape;
 
 static void TFE_Py_Tape_Delete(PyObject* tape) {
@@ -481,7 +484,7 @@ PyObject* TFE_Py_NewTape() {
   TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
   if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
   TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
-  tape->tape = new tensorflow::eager::GradientTape();
+  tape->tape = new GradientTape();
   return reinterpret_cast<PyObject*>(tape);
 }
 
@@ -627,9 +630,8 @@ void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) {
   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->DeleteTrace(tensor_id);
 }
 
-// TODO(apassos): cache the attribute lookups as member variables and decref
-// them in the destructor.
-class PyVSpace : public tensorflow::eager::VSpace {
+class PyVSpace
+    : public tensorflow::eager::VSpace<PyObject, PyObject, PyObject> {
  public:
   explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
 
@@ -661,7 +663,7 @@ class PyVSpace : public tensorflow::eager::VSpace {
     Py_XDECREF(ones_like_);
   }
 
-  tensorflow::int64 NumElements(void* tensor) const final {
+  tensorflow::int64 NumElements(PyObject* tensor) const final {
     PyObject* arglist =
         Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
     PyObject* result = PyEval_CallObject(num_elements_, arglist);
@@ -671,8 +673,8 @@ class PyVSpace : public tensorflow::eager::VSpace {
     return r;
   }
 
-  void* AggregateGradients(
-      tensorflow::gtl::ArraySlice<void*> gradient_tensors) const final {
+  PyObject* AggregateGradients(
+      tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
     PyObject* list = PyList_New(gradient_tensors.size());
     for (int i = 0; i < gradient_tensors.size(); ++i) {
       // Note: stealing a reference to the gradient tensors.
@@ -689,8 +691,8 @@ class PyVSpace : public tensorflow::eager::VSpace {
     return result;
   }
 
-  void* Zeros(tensorflow::TensorShape shape,
-              tensorflow::DataType dtype) const final {
+  PyObject* Zeros(tensorflow::TensorShape shape,
+                  tensorflow::DataType dtype) const final {
     PyObject* py_shape = PyTuple_New(shape.dims());
     for (int i = 0; i < shape.dims(); ++i) {
       PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
@@ -701,20 +703,20 @@ class PyVSpace : public tensorflow::eager::VSpace {
     Py_DECREF(arg_list);
     Py_DECREF(py_dtype);
     Py_DECREF(py_shape);
-    return reinterpret_cast<void*>(result);
+    return reinterpret_cast<PyObject*>(result);
   }
 
-  void* OnesLike(void* tensor) const final {
+  PyObject* OnesLike(PyObject* tensor) const final {
     PyObject* arg_list = Py_BuildValue("(O)", tensor);
     PyObject* result = PyEval_CallObject(ones_like_, arg_list);
     if (result == nullptr) {
       VLOG(1) << "Call to ones_like failed";
     }
     Py_DECREF(arg_list);
-    return reinterpret_cast<void*>(result);
+    return result;
   }
 
-  tensorflow::int64 TensorId(void* tensor) const final {
+  tensorflow::int64 TensorId(PyObject* tensor) const final {
     PyObject* py_tensor = reinterpret_cast<PyObject*>(tensor);
     PyObject* id_field = PyObject_GetAttrString(py_tensor, "_id");
     tensorflow::int64 id = MakeInt(id_field);
@@ -723,9 +725,9 @@ class PyVSpace : public tensorflow::eager::VSpace {
   }
 
   tensorflow::Status CallBackwardFunction(
-      void* backward_function,
-      tensorflow::gtl::ArraySlice<void*> output_gradients,
-      std::vector<void*>* result) const final {
+      PyObject* backward_function,
+      tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
+      std::vector<PyObject*>* result) const final {
     PyObject* grads = PyTuple_New(output_gradients.size());
     for (int i = 0; i < output_gradients.size(); ++i) {
       if (output_gradients[i] == nullptr) {
@@ -771,9 +773,7 @@ class PyVSpace : public tensorflow::eager::VSpace {
     return tensorflow::Status::OK();
   }
 
-  void DeleteTensor(void* tensor) const final {
-    Py_XDECREF(reinterpret_cast<PyObject*>(tensor));
-  }
+  void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
 
  private:
   PyObject* py_vspace_;
@@ -784,13 +784,13 @@ class PyVSpace : public tensorflow::eager::VSpace {
   PyObject* ones_like_;
 };
 
-std::vector<void*> MakeTensorList(PyObject* tensors) {
+std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
   PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
   if (seq == nullptr) {
     return {};
   }
   int len = PySequence_Fast_GET_SIZE(seq);
-  std::vector<void*> list;
+  std::vector<PyObject*> list;
   list.reserve(len);
   for (int i = 0; i < len; ++i) {
     list.push_back(PySequence_Fast_GET_ITEM(seq, i));
@@ -807,30 +807,30 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
     return nullptr;
   }
 
-  std::vector<void*> target_vec = MakeTensorList(target);
+  std::vector<PyObject*> target_vec = MakeTensorList(target);
   if (PyErr_Occurred()) {
     return nullptr;
   }
-  std::vector<void*> sources_vec = MakeTensorList(sources);
+  std::vector<PyObject*> sources_vec = MakeTensorList(sources);
   if (PyErr_Occurred()) {
     return nullptr;
   }
-  std::vector<void*> outgrad_vec;
+  std::vector<PyObject*> outgrad_vec;
   if (output_gradients != Py_None) {
     outgrad_vec = MakeTensorList(output_gradients);
     if (PyErr_Occurred()) {
       return nullptr;
     }
-    for (void* tensor : outgrad_vec) {
+    for (PyObject* tensor : outgrad_vec) {
       // Calling the backward function will eat a reference to the tensors in
       // outgrad_vec, so we need to increase their reference count.
-      Py_INCREF(reinterpret_cast<PyObject*>(tensor));
+      Py_INCREF(tensor);
     }
   }
   TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
-  std::vector<void*> result;
-  status->status = tape_obj->tape->Gradient(c_vspace, target_vec, sources_vec,
-                                            outgrad_vec, &result);
+  std::vector<PyObject*> result;
+  status->status = tape_obj->tape->ComputeGradient(
+      c_vspace, target_vec, sources_vec, outgrad_vec, &result);
   if (!status->status.ok()) {
     return nullptr;
   }

From bcf2ce97591e0cf6b76148e64cf073dd122f41f6 Mon Sep 17 00:00:00 2001
From: Yifei Feng <yifeif@google.com>
Date: Thu, 9 Nov 2017 08:30:44 -0800
Subject: [PATCH 068/115] Fix typo in tensorflow/python/client/timeline.py

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/14386 from yifeif:yifeif-patch-2 8391d3b0369f823fc94ea75aef2df04c611a1671
PiperOrigin-RevId: 175161296
---
 tensorflow/python/client/timeline.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/client/timeline.py b/tensorflow/python/client/timeline.py
index f3ba4244cec..1e96ac5ed48 100644
--- a/tensorflow/python/client/timeline.py
+++ b/tensorflow/python/client/timeline.py
@@ -275,7 +275,7 @@ class _TensorTracker(object):
       name:  The name of the Tensor as a string.
       object_id:  Chrome Trace object identifier assigned for this Tensor.
       timestamp:  The creation timestamp of this event as a long integer.
-      pid:  Process identifier of the assicaiated device, as an integer.
+      pid:  Process identifier of the associated device, as an integer.
       allocator:  Name of the allocator used to create the Tensor.
       num_bytes:  Number of bytes allocated (long integer).
 

From f11999586914467f510de2fc3b33fac3c984e6d4 Mon Sep 17 00:00:00 2001
From: Michael Case <mikecase@google.com>
Date: Thu, 9 Nov 2017 08:46:31 -0800
Subject: [PATCH 069/115] Internal Change.

PiperOrigin-RevId: 175163107
---
 configure.py     | 9 +++++----
 tensorflow/BUILD | 6 ++++++
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/configure.py b/configure.py
index 650541770af..e98367ef9fb 100644
--- a/configure.py
+++ b/configure.py
@@ -487,10 +487,11 @@ def set_cc_opt_flags(environ_cp):
   cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS',
                                                  question, default_cc_opt_flags)
   for opt in cc_opt_flags.split():
-    host_opt = '-march=native'  # It should be safe on the same build host.
-    write_to_bazelrc(
-        'build:opt --cxxopt=%s --copt=%s' % (opt, opt) +
-        ' --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt))
+    write_to_bazelrc('build:opt --cxxopt=%s --copt=%s' % (opt, opt))
+  host_opt = '-march=native'  # It should be safe on the same build host.
+  write_to_bazelrc(
+      'build:opt --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt))
+  write_to_bazelrc('build:opt --define with_default_optimizations=true')
 
 
 def set_tf_cuda_clang(environ_cp):
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 5a408db94e1..8d3d38b5a1f 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -178,6 +178,12 @@ config_setting(
     visibility = ["//visibility:public"],
 )
 
+config_setting(
+    name = "with_default_optimizations",
+    define_values = {"with_default_optimizations": "true"},
+    visibility = ["//visibility:public"],
+)
+
 config_setting(
     name = "with_gcp_support",
     define_values = {"with_gcp_support": "true"},

From a11d99f2ff4b3022f615d07b142b73571ff93b20 Mon Sep 17 00:00:00 2001
From: Benoit Steiner <bsteiner@google.com>
Date: Thu, 9 Nov 2017 09:11:36 -0800
Subject: [PATCH 070/115] Implemented Processor<ShapeHandle>, which allows us
 to merge shapes of unknown rank with shapes of known rank. Made sure
 Processor<DimensionHandle>::Merge doesn't erase previously inferred
 dimensions.

PiperOrigin-RevId: 175166217
---
 .../core/grappler/costs/graph_properties.cc   | 103 +++++++-----------
 .../grappler/costs/graph_properties_test.cc   |   7 +-
 2 files changed, 45 insertions(+), 65 deletions(-)

diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 8654a2a3ed0..151455778a4 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -50,13 +50,9 @@ template <typename Handle>
 struct HandleToObject {};
 template <>
 struct HandleToObject<ShapeHandle> {
-  typedef TensorShapeProto Object;
+  typedef ShapeHandle Object;
 
-  static TensorShapeProto Unknown() {
-    TensorShapeProto result;
-    result.set_unknown_rank(true);
-    return result;
-  }
+  static ShapeHandle Unknown() { return ShapeHandle(); }
 };
 
 template <>
@@ -67,13 +63,24 @@ struct HandleToObject<DimensionHandle> {
 };
 
 template <typename Handle>
-struct Processor {
+struct Processor {};
+
+template <>
+struct Processor<ShapeHandle> {
   // Extract the shape or dim denoted by the handle.
-  void ExtractValue(Handle /*t1*/,
-                    typename HandleToObject<Handle>::Object* result) {}
+  void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
   // Merge the shapes or dims.
-  Status Merge(Handle /*t1*/, Handle /*t2*/,
-               typename HandleToObject<Handle>::Object* result) {
+  Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
+    if (InferenceContext::RankKnown(*result)) {
+      // The result was initialized in a previous merge to a shape of known
+      // rank, make sure we preserve that information.
+      return Status::OK();
+    }
+    if (InferenceContext::RankKnown(h1)) {
+      *result = h1;
+    } else {
+      *result = h2;
+    }
     return Status::OK();
   }
 };
@@ -101,24 +108,34 @@ struct Processor<DimensionHandle> {
 
     if (dim1 >= 0 && dim2 >= 0) {
       CHECK_EQ(dim1, dim2);
-      *result = dim1;
+      RefineDim(dim1, result);
     } else if (dim1 >= 0 && dim2 < 0) {
-      *result = dim1;
+      RefineDim(dim1, result);
     } else if (dim1 < 0 && dim2 >= 0) {
-      *result = dim2;
+      RefineDim(dim2, result);
     } else if (dim1 < -1) {
-      *result = dim1;
+      RefineDim(dim1, result);
     } else if (dim2 < -1) {
-      *result = dim2;
+      RefineDim(dim2, result);
     } else {
       CHECK_EQ(dim1, dim2);
       CHECK_EQ(-1, dim1);
-      *result = -1;
+      RefineDim(-1, result);
     }
     return Status::OK();
   }
 
  private:
+  void RefineDim(int64 dim, int64* result) {
+    if (*result >= 0) {
+      CHECK(*result == dim || dim < 0);
+    } else if (dim >= 0) {
+      *result = dim;
+    } else if (dim < *result) {
+      *result = dim;
+    }
+  }
+
   int64 counter = 2;
 };
 
@@ -354,18 +371,17 @@ class SymbolicShapeManager {
     return dims_.Merge(d1, d2);
   }
 
-  int64 Value(DimensionHandle d) { return dims_.GetMergedValue(d); }
-
   void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
-                          InferenceContext* ctx,
                           OpInfo::TensorProperties* properties) {
     properties->set_dtype(type);
-    if (!ctx->RankKnown(shape)) {
+    ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
+    if (!InferenceContext::RankKnown(actual_shape)) {
       properties->mutable_shape()->set_unknown_rank(true);
     } else {
-      for (int j = 0; j < ctx->Rank(shape); ++j) {
-        shape_inference::DimensionHandle dim = ctx->Dim(shape, j);
-        int64 d = Value(dim);
+      for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
+        shape_inference::DimensionHandle dim =
+            InferenceContext::DimKnownRank(actual_shape, j);
+        int64 d = dims_.GetMergedValue(dim);
         properties->mutable_shape()->add_dim()->set_size(d);
       }
     }
@@ -477,41 +493,6 @@ Status GraphProperties::InferStatically() {
         }
       }
     }
-
-    // Infer output shape for Restore op.
-    if (node->op_def().name() == "Restore" ||
-        node->op_def().name() == "RestoreV2" ||
-        node->op_def().name() == "RestoreSlice") {
-      auto ctx = shape_refiner.GetContext(node);
-      for (const Edge* out_edge : node->out_edges()) {
-        const Node* output = out_edge->dst();
-        int output_idx = out_edge->src_output();
-        if (output_idx < 0) {
-          continue;
-        }
-        if (!ctx->FullyDefined(ctx->output(output_idx)) &&
-            output->op_def().name() == "Assign") {
-          if (!output->attrs().Find("validate_shape") ||
-              !output->attrs().Find("validate_shape")->b()) {
-            continue;
-          }
-          auto output_ctx = shape_refiner.GetContext(output);
-          if (output_ctx->FullyDefined(output_ctx->output(0))) {
-            ctx->set_output(output_idx, output_ctx->output(0));
-            output_ctx->MergeInput(1, output_ctx->output(0));
-          } else {
-            const Node* var;
-            TF_CHECK_OK(node->input_node(0, &var));
-            if (node->IsVariable()) {
-              auto var_ctx = shape_refiner.GetContext(var);
-              CHECK(var_ctx->FullyDefined(var_ctx->output(0)));
-              ctx->set_output(output_idx, var_ctx->output(0));
-              output_ctx->MergeInput(1, var_ctx->output(0));
-            }
-          }
-        }
-      }
-    }
   }
 
   // Propagate the initial shapes of Enter nodes manually (the Enter shape
@@ -691,7 +672,7 @@ Status GraphProperties::InferStatically() {
       input_properties.resize(ctx->num_inputs());
       for (int i = 0; i < ctx->num_inputs(); ++i) {
         shape_manager.AsTensorProperties(ctx->input(i), node->input_type(i),
-                                         ctx, &input_properties[i]);
+                                         &input_properties[i]);
       }
       for (const auto& edge : node->in_edges()) {
         if (!edge->src()->IsConstant()) {
@@ -718,7 +699,7 @@ Status GraphProperties::InferStatically() {
       output_properties.resize(ctx->num_outputs());
       for (int i = 0; i < ctx->num_outputs(); ++i) {
         shape_manager.AsTensorProperties(ctx->output(i), node->output_type(i),
-                                         ctx, &output_properties[i]);
+                                         &output_properties[i]);
       }
     }
   }
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index acd0b598aef..f785f627e12 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -296,10 +296,9 @@ TEST_F(GraphPropertiesTest, Queues) {
   ASSERT_EQ(1, props2.size());
   EXPECT_EQ("float: [3,7]", PropToString(props2[0]));
 
-  // The dequeue3 op shape is unknown.
   const auto props3 = properties.GetOutputProperties("Dequeue3");
   ASSERT_EQ(1, props3.size());
-  EXPECT_EQ("float: ?", PropToString(props3[0]));
+  EXPECT_EQ("float: [3,7]", PropToString(props3[0]));
 
   // The dequeue3 op shape is unknown. The square2 op shape is known. Verify
   // that we merge the 2 properly to determine the shape of the data coming out
@@ -678,8 +677,8 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
 
 TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
-  Output var =
-      ops::Variable(s.WithOpName("var"), TensorShape(), DataType::DT_FLOAT);
+  Output var = ops::Variable(s.WithOpName("var"), PartialTensorShape(),
+                             DataType::DT_FLOAT);
   Output var2 = ops::Variable(s.WithOpName("var2"), TensorShape({128, 256}),
                               DataType::DT_FLOAT);
   Output filename =

From 2598f7b6b3770cafb4b047740bac6d53e33ea2f7 Mon Sep 17 00:00:00 2001
From: Sanjoy Das <sanjoy@google.com>
Date: Thu, 9 Nov 2017 09:24:00 -0800
Subject: [PATCH 071/115] Explicitly disable vectorization in the LLVM IR
 generated for Dot.

In practice this does not seem to make a difference, but I did it
anyway for completeness.

PiperOrigin-RevId: 175167706
---
 .../service/llvm_ir/kernel_support_library.cc |  4 +-
 .../service/llvm_ir/kernel_support_library.h  |  8 ++-
 .../compiler/xla/service/llvm_ir/llvm_loop.cc | 68 +++++++++++++------
 .../compiler/xla/service/llvm_ir/llvm_loop.h  | 29 +++++---
 4 files changed, 75 insertions(+), 34 deletions(-)

diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
index 123a327d4db..29cc0f81bd2 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -40,7 +40,9 @@ void KernelSupportLibrary::For(
         });
   } else {
     std::unique_ptr<llvm_ir::ForLoop> loop = llvm_ir::ForLoop::EmitForLoop(
-        name, start, end, step, ir_builder_, prevent_unrolling_);
+        name, start, end, step, ir_builder_,
+        /*prevent_unrolling=*/prevent_unrolling_,
+        /*prevent_vectorization=*/prevent_vectorization_);
     ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back());
     for_body_generator(loop->GetIndVarValue(),
                        /*is_first_iteration=*/ir_builder_->CreateICmpEQ(
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index 25aa2291a66..9bafb7b5774 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -33,8 +33,11 @@ class KernelSupportLibrary {
   // If `prevent_unrolling` is true then unrolling is explicitly disabled on
   // every loop generated by this instance of KernelSupportLibrary.
   explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder,
-                                bool prevent_unrolling = true)
-      : ir_builder_(ir_builder), prevent_unrolling_(prevent_unrolling) {}
+                                bool prevent_unrolling = true,
+                                bool prevent_vectorization = true)
+      : ir_builder_(ir_builder),
+        prevent_unrolling_(prevent_unrolling),
+        prevent_vectorization_(prevent_vectorization) {}
 
   // Generates the following control flow structure:
   //
@@ -118,6 +121,7 @@ class KernelSupportLibrary {
  private:
   llvm::IRBuilder<>* ir_builder_;
   bool prevent_unrolling_;
+  bool prevent_vectorization_;
 };
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index 83d35cb9efc..7b227ce2941 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -34,21 +34,24 @@ namespace llvm_ir {
 
 ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
                  llvm::Value* start_index, llvm::Value* end_index,
-                 llvm::Value* step, bool prevent_unrolling)
+                 llvm::Value* step, bool prevent_unrolling,
+                 bool prevent_vectorization)
     : prefix_(prefix.ToString()),
       suffix_(suffix.ToString()),
       start_index_(start_index),
       end_index_(end_index),
       step_(step),
       insert_before_bb_(nullptr),
-      prevent_unrolling_(prevent_unrolling) {}
+      prevent_unrolling_(prevent_unrolling),
+      prevent_vectorization_(prevent_vectorization) {}
 
 /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
     tensorflow::StringPiece prefix, llvm::Value* start_index,
     llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
-    bool prevent_unrolling) {
-  std::unique_ptr<ForLoop> loop(new ForLoop(
-      prefix, /*suffix=*/"", start_index, end_index, step, prevent_unrolling));
+    bool prevent_unrolling, bool prevent_vectorization) {
+  std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
+                                            end_index, step, prevent_unrolling,
+                                            prevent_vectorization));
   loop->Emit(ir_builder);
   return loop;
 }
@@ -127,14 +130,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
   ir_builder->CreateStore(indvar_inc, indvar_address);
   llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_);
 
-  if (prevent_unrolling_) {
-    const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
-    llvm::LLVMContext* ctx = &back_branch->getContext();
-
+  std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(ir_builder);
+  if (!loop_metadata.empty()) {
+    llvm::LLVMContext* ctx = &start_index_->getContext();
     auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
-    auto no_unroll_node = llvm::MDNode::get(
-        *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)});
-    auto loop_id = llvm::MDNode::get(*ctx, {temp_node.get(), no_unroll_node});
+    loop_metadata.insert(loop_metadata.begin(), temp_node.get());
+    auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
     loop_id->replaceOperandWith(0, loop_id);
     back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
   }
@@ -143,6 +144,27 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
   ir_builder->SetInsertPoint(exit_bb_);
 }
 
+std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(
+    llvm::IRBuilder<>* ir_builder) {
+  const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
+  const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
+  llvm::LLVMContext* ctx = &start_index_->getContext();
+
+  std::vector<llvm::Metadata*> result;
+  if (prevent_unrolling_) {
+    result.push_back(llvm::MDNode::get(
+        *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
+  }
+
+  if (prevent_vectorization_) {
+    result.push_back(llvm::MDNode::get(
+        *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
+               llvm::ConstantAsMetadata::get(ir_builder->getFalse())}));
+  }
+
+  return result;
+}
+
 string ForLoop::GetQualifiedName(tensorflow::StringPiece name) {
   return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
 }
@@ -156,23 +178,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name,
 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
                                               llvm::Value* start_index,
                                               llvm::Value* end_index,
-                                              bool prevent_unrolling) {
+                                              bool prevent_unrolling,
+                                              bool prevent_vectorization) {
   return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1),
-                 prevent_unrolling);
+                 prevent_unrolling, prevent_vectorization);
 }
 
 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
                                               llvm::Value* start_index,
                                               llvm::Value* end_index,
                                               llvm::Value* stride,
-                                              bool prevent_unrolling) {
+                                              bool prevent_unrolling,
+                                              bool prevent_vectorization) {
   if (inner_loop_body_bb_ != nullptr) {
     // Create this loop inside the previous one.
     ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
   }
   std::unique_ptr<ForLoop> loop(new ForLoop(
       /*prefix=*/name_, suffix, start_index, end_index, stride,
-      prevent_unrolling));
+      prevent_unrolling, prevent_vectorization));
   loop->Emit(ir_builder_);
 
   if (outer_loop_preheader_bb_ == nullptr) {
@@ -191,20 +215,24 @@ std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
                                               int64 end_index,
                                               tensorflow::StringPiece suffix,
-                                              bool prevent_unrolling) {
+                                              bool prevent_unrolling,
+                                              bool prevent_vectorization) {
   CHECK_LE(start_index, end_index);
   return AddLoop(suffix, ir_builder_->getInt64(start_index),
-                 ir_builder_->getInt64(end_index), prevent_unrolling);
+                 ir_builder_->getInt64(end_index), prevent_unrolling,
+                 prevent_vectorization);
 }
 
 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
                                               int64 end_index, int64 stride,
                                               tensorflow::StringPiece suffix,
-                                              bool prevent_unrolling) {
+                                              bool prevent_unrolling,
+                                              bool prevent_vectorization) {
   CHECK_LE(start_index, end_index);
   return AddLoop(suffix, ir_builder_->getInt64(start_index),
                  ir_builder_->getInt64(end_index),
-                 ir_builder_->getInt64(stride), prevent_unrolling);
+                 ir_builder_->getInt64(stride), prevent_unrolling,
+                 prevent_vectorization);
 }
 
 IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
index 90f7c7df9e2..20069ce5a28 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -71,12 +71,10 @@ class ForLoop {
   //
   // If `prevent_unrolling` is true then emit metadata that directs LLVM to not
   // unroll the generated loop.
-  static std::unique_ptr<ForLoop> EmitForLoop(tensorflow::StringPiece prefix,
-                                              llvm::Value* start_index,
-                                              llvm::Value* end_index,
-                                              llvm::Value* step,
-                                              llvm::IRBuilder<>* ir_builder,
-                                              bool prevent_unrolling = false);
+  static std::unique_ptr<ForLoop> EmitForLoop(
+      tensorflow::StringPiece prefix, llvm::Value* start_index,
+      llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder,
+      bool prevent_unrolling = false, bool prevent_vectorization = false);
 
   // The names of the blocks follow LLVM's conventions. Control flow amongst the
   // blocks for the example C code looks like:
@@ -130,7 +128,7 @@ class ForLoop {
 
   ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix,
           llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step,
-          bool prevent_unrolling);
+          bool prevent_unrolling, bool prevent_vectorization);
 
   // Emit the loop at the insert point of the builder.
   void Emit(llvm::IRBuilder<>* ir_builder);
@@ -142,6 +140,10 @@ class ForLoop {
   // they are set.
   string GetQualifiedName(tensorflow::StringPiece name);
 
+  // Return a list of metadata nodes that should be associated with the
+  // llvm::Loop for this `ForLoop`.
+  std::vector<llvm::Metadata*> GetLoopMetadata(llvm::IRBuilder<>* ir_builder);
+
   string prefix_;
   string suffix_;
   llvm::Value* start_index_;
@@ -160,6 +162,7 @@ class ForLoop {
   llvm::BasicBlock* exit_bb_;
   llvm::Value* indvar_;
   bool prevent_unrolling_;
+  bool prevent_vectorization_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(ForLoop);
 };
@@ -185,24 +188,28 @@ class ForLoopNest {
   std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
                                    llvm::Value* start_index,
                                    llvm::Value* end_index, llvm::Value* stride,
-                                   bool prevent_unrolling = false);
+                                   bool prevent_unrolling = false,
+                                   bool prevent_vectorization = false);
 
   // Like the above, except that it defaults to a stride of one.
   std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
                                    llvm::Value* start_index,
                                    llvm::Value* end_index,
-                                   bool prevent_unrolling = false);
+                                   bool prevent_unrolling = false,
+                                   bool prevent_vectorization = false);
 
   // A convenient wrapper of the other flavor of AddLoop. The given start and
   // end index are constant.
   std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
                                    int64 stride, tensorflow::StringPiece suffix,
-                                   bool prevent_unrolling = false);
+                                   bool prevent_unrolling = false,
+                                   bool prevent_vectorization = false);
 
   // Like the above, except that it defaults to a stride of one.
   std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
                                    tensorflow::StringPiece suffix,
-                                   bool prevent_unrolling = false);
+                                   bool prevent_unrolling = false,
+                                   bool prevent_vectorization = false);
 
   // Add loops to iterate through the indices within the specified
   // shape. The returned index collects the induction variables of the

From 86f723beca5e651af6f703a8f5720d0f038ae3f1 Mon Sep 17 00:00:00 2001
From: Michael Case <mikecase@google.com>
Date: Thu, 9 Nov 2017 09:25:54 -0800
Subject: [PATCH 072/115] Internal Change.

PiperOrigin-RevId: 175167946
---
 tensorflow/BUILD | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 8d3d38b5a1f..8cb7edcc502 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -355,7 +355,7 @@ config_setting(
     visibility = ["//visibility:public"],
 )
 
-# Make a dummy rule that we can chaqnge "default" in select statements to.
+# Make a dummy rule that we can change "default" in select statements to.
 # to disable dependencies in copybara.
 config_setting(
     name = "dummy_disabled_internal",

From 09f99427f96d96393e71a4bda378493e0a6817de Mon Sep 17 00:00:00 2001
From: Jianwei Xie <xiejw@google.com>
Date: Thu, 9 Nov 2017 10:12:58 -0800
Subject: [PATCH 073/115] Adds explicity docstring about TF version in
 examples.

PiperOrigin-RevId: 175174326
---
 tensorflow/examples/learn/iris.py                 | 5 ++++-
 tensorflow/examples/learn/wide_n_deep_tutorial.py | 5 ++++-
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/tensorflow/examples/learn/iris.py b/tensorflow/examples/learn/iris.py
index 0a50b3ba87d..03e60972aa6 100644
--- a/tensorflow/examples/learn/iris.py
+++ b/tensorflow/examples/learn/iris.py
@@ -11,7 +11,10 @@
 #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
-"""Example of DNNClassifier for Iris plant dataset."""
+"""Example of DNNClassifier for Iris plant dataset.
+
+This example uses APIs in Tensorflow 1.4 or above.
+"""
 
 from __future__ import absolute_import
 from __future__ import division
diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py
index e447b3e24e7..072353392a9 100644
--- a/tensorflow/examples/learn/wide_n_deep_tutorial.py
+++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py
@@ -12,7 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Example code for TensorFlow Wide & Deep Tutorial using TF.Learn API."""
+"""Example code for TensorFlow Wide & Deep Tutorial using TF High Level API.
+
+This example uses APIs in Tensorflow 1.4 or above.
+"""
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function

From 17532a3c5fd671a59002fac83c92344c451f9936 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 10:27:15 -0800
Subject: [PATCH 074/115] Supports multi-dimensional logits and labels in
 multi_label head and some cleanup.

PiperOrigin-RevId: 175176635
---
 .../estimator/python/estimator/head.py        | 143 ++++++++----
 .../estimator/python/estimator/head_test.py   | 206 +++++++++++++++++-
 tensorflow/python/estimator/canned/head.py    | 131 +++++------
 .../python/estimator/canned/head_test.py      |  10 +-
 4 files changed, 379 insertions(+), 111 deletions(-)

diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index e344ee3c3ea..a9311a20f12 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import lookup_ops
 from tensorflow.python.ops import math_ops
@@ -48,7 +49,20 @@ def multi_class_head(n_classes,
 
   Uses `sparse_softmax_cross_entropy` loss.
 
-  This head expects to be fed integer labels specifying the class index.
+  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`.
+  In many applications, the shape is `[batch_size, n_classes]`.
+
+  `labels` must be a dense `Tensor` with shape matching `logits`, namely
+  `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
+  `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
+  `labels` must be an integer `Tensor` with values specifying the class index.
+
+  If `weight_column` is specified, weights must be of shape
+  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
+
+  The loss is the weighted sum over the input dimensions. Namely, if the input
+  labels have shape `[batch_size, 1]`, the loss is the weighted sum over
+  `batch_size`.
 
   Args:
     n_classes: Number of classes, must be greater than 2 (for 2 classes, use
@@ -57,11 +71,11 @@ def multi_class_head(n_classes,
       `tf.feature_column.numeric_column` defining feature column representing
       weights. It is used to down weight or boost examples during training. It
       will be multiplied by the loss of the example.
-    label_vocabulary: A list of strings represents possible label values. If it
-      is not given, that means labels are already encoded as integer within
-      [0, n_classes). If given, labels must be string type and have any value in
-      `label_vocabulary`. Also there will be errors if vocabulary is not
-      provided and labels are string.
+    label_vocabulary: A list or tuple of strings representing possible label
+      values. If it is not given, that means labels are already encoded as an
+      integer within [0, n_classes). If given, labels must be of string type and
+      have any value in `label_vocabulary`. Note that errors will be raised if
+      `label_vocabulary` is not provided but labels are strings.
     name: name of the head. If provided, summary and metrics keys will be
       suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
 
@@ -84,7 +98,20 @@ def binary_classification_head(
 
   This head uses `sigmoid_cross_entropy_with_logits` loss.
 
-  This head expects to be fed float labels of shape `(batch_size, 1)`.
+  The head expects `logits` with shape `[D0, D1, ... DN, 1]`.
+  In many applications, the shape is `[batch_size, 1]`.
+
+  `labels` must be a dense `Tensor` with shape matching `logits`, namely
+  `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
+  `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
+  `labels` must be float `Tensor` with values in the interval `[0, 1]`.
+
+  If `weight_column` is specified, weights must be of shape
+  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
+
+  The loss is the weighted sum over the input dimensions. Namely, if the input
+  labels have shape `[batch_size, 1]`, the loss is the weighted sum over
+  `batch_size`.
 
   Args:
     weight_column: A string or a `_NumericColumn` created by
@@ -96,11 +123,11 @@ def binary_classification_head(
       generated for each threshold value. This threshold is applied to the
       logistic values to determine the binary classification (i.e., above the
       threshold is `true`, below is `false`.
-    label_vocabulary: A list of strings represents possible label values. If it
-      is not given, that means labels are already encoded within [0, 1]. If
-      given, labels must be string type and have any value in
-      `label_vocabulary`. Also there will be errors if vocabulary is not
-      provided and labels are string.
+    label_vocabulary: A list or tuple of strings representing possible label
+      values. If it is not given, labels must be float with values within
+      [0, 1]. If given, labels must be string type and have any value in
+      `label_vocabulary`. Note that errors will be raised if `label_vocabulary`
+      is not provided but labels are strings.
     name: name of the head. If provided, summary and metrics keys will be
       suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
 
@@ -120,9 +147,22 @@ def binary_classification_head(
 def regression_head(weight_column=None,
                     label_dimension=1,
                     name=None):
-  """Creates a `_Head` for regression using the mean squared loss.
+  """Creates a `_Head` for regression using the `mean_squared_error` loss.
 
-  Uses `mean_squared_error` loss.
+  The loss is the weighted sum over all input dimensions. Namely, if the input
+  labels have shape `[batch_size, label_dimension]`, the loss is the weighted
+  sum over both `batch_size` and `label_dimension`.
+
+  The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`.
+  In many applications, the shape is `[batch_size, label_dimension]`.
+
+  The `labels` shape must match `logits`, namely
+  `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape
+  `[D0, D1, ... DN]` is also supported.
+
+  If `weight_column` is specified, weights must be of shape
+  `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
+  `[D0, D1, ... DN, label_dimension]`.
 
   Args:
     weight_column: A string or a `_NumericColumn` created by
@@ -156,15 +196,29 @@ def multi_label_head(n_classes,
   or more associated labels, from a discrete set. This is distinct from
   `multi_class_head` which has exactly one label per example.
 
-  Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a
-  multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer
-  `SparseTensor` of class indices.
+  Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over
+  the batch. Namely, if the input logits have shape `[batch_size, n_classes]`,
+  the loss is the average over `n_classes` and the weighted sum over
+  `batch_size`.
+
+  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many
+  applications, the shape is `[batch_size, label_n_classes]`.
+
+  Labels can be:
+  * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
+  * An integer `SparseTensor` of class indices. The `dense_shape` must be
+    `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
+  * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
+    must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.
+
+  If `weight_column` is specified, weights must be of shape
+  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
 
   Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
   `(labels, logits, features)` as arguments and returns unreduced loss with
-  shape `[batch_size, 1]`. `loss_fn` must support indicator `labels` with shape
-  `[batch_size, n_classes]`. Namely, the head applies `label_vocabulary` to the
-  input labels before passing them to `loss_fn`.
+  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with
+  shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies
+  `label_vocabulary` to the input labels before passing them to `loss_fn`.
 
   Args:
     n_classes: Number of classes, must be greater than 1 (for 1 class, use
@@ -191,7 +245,7 @@ def multi_label_head(n_classes,
     An instance of `_Head` for multi-label classification.
 
   Raises:
-    ValueError: if `n_classes` or `thresholds` is invalid.
+    ValueError: if `n_classes`, `thresholds`, or `loss_fn` is invalid.
   """
   thresholds = tuple(thresholds) if thresholds else tuple()
   if n_classes is None or n_classes < 2:
@@ -259,26 +313,36 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
             indices=labels.indices,
             values=label_ids_values,
             dense_shape=labels.dense_shape)
+        return math_ops.to_int64(
+            sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
       else:
-        label_ids = labels
-      return math_ops.to_int64(
-          sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
-    msg = ('labels shape must be [batch_size, {}]. '
-           'Given: ').format(self._n_classes)
-    labels_shape = array_ops.shape(labels)
-    check_rank_op = control_flow_ops.Assert(
-        math_ops.equal(array_ops.rank(labels), 2),
-        data=[msg, labels_shape])
-    check_label_dim = control_flow_ops.Assert(
-        math_ops.equal(labels_shape[-1], self._n_classes),
-        data=[msg, labels_shape])
-    with ops.control_dependencies([check_rank_op, check_label_dim]):
-      return array_ops.identity(labels)
+        err_msg = (
+            r'labels must be an integer SparseTensor with values in '
+            r'[0, {})'.format(self._n_classes))
+        assert_int = check_ops.assert_integer(
+            labels.values, message=err_msg)
+        assert_less = check_ops.assert_less(
+            labels.values,
+            ops.convert_to_tensor(self._n_classes, dtype=labels.dtype),
+            message=err_msg)
+        assert_greater = check_ops.assert_non_negative(
+            labels.values, message=err_msg)
+        with ops.control_dependencies(
+            [assert_int, assert_less, assert_greater]):
+          return math_ops.to_int64(
+              sparse_ops.sparse_to_indicator(labels, self._n_classes))
+    err_msg = (
+        r'labels must be an integer indicator Tensor with values in [0, 1]')
+    return head_lib._assert_range(labels, 2, message=err_msg)  # pylint:disable=protected-access,
 
   def create_loss(self, features, mode, logits, labels):
     """See `Head`."""
     del mode  # Unused for this head.
+    logits = ops.convert_to_tensor(logits)
     processed_labels = self._process_labels(labels)
+    processed_labels = head_lib._check_dense_labels_match_logits_and_reshape(  # pylint:disable=protected-access
+        labels=processed_labels, logits=logits,
+        expected_labels_dimension=self.logits_dimension)
     if self._loss_fn:
       unweighted_loss = _call_loss_fn(
           loss_fn=self._loss_fn, labels=processed_labels, logits=logits,
@@ -290,7 +354,8 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
       # Averages loss over classes.
       unweighted_loss = math_ops.reduce_mean(
           unweighted_loss, axis=-1, keep_dims=True)
-    weights = head_lib._weights(features, self._weight_column)  # pylint:disable=protected-access,
+    weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
+        features=features, weight_column=self._weight_column, logits=logits)
     weighted_sum_loss = losses.compute_weighted_loss(
         unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
     # _weights() can return 1.
@@ -305,7 +370,7 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
       self, features, mode, logits, labels=None, train_op_fn=None):
     """See `Head`."""
     with ops.name_scope(self._name, 'head'):
-      logits = head_lib._check_logits(logits, self.logits_dimension)  # pylint:disable=protected-access
+      logits = head_lib._check_logits_final_dim(logits, self.logits_dimension)  # pylint:disable=protected-access
 
       # Predict.
       pred_keys = prediction_keys.PredictionKeys
@@ -335,6 +400,8 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
 
       # Eval.
       if mode == model_fn.ModeKeys.EVAL:
+        weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
+            features=features, weight_column=self._weight_column, logits=logits)
         return model_fn.EstimatorSpec(
             mode=model_fn.ModeKeys.EVAL,
             predictions=predictions,
@@ -342,7 +409,7 @@ class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
             eval_metric_ops=self._eval_metric_ops(
                 labels=processed_labels,
                 probabilities=probabilities,
-                weights=head_lib._weights(features, self._weight_column),  # pylint:disable=protected-access,
+                weights=weights,
                 weighted_sum_loss=weighted_sum_loss,
                 example_weight_sum=example_weight_sum))
 
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index fd8c53f6a94..d1cf9090048 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -316,13 +316,14 @@ class MultiLabelHead(test.TestCase):
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
-          r'labels shape must be \[batch_size, 2\]\. Given: \] \[2 1\]'):
+          r'\[expected_labels_shape: \] \[2 2\] \[labels_shape: \] \[2 1\]'):
         actual_weighted_sum_loss.eval({
             labels_placeholder: np.array([[1], [1]], dtype=np.int64)
         })
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
-          r'labels shape must be \[batch_size, 2\]\. Given: \] \[2\]'):
+          r'labels shape must be \[D0, D1, ... DN, 2\]\..*'
+          r'\[Received shape: \] \[2\]'):
         actual_weighted_sum_loss.eval({
             labels_placeholder: np.array([1, 1], dtype=np.int64)
         })
@@ -387,9 +388,11 @@ class MultiLabelHead(test.TestCase):
           logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
           labels=None)
 
-  def _test_eval(self, head, logits, labels, expected_loss, expected_metrics):
+  def _test_eval(
+      self, head, logits, labels, expected_loss, expected_metrics,
+      features=None):
     spec = head.create_estimator_spec(
-        features={'x': np.array(((42,),), dtype=np.int32)},
+        features=features or {},
         mode=model_fn.ModeKeys.EVAL,
         logits=logits,
         labels=labels)
@@ -655,6 +658,54 @@ class MultiLabelHead(test.TestCase):
           labels=None,
           train_op_fn=_no_op_train_fn)
 
+  def test_train_invalid_indicator_labels(self):
+    head = head_lib.multi_label_head(n_classes=2)
+    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+    # The value 2 is outside the allowed range.
+    labels = np.array([[2, 0], [1, 1]], dtype=np.int64)
+    def _train_op_fn(loss):
+      del loss
+      return control_flow_ops.no_op()
+
+    spec = head.create_estimator_spec(
+        features={},
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        train_op_fn=_train_op_fn)
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          r'labels must be an integer indicator Tensor with values in '
+          r'\[0, 1\]'):
+        sess.run(spec.loss)
+
+  def test_train_invalid_sparse_labels(self):
+    head = head_lib.multi_label_head(n_classes=2)
+    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
+    # The value 2 is outside the allowed range.
+    labels = sparse_tensor.SparseTensor(
+        values=[2, 0, 1],
+        indices=[[0, 0], [1, 0], [1, 1]],
+        dense_shape=[2, 2])
+    def _train_op_fn(loss):
+      del loss
+      return control_flow_ops.no_op()
+
+    spec = head.create_estimator_spec(
+        features={},
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        train_op_fn=_train_op_fn)
+    with self.test_session() as sess:
+      _initialize_variables(self, spec.scaffold)
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          r'labels must be an integer SparseTensor with values in \[0, 2\)'):
+        sess.run(spec.loss)
+
   def _test_train(self, head, logits, labels, expected_loss):
     expected_train_result = 'my_train_op'
     def _train_op_fn(loss):
@@ -791,6 +842,153 @@ class MultiLabelHead(test.TestCase):
           metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3,
       }, summary_str, tol)
 
+  def test_multi_dim_weighted_train_create_loss(self):
+    """Logits and labels of shape [2, 2, 3], weights [2, 2]."""
+    head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+                       [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+    labels = np.array([[[1, 0, 0], [1, 0, 0]],
+                       [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
+    # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
+    #      = [[20/3, 10/3], [4, 8]]
+    # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
+    expected_weighted_sum_loss = 39.6667
+    expected_example_weight_sum = np.sum(weights)
+    actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss(
+        features={'weights': weights},
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels)
+    atol = 1.e-3
+    with self.test_session():
+      _initialize_variables(self, monitored_session.Scaffold())
+      self.assertAllClose(
+          expected_weighted_sum_loss, actual_weighted_sum_loss.eval(),
+          atol=atol)
+      self.assertAllClose(
+          expected_example_weight_sum, actual_example_weight_sum.eval(),
+          atol=atol)
+
+  def test_multi_dim_weighted_train(self):
+    """Logits and labels of shape [2, 2, 3], weights [2, 2]."""
+    head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+                       [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+    labels = np.array([[[1, 0, 0], [1, 0, 0]],
+                       [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
+    # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
+    #      = [[20/3, 10/3], [4, 8]]
+    # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
+    expected_loss = 39.6667
+    expected_train_result = 'my_train_op'
+    def _train_op_fn(loss):
+      return string_ops.string_join(
+          [constant_op.constant(expected_train_result),
+           string_ops.as_string(loss, precision=3)])
+
+    spec = head.create_estimator_spec(
+        features={'weights': weights},
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        train_op_fn=_train_op_fn)
+
+    atol = 1.e-3
+    with self.test_session() as sess:
+      _initialize_variables(self, monitored_session.Scaffold())
+      loss, train_result = sess.run((spec.loss, spec.train_op))
+      self.assertAllClose(expected_loss, loss, atol=atol)
+      self.assertEqual(
+          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
+          train_result)
+
+  def test_multi_dim_weights_wrong_inner_dim(self):
+    """Logits and labels of shape [2, 2, 3], weights [2, 1]."""
+    head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+                       [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+    labels = np.array([[[1, 0, 0], [1, 0, 0]],
+                       [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+    weights = np.array([[1.], [2.]], dtype=np.float32)
+    def _train_op_fn(loss):
+      del loss
+      return control_flow_ops.no_op()
+
+    spec = head.create_estimator_spec(
+        features={'weights': weights},
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        train_op_fn=_train_op_fn)
+    with self.test_session():
+      _initialize_variables(self, monitored_session.Scaffold())
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 1\]'):
+        spec.loss.eval()
+
+  def test_multi_dim_weights_wrong_outer_dim(self):
+    """Logits and labels of shape [2, 2, 3], weights [2, 2, 3]."""
+    head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+                       [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+    labels = np.array([[[1, 0, 0], [1, 0, 0]],
+                       [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+    weights = np.array([[[1., 1., 1.], [1.5, 1.5, 1.5]],
+                        [[2., 2., 2.], [2.5, 2.5, 2.5]]], dtype=np.float32)
+    weights_placeholder = array_ops.placeholder(dtype=dtypes.float32)
+    def _train_op_fn(loss):
+      del loss
+      return control_flow_ops.no_op()
+
+    spec = head.create_estimator_spec(
+        features={'weights': weights_placeholder},
+        mode=model_fn.ModeKeys.TRAIN,
+        logits=logits,
+        labels=labels,
+        train_op_fn=_train_op_fn)
+    with self.test_session():
+      _initialize_variables(self, monitored_session.Scaffold())
+      with self.assertRaisesRegexp(
+          errors.InvalidArgumentError,
+          r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 2 3\]'):
+        spec.loss.eval({weights_placeholder: weights})
+
+  def test_multi_dim_weighted_eval(self):
+    """Logits and labels of shape [2, 2, 3], weights [2, 2]."""
+    head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
+
+    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
+                       [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
+    labels = np.array([[[1, 0, 0], [1, 0, 0]],
+                       [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
+    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
+    # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
+    #      = [[20/3, 10/3], [4, 8]]
+    # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
+    expected_loss = 39.6667
+    keys = metric_keys.MetricKeys
+    expected_metrics = {
+        keys.LOSS_MEAN: expected_loss / np.sum(weights),
+        # auc and auc_pr cannot be reliably calculated for only 4 samples, but
+        # this assert tests that the algorithm remains consistent.
+        keys.AUC: 0.4977,
+        keys.AUC_PR: 0.6645,
+    }
+    self._test_eval(
+        head=head,
+        features={'weights': weights},
+        logits=logits,
+        labels=labels,
+        expected_loss=expected_loss,
+        expected_metrics=expected_metrics)
+
 
 if __name__ == '__main__':
   test.main()
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 2c3e18cb12d..eaed412c8bc 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -264,26 +264,55 @@ def _check_dense_labels_match_logits_and_reshape(
         return array_ops.identity(labels, name=scope)
 
 
-def _check_weights_match_logits_and_reshape(weights, logits):
-  """Checks that weights shape matches logits and reshapes if needed.
+def _get_weights_and_check_match_logits(
+    features, weight_column, logits, allow_per_logit_weights=False):
+  """Fetches weights from features and checks that the shape matches logits.
 
   Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape
   can be either:
-  * [D0, D1, ... DN, logits_dimension]
+  * [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`.
   * [D0, D1, ... DN, 1]
   * [D0, D1, ... DN]: In this case, weights is reshaped into
     [D0, D1, ... DN, 1] to work with weight broadcasting rules.
 
   Args:
-    weights: weights Tensor.
+    features: The features dict that contains weights.
+    weight_column: The weight column. If not given, this method returns 1.
     logits: logits Tensor.
+    allow_per_logit_weights: Boolean. Whether we allow weights along the logits
+      dimension, namely shape `[D0, D1, ... DN, logits_dimension]`.
   Returns:
     Validated and reshaped weights Tensor.
+  Raises:
+    ValueError: If the weights `Tensor` cannot be cast into float.
   """
-  err_msg = (
-      'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '
-      '[D0, D1, ... DN, logits_dimension]')
-  with ops.name_scope(None, 'weights', (weights, logits)) as scope:
+  if allow_per_logit_weights:
+    err_msg = (
+        'weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '
+        '[D0, D1, ... DN, logits_dimension]')
+  else:
+    err_msg = (
+        'weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]')
+  with ops.name_scope(
+      None, 'weights',
+      values=tuple(six.itervalues(features)) + (logits,)) as scope:
+    # Fetch the weights.
+    if weight_column is None:
+      return 1.
+    if isinstance(weight_column, six.string_types):
+      weight_column = feature_column_lib.numeric_column(
+          key=weight_column, shape=(1,))
+    if not isinstance(weight_column, feature_column_lib._NumericColumn):  # pylint: disable=protected-access
+      raise TypeError('Weight column must be either a string or _NumericColumn.'
+                      ' Given type: {}.'.format(type(weight_column)))
+    weights = weight_column._get_dense_tensor(  # pylint: disable=protected-access
+        feature_column_lib._LazyBuilder(features))  # pylint: disable=protected-access
+    if not (weights.dtype.is_floating or weights.dtype.is_integer):
+      raise ValueError('Weight column should be castable to float. '
+                       'Given dtype: {}'.format(weights.dtype))
+    weights = math_ops.to_float(weights, name='weights')
+
+    # Validate the weights shape.
     weights_shape = array_ops.shape(weights, name='weights_shape')
     logits_shape = array_ops.shape(logits, name='logits_shape')
     if (weights.shape.ndims is not None and logits.shape.ndims is not None and
@@ -295,42 +324,24 @@ def _check_weights_match_logits_and_reshape(weights, logits):
       with ops.control_dependencies([assert_dimension]):
         return array_ops.expand_dims(weights, -1, name=scope)
     supported_weights_shape = array_ops.concat([logits_shape[:-1], [1]], axis=0)
-    condition = math_ops.reduce_any(
-        [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)),
-         math_ops.reduce_all(math_ops.equal(
-             supported_weights_shape, weights_shape))])
-    assert_dimension = control_flow_ops.Assert(
-        condition=condition,
-        data=[err_msg, 'logits_shape: ', logits_shape,
-              'weights_shape: ', weights_shape])
+    if allow_per_logit_weights:
+      condition = math_ops.reduce_any(
+          [math_ops.reduce_all(math_ops.equal(logits_shape, weights_shape)),
+           math_ops.reduce_all(math_ops.equal(
+               supported_weights_shape, weights_shape))])
+      assert_dimension = control_flow_ops.Assert(
+          condition=condition,
+          data=[err_msg, 'logits_shape: ', logits_shape,
+                'weights_shape: ', weights_shape])
+    else:
+      assert_dimension = check_ops.assert_equal(
+          supported_weights_shape, weights_shape, message=err_msg,
+          data=['logits_shape: ', logits_shape,
+                'weights_shape: ', weights_shape])
     with ops.control_dependencies([assert_dimension]):
       return array_ops.identity(weights, name=scope)
 
 
-# TODO(roumposg): Delete once all heads support multi-dim input.
-def _check_logits(logits, expected_logits_dimension):
-  """Check logits type and shape."""
-  with ops.name_scope(None, 'logits', (logits,)) as scope:
-    logits = math_ops.to_float(logits)
-    logits_shape = array_ops.shape(logits)
-    assert_rank = check_ops.assert_rank(
-        logits, 2, data=[logits_shape],
-        message='logits shape must be [batch_size, logits_dimension]')
-    with ops.control_dependencies([assert_rank]):
-      static_shape = logits.shape
-      if static_shape is not None:
-        dim1 = static_shape[1]
-        if (dim1 is not None) and (dim1 != expected_logits_dimension):
-          raise ValueError(
-              'logits shape must be [batch_size, logits_dimension], got %s.' %
-              (static_shape,))
-      assert_dimension = check_ops.assert_equal(
-          expected_logits_dimension, logits_shape[1], data=[logits_shape],
-          message='logits shape must be [batch_size, logits_dimension]')
-      with ops.control_dependencies([assert_dimension]):
-        return array_ops.identity(logits, name=scope)
-
-
 def _check_logits_final_dim(logits, expected_logits_dimension):
   """Checks that logits shape is [D0, D1, ... DN, logits_dimension]."""
   with ops.name_scope(None, 'logits', (logits,)) as scope:
@@ -575,10 +586,8 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
         labels=label_ids, logits=logits, reduction=losses.Reduction.NONE)
     # Restore the squeezed dim, so unweighted_loss matches the weights shape.
     unweighted_loss = array_ops.expand_dims(unweighted_loss, axis=-1)
-    weights = _weights(features, self._weight_column)
-    if self._weight_column is not None:
-      weights = _check_weights_match_logits_and_reshape(
-          weights=weights, logits=logits)
+    weights = _get_weights_and_check_match_logits(
+        features=features, weight_column=self._weight_column, logits=logits)
     weighted_sum_loss = losses.compute_weighted_loss(
         unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
     # _weights() can return 1.
@@ -680,7 +689,7 @@ class _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):
 
 def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
     weight_column=None, thresholds=None, label_vocabulary=None, name=None):
-  """Creates a `Head` for single label binary classification.
+  """Creates a `_Head` for single label binary classification.
 
   This head uses `sigmoid_cross_entropy_with_logits` loss.
 
@@ -718,7 +727,7 @@ def _binary_logistic_head_with_sigmoid_cross_entropy_loss(
       suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
 
   Returns:
-    An instance of `Head` for binary classification.
+    An instance of `_Head` for binary classification.
 
   Raises:
     ValueError: if `thresholds` contains a value outside of `(0, 1)`.
@@ -852,10 +861,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
     labels = _assert_range(labels, 2)
     unweighted_loss = nn.sigmoid_cross_entropy_with_logits(
         labels=labels, logits=logits)
-    weights = _weights(features, self._weight_column)
-    if self._weight_column is not None:
-      weights = _check_weights_match_logits_and_reshape(
-          weights=weights, logits=logits)
+    weights = _get_weights_and_check_match_logits(
+        features=features, weight_column=self._weight_column, logits=logits)
     weighted_sum_loss = losses.compute_weighted_loss(
         unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
     # _weights() can return 1.
@@ -918,12 +925,8 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
 
       # Eval.
       if mode == model_fn.ModeKeys.EVAL:
-        weights = _weights(features, self._weight_column)
-        # TODO(roumposg): Merge this logic inside _weights once all heads
-        # support multi-dimensional inputs.
-        if self._weight_column is not None:
-          weights = _check_weights_match_logits_and_reshape(
-              weights=weights, logits=logits)
+        weights = _get_weights_and_check_match_logits(
+            features=features, weight_column=self._weight_column, logits=logits)
         return model_fn.EstimatorSpec(
             mode=model_fn.ModeKeys.EVAL,
             predictions=predictions,
@@ -957,7 +960,7 @@ class _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):
 def _regression_head_with_mean_squared_error_loss(weight_column=None,
                                                   label_dimension=1,
                                                   name=None):
-  """Creates a `_Head` for regression using the mean squared loss.
+  """Creates a `_Head` for regression using the `mean_squared_error` loss.
 
   The loss is the weighted sum over all input dimensions. Namely, if the input
   labels have shape `[batch_size, label_dimension]`, the loss is the weighted
@@ -1023,10 +1026,9 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
     labels = math_ops.to_float(labels)
     unweighted_loss = losses.mean_squared_error(
         labels=labels, predictions=logits, reduction=losses.Reduction.NONE)
-    weights = _weights(features, self._weight_column)
-    if self._weight_column is not None:
-      weights = _check_weights_match_logits_and_reshape(
-          weights=weights, logits=logits)
+    weights = _get_weights_and_check_match_logits(
+        features=features, weight_column=self._weight_column, logits=logits,
+        allow_per_logit_weights=True)
     weighted_sum_loss = losses.compute_weighted_loss(
         unweighted_loss, weights=weights, reduction=losses.Reduction.SUM)
     # _weights() can return 1.
@@ -1111,18 +1113,19 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head):
         train_op=train_op_fn(weighted_sum_loss))
 
 
-def _assert_range(labels, n_classes):
+def _assert_range(labels, n_classes, message=None):
   with ops.name_scope(None, 'assert_range', (labels,)):
     assert_less = check_ops.assert_less(
         labels,
         ops.convert_to_tensor(n_classes, dtype=labels.dtype),
-        message='Label IDs must < n_classes')
+        message=message or 'Label IDs must < n_classes')
     assert_greater = check_ops.assert_non_negative(
-        labels, message='Label IDs must >= 0')
+        labels, message=message or 'Label IDs must >= 0')
     with ops.control_dependencies((assert_less, assert_greater)):
       return array_ops.identity(labels)
 
 
+# TODO(b/69000400): Delete this method.
 def _weights(features, weight_column):
   """Fetches weights from features."""
   with ops.name_scope(None, 'weights', values=features.values()):
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 0a4ea7d81c9..4497cd26f2d 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -987,12 +987,14 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
         spec.loss.eval()
 
   def test_multi_dim_train_weights_wrong_outer_dim(self):
-    """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 2]."""
+    """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 3]."""
     head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
         n_classes=3, weight_column='weights')
     logits = np.array([[[10, 0, 0], [12, 0, 0]],
                        [[0, 10, 0], [0, 15, 0]]], dtype=np.float32)
     labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)
+    weights = np.array([[[1., 1.1, 1.2], [1.5, 1.6, 1.7]],
+                        [[2., 2.1, 2.2], [2.5, 2.6, 2.7]]])
     weights_placeholder = array_ops.placeholder(dtype=dtypes.float32)
     def _no_op_train_fn(loss):
       del loss
@@ -1008,10 +1010,8 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
       _initialize_variables(self, monitored_session.Scaffold())
       with self.assertRaisesRegexp(
           errors.InvalidArgumentError,
-          r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 2\]'):
-        spec.loss.eval({
-            weights_placeholder: np.array([[[1., 1.1], [1.5, 1.6]],
-                                           [[2., 2.1], [2.5, 2.6]]])})
+          r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 3\]'):
+        spec.loss.eval({weights_placeholder: weights})
 
   def test_multi_dim_weighted_eval(self):
     """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2]."""

From c4a2562dfcd8dd61f4d2c4ce88f3b72eeb888a5a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 10:35:13 -0800
Subject: [PATCH 075/115] Allow a key type without a constructor that takes an
 int in Squawd.

PiperOrigin-RevId: 175178089
---
 .../lib/quantiles/weighted_quantiles_buffer.h          | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
index 5e316538cef..70037d5bd8f 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
@@ -33,9 +33,9 @@ template <typename ValueType, typename WeightType,
 class WeightedQuantilesBuffer {
  public:
   struct BufferEntry {
-    BufferEntry(const ValueType& v, const WeightType& w)
-        : value(v), weight(w) {}
-    BufferEntry() : value(0), weight(0) {}
+    BufferEntry(ValueType v, WeightType w)
+        : value(std::move(v)), weight(std::move(w)) {}
+    BufferEntry() : value(), weight(0) {}
 
     bool operator<(const BufferEntry& other) const {
       return kCompFn(value, other.value);
@@ -67,7 +67,7 @@ class WeightedQuantilesBuffer {
 
   // Push entry to buffer and maintain a compact representation within
   // pre-defined size limit.
-  void PushEntry(const ValueType& value, const WeightType& weight) {
+  void PushEntry(ValueType value, WeightType weight) {
     // Callers are expected to act on a full compacted buffer after the
     // PushEntry call returns.
     QCHECK(!IsFull()) << "Buffer already full: " << max_size_;
@@ -78,7 +78,7 @@ class WeightedQuantilesBuffer {
     }
 
     // Push back the entry to the buffer.
-    vec_.push_back(BufferEntry(value, weight));
+    vec_.push_back(BufferEntry(std::move(value), std::move(weight)));
   }
 
   // Returns a sorted vector view of the base buffer and clears the buffer.

From 2bc1f2339ef7a82f868b1bbcbc80f5632800ef1c Mon Sep 17 00:00:00 2001
From: Yifei Feng <yifeif@google.com>
Date: Thu, 9 Nov 2017 12:01:53 -0800
Subject: [PATCH 076/115] Fix typo in tensorflow/python/layers/base_test.py

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/14412 from yifeif:yifeif-patch-3 4b91380c6fc1f995d48a5f184e7307f776541bd0
PiperOrigin-RevId: 175192097
---
 tensorflow/python/estimator/BUILD     | 2 --
 tensorflow/python/layers/base_test.py | 2 +-
 2 files changed, 1 insertion(+), 3 deletions(-)

diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index dba77617008..03f386e9cf8 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -245,8 +245,6 @@ py_test(
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
         "//tensorflow/python/feature_column",
-        "//third_party/py/numpy",
-        "//third_party/py/pandas",
         "@six_archive//:six",
     ],
 )
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 7ddfe37827d..509ad5a7afb 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -47,7 +47,7 @@ class BaseLayerTest(test.TestCase):
     self.assertEqual(layer.trainable_variables, [])
     self.assertEqual(layer.non_trainable_variables, [])
     if context.in_graph_mode():
-      # updates, losses only suppported in GRAPH mode
+      # updates, losses only supported in GRAPH mode
       self.assertEqual(layer.updates, [])
       self.assertEqual(layer.losses, [])
     self.assertEqual(layer.built, False)

From 534c6176f6b8704f0944ad17cc3fba5ff26784ed Mon Sep 17 00:00:00 2001
From: Igor Saprykin <isaprykin@google.com>
Date: Thu, 9 Nov 2017 12:22:01 -0800
Subject: [PATCH 077/115] `replicate_model_fn` supports aggregating gradients
 in IndexedSlices.

`gradients.gradients` may return computed gradients in IndexedSlices as opposed to a Tensor: https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/python/ops/gradients_impl.py#L881.

`replicate_model_fn` currently uses math_ops.add_n to aggregate gradients from all towers.  It doesn't work with IndexedSlices and thus needs to be handled separately.

PiperOrigin-RevId: 175194893
---
 .../python/estimator/replicate_model_fn.py    | 25 +++++-
 .../estimator/replicate_model_fn_test.py      | 87 +++++++++++++++++--
 2 files changed, 104 insertions(+), 8 deletions(-)

diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index 7005a647db5..421bf18c45d 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -34,10 +34,12 @@ from tensorflow.python.estimator import util
 from tensorflow.python.estimator.export import export_output as export_output_lib
 from tensorflow.python.framework import device as framework_device
 from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import gradients as gradients_lib
 from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
 from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables as variables_lib
@@ -183,10 +185,17 @@ def _split_batch(features, labels, number_of_shards, device):
   """Split input features and labes into batches."""
 
   def split_dictionary(dictionary):
+    """Split a dictionary into shards."""
     shards = [{} for _ in range(number_of_shards)]
     for name, tensor in six.iteritems(dictionary):
-      for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
-        shards[i][name] = shard
+      if isinstance(tensor, sparse_tensor.SparseTensor):
+        for i, shard in enumerate(
+            sparse_ops.sparse_split(
+                sp_input=tensor, num_split=number_of_shards, axis=0)):
+          shards[i][name] = shard
+      else:
+        for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
+          shards[i][name] = shard
     return shards
 
   with ops_lib.name_scope('split_inputs'):
@@ -313,7 +322,17 @@ def _call_optimizer_fn(optimizer_fn, params):
 
 def _compute_sum_on_device(values, device, name=None):
   with ops_lib.device(device):
-    return math_ops.add_n(values, name=name)
+    if isinstance(values[0], ops_lib.IndexedSlices):
+      if name:
+        raise ValueError('The name {} is not expected to be given to '
+                         'IndexedSlices {}'.format(name, values))
+
+      values_concat = array_ops.concat([v.values for v in values], axis=0)
+      indices_concat = array_ops.concat([v.indices for v in values], axis=0)
+      return ops_lib.IndexedSlices(values_concat, indices_concat,
+                                   values[0].dense_shape)
+    else:
+      return math_ops.add_n(values, name=name)
 
 
 def _train_spec(tower_specs,
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index ce286c33b01..c90169af8ce 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -65,20 +65,35 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
     data = np.linspace(
         0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)
     x_data = data.reshape(batch_size, input_dimension)
+    categorical_data = np.random.random_integers(
+        0, len(x_data), size=len(x_data))
     y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))
     train_input_fn = numpy_io.numpy_input_fn(
-        x={'x': x_data},
+        x={'x': x_data,
+           'categories': categorical_data},
         y=y_data,
         batch_size=batch_size,
         num_epochs=None,
         shuffle=True)
     eval_input_fn = numpy_io.numpy_input_fn(
-        x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)
+        x={'x': x_data,
+           'categories': categorical_data},
+        y=y_data,
+        batch_size=batch_size,
+        shuffle=False)
     predict_input_fn = numpy_io.numpy_input_fn(
-        x={'x': x_data}, batch_size=batch_size, shuffle=False)
+        x={'x': x_data,
+           'categories': categorical_data},
+        batch_size=batch_size,
+        shuffle=False)
 
     feature_columns = [
-        feature_column.numeric_column('x', shape=(input_dimension,))
+        feature_column.numeric_column('x', shape=(input_dimension,)),
+        feature_column.indicator_column(
+            feature_column.categorical_column_with_vocabulary_list(
+                'categories',
+                vocabulary_list=np.linspace(
+                    0., len(x_data), len(x_data), dtype=np.int64)))
     ]
 
     estimator = dnn.DNNClassifier(
@@ -858,7 +873,7 @@ class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
 
 class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
 
-  def test_example(self):
+  def test_vectors(self):
     with self.test_session() as session:
       total = replicate_model_fn._compute_sum_on_device(
           [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum')
@@ -867,6 +882,68 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
       self.assertEqual('test_sum', total.op.name)
       self.assertEqual(10.0, session.run(total))
 
+  def test_tensors(self):
+    with self.test_session() as session:
+      total = replicate_model_fn._compute_sum_on_device(
+          [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum')
+
+      self.assertEqual('/device:GPU:0', total.device)
+      self.assertEqual('test_sum', total.op.name)
+      self.assertAllEqual([4.0, 6.0], session.run(total))
+
+  def test_indexedslices(self):
+    with self.test_session() as session:
+      a = ops_lib.IndexedSlices(
+          constant_op.constant([1.0, 2.0]), [0, 1],
+          dense_shape=constant_op.constant([2]))
+      b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
+
+      total = replicate_model_fn._compute_sum_on_device(
+          [a, b], device='/device:GPU:0')
+
+      self.assertEqual('/device:GPU:0', total.device)
+      self.assertAllEqual([4.0, 6.0],
+                          session.run(ops_lib.convert_to_tensor(total)))
+
+  def test_indexedslices_higher_dimensions(self):
+    with self.test_session() as session:
+      a = ops_lib.IndexedSlices(
+          constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1],
+          dense_shape=constant_op.constant([2, 4]))
+      b = ops_lib.IndexedSlices(
+          constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1])
+
+      total = replicate_model_fn._compute_sum_on_device(
+          [a, b], device='/device:GPU:0')
+
+      self.assertEqual('/device:GPU:0', total.device)
+      self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]],
+                          session.run(ops_lib.convert_to_tensor(total)))
+
+  def test_indexedslices_some_dont_overlap(self):
+    with self.test_session() as session:
+      a = ops_lib.IndexedSlices(
+          constant_op.constant([1.0, 2.0]), [0, 3],
+          dense_shape=constant_op.constant([4]))
+      b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
+
+      total = replicate_model_fn._compute_sum_on_device(
+          [a, b], device='/device:GPU:0')
+
+      self.assertEqual('/device:GPU:0', total.device)
+      self.assertAllEqual([4.0, 4.0, 0.0, 2.0],
+                          session.run(ops_lib.convert_to_tensor(total)))
+
+  def test_no_name_for_indexslices(self):
+    a = ops_lib.IndexedSlices(
+        constant_op.constant([1.0, 2.0]), [0, 1],
+        dense_shape=constant_op.constant([2]))
+    b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1])
+
+    with self.assertRaisesRegexp(ValueError, ''):
+      _ = replicate_model_fn._compute_sum_on_device(
+          [a, b], device='/device:GPU:0', name='cant_name_indexslices')
+
 
 class ConcatTensorDictsTest(test_util.TensorFlowTestCase):
 

From c51b3c301d60697bb498d19ea5068ddfb2525f95 Mon Sep 17 00:00:00 2001
From: Yifei Feng <yifeif@google.com>
Date: Thu, 9 Nov 2017 12:24:28 -0800
Subject: [PATCH 078/115] Fix cmake build.

PiperOrigin-RevId: 175195239
---
 tensorflow/contrib/cmake/tf_c.cmake | 1 -
 1 file changed, 1 deletion(-)

diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake
index f3882e8cf76..3ae28b7601a 100644
--- a/tensorflow/contrib/cmake/tf_c.cmake
+++ b/tensorflow/contrib/cmake/tf_c.cmake
@@ -21,7 +21,6 @@ set(tf_c_srcs
     "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc"
     "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc"
     "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h"
-    "${tensorflow_source_dir}/tensorflow/c/eager/tape.cc"
     "${tensorflow_source_dir}/tensorflow/c/eager/tape.h"
     "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc"
     "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h"

From e384cf3822a95fad8a83d8b5e364321244a2c6dd Mon Sep 17 00:00:00 2001
From: Benoit Steiner <bsteiner@google.com>
Date: Thu, 9 Nov 2017 12:48:14 -0800
Subject: [PATCH 079/115] Use error status instead of assertions to ensure
 shape consistency

PiperOrigin-RevId: 175198248
---
 .../core/grappler/costs/graph_properties.cc   | 19 +++++++++++--------
 1 file changed, 11 insertions(+), 8 deletions(-)

diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 151455778a4..35048a4fcff 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -108,32 +108,35 @@ struct Processor<DimensionHandle> {
 
     if (dim1 >= 0 && dim2 >= 0) {
       CHECK_EQ(dim1, dim2);
-      RefineDim(dim1, result);
+      return RefineDim(dim1, result);
     } else if (dim1 >= 0 && dim2 < 0) {
-      RefineDim(dim1, result);
+      return RefineDim(dim1, result);
     } else if (dim1 < 0 && dim2 >= 0) {
-      RefineDim(dim2, result);
+      return RefineDim(dim2, result);
     } else if (dim1 < -1) {
-      RefineDim(dim1, result);
+      return RefineDim(dim1, result);
     } else if (dim2 < -1) {
-      RefineDim(dim2, result);
+      return RefineDim(dim2, result);
     } else {
       CHECK_EQ(dim1, dim2);
       CHECK_EQ(-1, dim1);
-      RefineDim(-1, result);
+      return RefineDim(-1, result);
     }
     return Status::OK();
   }
 
  private:
-  void RefineDim(int64 dim, int64* result) {
+  Status RefineDim(int64 dim, int64* result) {
     if (*result >= 0) {
-      CHECK(*result == dim || dim < 0);
+      if (!(*result == dim || dim < 0)) {
+        return errors::InvalidArgument("Inconsistent dimensions detected");
+      }
     } else if (dim >= 0) {
       *result = dim;
     } else if (dim < *result) {
       *result = dim;
     }
+    return Status::OK();
   }
 
   int64 counter = 2;

From c693b3130fabde91b09c160f36f3ac1eed6311f6 Mon Sep 17 00:00:00 2001
From: Yifei Feng <yifeif@google.com>
Date: Thu, 9 Nov 2017 12:54:10 -0800
Subject: [PATCH 080/115] Disable
 tensorflow/contrib/data/python/kernel_tests:prefetching_ops_test. Flaky in
 open source build.

PiperOrigin-RevId: 175199083
---
 tensorflow/contrib/data/python/kernel_tests/BUILD | 1 +
 1 file changed, 1 insertion(+)

diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index c1f1d90c5da..d811683ecda 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -448,6 +448,7 @@ py_test(
     size = "small",
     srcs = ["prefetching_ops_test.py"],
     srcs_version = "PY2AND3",
+    tags = ["no_oss"],  # b/68785503
     deps = [
         "//tensorflow/contrib/data/python/ops:dataset_ops",
         "//tensorflow/contrib/data/python/ops:prefetching_py",

From 10a2b450d26eca33b880fdc4887946d60064ef50 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 13:02:38 -0800
Subject: [PATCH 081/115] Add per-host input for multi-host setup.

PiperOrigin-RevId: 175200199
---
 .../contrib/tpu/python/tpu/tpu_config.py      |   5 +-
 .../contrib/tpu/python/tpu/tpu_estimator.py   | 111 ++++++++++--------
 2 files changed, 61 insertions(+), 55 deletions(-)

diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
index 097acd5ee73..916b9b3082f 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py
@@ -45,10 +45,7 @@ class TPUConfig(
       is invoked once on each host. To be precise, with a global batch size
       `train_batch_size` in `TPUEstimator` constructor, the batch size for each
       shard is `train_batch_size` // #hosts. With Per-Core input pipeline
-      deployment, the shard batch size is `train_batch_size` // #cores.  Note
-      that this only works for single-host TPU training now (tracked in
-      b/67051042). For multi-host, please use Per-Core, i.e., `False` for
-      `per_host_input_for_training`.
+      deployment, the shard batch size is `train_batch_size` // #cores.
     tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred
       within TPUEstimator, however when using ClusterSpec propagation in more
       esoteric cluster configurations, you may need to specify the job name as a
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 16d712af9e2..07877fcc761 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -232,8 +232,10 @@ class _TPUContext(object):
                          mode == model_fn_lib.ModeKeys.TRAIN
                          else self._eval_batch_size)
     # On TPU
-    return (global_batch_size // self.num_cores
-            if self.is_input_sharded_per_core() else global_batch_size)
+    if self.is_input_sharded_per_core():
+      return global_batch_size // self.num_cores
+    else:
+      return global_batch_size // self.num_hosts
 
   @property
   def batch_size_for_model_fn(self):
@@ -682,6 +684,40 @@ def generate_per_core_enqueue_ops_fn_for_host(
   return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
 
 
+def generate_per_host_enqueue_ops_fn_for_host(
+    ctx, input_fn, inputs_structure_recorder, batch_axis, device):
+  """Generates infeed enqueue ops for per-host input_fn on a single host."""
+  infeed_queue_holder = {'instance': None}
+
+  def enqueue_ops_fn():
+    with ops.device(device):
+      num_cores_per_host = ctx.num_of_cores_per_host
+      inputs = input_fn()
+      if isinstance(inputs, tuple):
+        features, labels = inputs
+      else:
+        features, labels = inputs, None
+      inputs_structure_recorder.validate_and_record_structure(
+          features, labels)
+      unsharded_tensor_list = (
+          inputs_structure_recorder.flatten_features_and_labels(
+              features, labels))
+
+      infeed_queue = tpu_feed.InfeedQueue(
+          tuple_types=[t.dtype for t in unsharded_tensor_list],
+          tuple_shapes=[t.shape for t in unsharded_tensor_list],
+          shard_dimensions=batch_axis)
+      infeed_queue_holder['instance'] = infeed_queue
+      infeed_queue.set_number_of_shards(num_cores_per_host)
+
+      per_host_enqueue_ops = (
+          infeed_queue.split_inputs_and_generate_enqueue_ops(
+              unsharded_tensor_list,
+              placement_function=lambda x: device))
+      return per_host_enqueue_ops
+  return enqueue_ops_fn, (lambda: infeed_queue_holder['instance'])
+
+
 class _InputPipeline(object):
   """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
 
@@ -856,15 +892,15 @@ class _InputPipeline(object):
     return (enqueue_ops, dequeue_fn)
 
   def _invoke_input_fn_and_record_structure(self):
+    """Deploys the input pipeline and record input structure."""
+    enqueue_ops = []
+    infeed_queues = []
+    num_hosts = self._ctx.num_hosts
+    tpu_host_placement_fn = self._ctx.tpu_host_placement_function
     if self._sharded_per_core:
       # Per-Core input pipeline deployment.
-      tpu_host_placement_fn = self._ctx.tpu_host_placement_function
-      enqueue_ops = []
-      infeed_queues = []
-
       # Invoke input pipeline for each core and placed on the corresponding
       # host.
-      num_hosts = self._ctx.num_hosts
       for host_id in range(num_hosts):
         host_device = tpu_host_placement_fn(host_id=host_id)
         with ops.device(host_device):
@@ -881,48 +917,27 @@ class _InputPipeline(object):
             # Infeed_queue_getter must be called after enqueue_ops_fn is called.
             infeed_queues.append(infeed_queue_getter())
 
-      # infeed_queue is used to generate dequeue ops. The only thing it uses for
-      # dequeue is dtypes and types. So, any one can be used. Here, grab the
-      # first one.
-      self._infeed_queue = infeed_queues[0]
-      return enqueue_ops
-
     else:
-      # TODO(b/67051042): Extend this to multi-host support.
-      host_id = 0
-      host_device = self._ctx.tpu_host_placement_function(host_id=host_id)
-      def enqueue_fn():
+      for host_id in range(num_hosts):
+        host_device = tpu_host_placement_fn(host_id=host_id)
         with ops.device(host_device):
           with ops.name_scope('input_pipeline_task%d' % (host_id)):
-            inputs = self._input_fn()
-            if isinstance(inputs, tuple):
-              features, labels = inputs
+            enqueue_ops_fn, infeed_queue_getter = (
+                generate_per_host_enqueue_ops_fn_for_host(
+                    self._ctx, self._input_fn, self._inputs_structure_recorder,
+                    self._batch_axis, host_device))
+
+            if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
+              enqueue_ops.append(_wrap_computation_in_while_loop(
+                  device=host_device, op_fn=enqueue_ops_fn))
             else:
-              features, labels = inputs, None
-            self._inputs_structure_recorder.validate_and_record_structure(
-                features, labels)
-            unsharded_tensor_list = (
-                self._inputs_structure_recorder.flatten_features_and_labels(
-                    features, labels))
-
-            self._infeed_queue = tpu_feed.InfeedQueue(
-                tuple_types=[t.dtype for t in unsharded_tensor_list],
-                tuple_shapes=[t.shape for t in unsharded_tensor_list],
-                shard_dimensions=self._batch_axis)
-            self._infeed_queue.set_number_of_shards(self._ctx.num_cores)
-
-            def placement_fn(core_id):
-              return self._ctx.tpu_host_placement_function(core_id=core_id)
-            return (
-                self._infeed_queue.split_inputs_and_generate_enqueue_ops(
-                    unsharded_tensor_list,
-                    placement_function=placement_fn))
-
-      if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
-        return _wrap_computation_in_while_loop(device=host_device,
-                                               op_fn=enqueue_fn)
-      else:
-        return enqueue_fn()
+              enqueue_ops.append(enqueue_ops_fn())
+            infeed_queues.append(infeed_queue_getter())
+    # infeed_queue is used to generate dequeue ops. The only thing it uses for
+    # dequeue is dtypes and types. So, any one can be used. Here, grab the
+    # first one.
+    self._infeed_queue = infeed_queues[0]
+    return enqueue_ops
 
   def _validate_input_pipeline(self):
     # Perform some sanity checks to log user friendly information. We should
@@ -1425,12 +1440,6 @@ class TPUEstimator(estimator_lib.Estimator):
               'eval batch size {} must be divisible by number of shards {}'
               .format(eval_batch_size, config.tpu_config.num_shards))
 
-      if (config.tpu_config.num_shards > 8 and
-          config.tpu_config.per_host_input_for_training):
-        # TODO(b/67051042): Support per_host input pipelines when num_shards > 8
-        raise NotImplementedError(
-            'Per-host input pipelines only available for num_shards <= 8')
-
     # Verifies the model_fn signature according to Estimator framework.
     estimator_lib._verify_model_fn_args(model_fn, params)  # pylint: disable=protected-access
     # We cannot store config and params in this constructor as parent

From 73b7d47031dc53ef52ef028dc0a830de8ec18238 Mon Sep 17 00:00:00 2001
From: Igor Saprykin <isaprykin@google.com>
Date: Thu, 9 Nov 2017 13:10:02 -0800
Subject: [PATCH 082/115] Disable flaky tests in replicate_model_fn_test.py.

I suspect that reducing local variables for eval metrics over more than one tower is flaky, but I haven't figured out why yet.

PiperOrigin-RevId: 175201241
---
 .../estimator/replicate_model_fn_test.py      | 108 +++++++++---------
 1 file changed, 55 insertions(+), 53 deletions(-)

diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index c90169af8ce..bb06700160d 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -223,33 +223,34 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
                                            features, labels, self.params)
       del estimator_spec
 
-  def test_eval(self):
-    features = np.array([[0.01], [0.002]])
-    labels = np.array([[0.01], [0.02]])
-
-    with self.test_session() as session:
-      replicated_model_fn = replicate_model_fn.replicate_model_fn(
-          self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
-      estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
-                                           labels, self.params)
-      session.run(variables.local_variables_initializer())
-      session.run(variables.global_variables_initializer())
-
-      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
-      auc, b = estimator_spec.eval_metric_ops['auc']
-
-      session.run([a, b])
-      accuracy = session.run(accuracy)
-      auc = session.run(auc)
-
-      # Accuracy is 0.0 (no match) in the first tower.
-      # Accuracy is 1.0 (match) in the second tower, since the feature
-      # times weight "c" happened to be equal to the label.
-      total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
-
-      self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
-      self.assertEqual(0, auc)
-      self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
+# TODO(isaprykin):  Resolve the source of flakinness.
+#   def test_eval(self):
+#     features = np.array([[0.01], [0.002]])
+#     labels = np.array([[0.01], [0.02]])
+#
+#     with self.test_session() as session:
+#       replicated_model_fn = replicate_model_fn.replicate_model_fn(
+#           self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+#     estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
+#                                            labels, self.params)
+#       session.run(variables.local_variables_initializer())
+#       session.run(variables.global_variables_initializer())
+#
+#       accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+#       auc, b = estimator_spec.eval_metric_ops['auc']
+#
+#       session.run([a, b])
+#       accuracy = session.run(accuracy)
+#       auc = session.run(auc)
+#
+#       # Accuracy is 0.0 (no match) in the first tower.
+#       # Accuracy is 1.0 (match) in the second tower, since the feature
+#       # times weight "c" happened to be equal to the label.
+#       total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
+#
+#       self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
+#       self.assertEqual(0, auc)
+#       self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
 
   def test_predict(self):
     features = np.array([[0.01], [0.002]])
@@ -523,32 +524,33 @@ class EvalSpecTest(test_util.TensorFlowTestCase):
     }
     return metrics
 
-  def test_example(self):
-    with self.test_session() as session:
-      tower_losses = map(self.create_constant_loss, [2, 4, 6])
-      tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
-      tower_specs = [
-          self.create_estimator_spec(l, m)
-          for l, m in zip(tower_losses, tower_metrics)
-      ]
-      session.run(variables.local_variables_initializer())
-
-      estimator_spec = replicate_model_fn._eval_spec(
-          tower_specs, aggregation_device='/device:GPU:0')
-
-      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
-      auc, b = estimator_spec.eval_metric_ops['auc']
-
-      self.assertEqual('/device:CPU:0', accuracy.device)
-      self.assertEqual('/device:CPU:0', auc.device)
-
-      session.run([a, b])
-      accuracy = session.run(accuracy)
-      auc = session.run(auc)
-
-      self.assertNear((12 - 2) / 12, accuracy, 0.01)
-      self.assertEqual(0, auc)
-      self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
+# TODO(isaprykin):  Resolve the source of flakinness.
+#   def test_example(self):
+#     with self.test_session() as session:
+#       tower_losses = map(self.create_constant_loss, [2, 4, 6])
+#       tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
+#       tower_specs = [
+#           self.create_estimator_spec(l, m)
+#           for l, m in zip(tower_losses, tower_metrics)
+#       ]
+#       session.run(variables.local_variables_initializer())
+#
+#       estimator_spec = replicate_model_fn._eval_spec(
+#           tower_specs, aggregation_device='/device:GPU:0')
+#
+#       accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+#       auc, b = estimator_spec.eval_metric_ops['auc']
+#
+#       self.assertEqual('/device:CPU:0', accuracy.device)
+#       self.assertEqual('/device:CPU:0', auc.device)
+#
+#       session.run([a, b])
+#       accuracy = session.run(accuracy)
+#       auc = session.run(auc)
+#
+#       self.assertNear((12 - 2) / 12, accuracy, 0.01)
+#       self.assertEqual(0, auc)
+#       self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
 
   def test_handles_single_tower(self):
     with self.test_session() as session:

From 11b9c430fac6a68972012d8b34b3f216a7b9e650 Mon Sep 17 00:00:00 2001
From: Alexandre Passos <apassos@google.com>
Date: Thu, 9 Nov 2017 13:26:45 -0800
Subject: [PATCH 083/115] EagerVariableStore.trainable_variables()

PiperOrigin-RevId: 175203593
---
 .../python/kernel_tests/variable_scope_test.py       | 12 ++++++++++++
 tensorflow/python/ops/variable_scope.py              |  5 +++++
 2 files changed, 17 insertions(+)

diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index bd4b12b7e8a..53962149561 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -117,6 +117,18 @@ class VariableScopeTest(test.TestCase):
         w = variable_scope.get_variable("w", [])
         self.assertEqual(w.dtype.base_dtype, dtypes.float16)
 
+  def testEagerVaribleStore(self):
+    with context.eager_mode():
+      store = variable_scope.EagerVariableStore()
+      with store.as_default():
+        v = variable_scope.get_variable("v", shape=(), trainable=True)
+        w = variable_scope.get_variable("w", shape=(), trainable=False)
+
+      self.assertTrue(v in store.variables())
+      self.assertTrue(w in store.variables())
+      self.assertTrue(v in store.trainable_variables())
+      self.assertFalse(w in store.trainable_variables())
+
   @test_util.run_in_graph_and_eager_modes()
   def testInitFromNonTensorValue(self):
     v = variable_scope.get_variable("v4", initializer=4, dtype=dtypes.int32)
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 92fa928eede..9a0ff755941 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1227,6 +1227,11 @@ class EagerVariableStore(object):
   def variables(self):
     return self._store._vars.values()  # pylint: disable=protected-access
 
+  def trainable_variables(self):
+    # pylint: disable=protected-access
+    return [x for x in self._store._vars.values() if x._trainable]
+    # pylint: enable=protected-access
+
 
 def get_variable(name,
                  shape=None,

From e830e6ddcb20ff2f7391b7c896bdb5004d5dda88 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 13:29:56 -0800
Subject: [PATCH 084/115] Modify quantization to support add ops that occur
 after all quantizable types, not just Conv2D.

PiperOrigin-RevId: 175204002
---
 .../contrib/quantize/python/quantize.py       |  4 +--
 .../contrib/quantize/python/quantize_test.py  | 25 +++++++++++++++++++
 2 files changed, 27 insertions(+), 2 deletions(-)

diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 6382d3f7b41..7db2d863aa4 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -89,8 +89,8 @@ def Quantize(graph,
           op.name[:-len('/depthwise')])
       if separable_conv and separable_conv.type == 'Conv2D':
         continue
-    if op.type == 'Conv2D':
-      # Quantize add ops that come after Conv2D
+    # Quantize add ops that come after Conv2D or DepthwiseConv2dNative.
+    if op.type in ['Conv2D', 'DepthwiseConv2dNative']:
       add_context_re = re.search(r'^(.*)/[^/]+/', op.name)
       if add_context_re is not None:
         context.add_contexts.add(add_context_re.group(1))
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index eb141a21bd8..1e4dd7cf67d 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import nn_ops
 from tensorflow.python.platform import googletest
 
 conv2d = layers.conv2d
+separable_conv2d = layers.separable_conv2d
 
 
 class QuantizeTest(test_util.TensorFlowTestCase):
@@ -77,6 +78,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
                                             quantization_node_name)
     self.assertEqual(add_quant.type, quantization_node_name)
 
+  def testInsertQuantOpForAddAfterSeparableConv2d(self):
+    graph = ops.Graph()
+    with graph.as_default():
+      batch_size, height, width, depth = 5, 128, 128, 3
+      input1 = array_ops.zeros((batch_size, height, width, depth))
+      input2 = array_ops.zeros((batch_size, height / 2, width / 2, depth))
+      conv = separable_conv2d(input1, None, [5, 5], stride=2,
+                              depth_multiplier=1.0, padding='SAME',
+                              weights_initializer=self._WeightInit(0.09),
+                              activation_fn=None, scope='test/test')
+      node = math_ops.add(conv, input2, name='test/add')
+      node = array_ops.identity(node, name='test/identity')
+      update_barrier = control_flow_ops.no_op(name='update_barrier')
+      with ops.control_dependencies([update_barrier]):
+        array_ops.identity(node, name='control_dependency')
+
+    quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True,
+                      activation_bits=8)
+
+    quantization_node_name = 'FakeQuantWithMinMaxVars'
+    add_quant = graph.get_operation_by_name('test/add_quant/' +
+                                            quantization_node_name)
+    self.assertEqual(add_quant.type, quantization_node_name)
+
   def _WeightInit(self, stddev):
     """Returns truncated normal variable initializer.
 

From ea33185cc154bb80741bf4a8a7321aae4b5396cd Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 13:30:25 -0800
Subject: [PATCH 085/115] Fix GANEstimator docstring.

PiperOrigin-RevId: 175204075
---
 .../contrib/gan/python/estimator/python/gan_estimator_impl.py   | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index e89993991a3..0824ecf616c 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -76,7 +76,7 @@ class GANEstimator(estimator.Estimator):
         return logits
 
       # Create GAN estimator.
-      gan_estimator = estimator.GANEstimator(
+      gan_estimator = tfgan.estimator.GANEstimator(
           model_dir,
           generator_fn=generator_fn,
           discriminator_fn=discriminator_fn,

From 02dbaaa3a7063ea5ede4ddf47a6ef5df5a64518e Mon Sep 17 00:00:00 2001
From: Yifei Feng <fengyifei2026@gmail.com>
Date: Thu, 9 Nov 2017 14:51:56 -0800
Subject: [PATCH 086/115] Fix typo <Copybara Experiment DO NOT MERGE>

Fix typo in tensorflow/python/framework/function.py
---
 tensorflow/python/framework/function.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index cef3f8d4c42..29cf2237244 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -100,7 +100,7 @@ class Defun(object):
          grad_func - (optional).  A function implementing the gradient
            of the function-to-register.  This is must be a
            `_DefinedFunction` object. The gradient
-           function must satisify the criterion defined in
+           function must satisfy the criterion defined in
            function.proto:GradientDef.
 
          python_grad_func - (optional).  A function implementing the

From 954e8d6b134288195f54b8871ee9fcc432bf0aba Mon Sep 17 00:00:00 2001
From: Asim Shankar <ashankar@google.com>
Date: Thu, 9 Nov 2017 13:42:15 -0800
Subject: [PATCH 087/115] eager: README title tweak.

PiperOrigin-RevId: 175205782
---
 tensorflow/contrib/eager/README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md
index ae4b07799f5..dcc370cd00d 100644
--- a/tensorflow/contrib/eager/README.md
+++ b/tensorflow/contrib/eager/README.md
@@ -1,4 +1,4 @@
-# TensorFlow Eager Execution
+# Eager Execution
 
 > *WARNING*: This is a preview/pre-alpha version. The API and performance
 > characteristics are subject to change.

From e930f0e072b8d67d9bf29d77babf071a3569615c Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 13:49:12 -0800
Subject: [PATCH 088/115] Fix bug reported in b/69059093, by skipping rewrites
 that we can determine have already been applied. Make sure rewrites are
 idempotent by running the optimizer twice in unit tests.

PiperOrigin-RevId: 175206742
---
 .../optimizers/arithmetic_optimizer.cc        | 12 +--
 .../optimizers/arithmetic_optimizer_test.cc   | 94 ++++++++++++++++++-
 2 files changed, 96 insertions(+), 10 deletions(-)

diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 14df3caebbc..44d16e5a426 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -833,8 +833,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
     }
   }
 
-  if (node->input_size() > 0 && IsAggregate(*node) &&
-      !node_map->GetOutputs(node->name()).empty()) {
+  if (node->input_size() > 0 && IsAggregate(*node)) {
     // Discard aggregate nodes with a single input.
     if (node->input_size() == 1) {
       return node->input(0);
@@ -855,7 +854,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
         break;
       }
     }
-    if (all_equal) {
+    if (all_equal && node_map->GetNode(node->name() + "_const") == nullptr) {
       // 1. Create constant node with value N.
       const int N = node->input_size();
       const auto type = GetDataTypeFromAttr(*node, "T");
@@ -898,7 +897,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
   // where all the inputs are Mul nodes. This pattern occurs frequently in
   // regularization terms for the gradients during training.
   if (node->input_size() > 1 && IsAggregate(*node) &&
-      !node_map->GetOutputs(node->name()).empty()) {
+      node_map->GetNode(node->name() + "_hoist") == nullptr) {
     // Determine the set of common factors if the input nodes are all Mul nodes.
     std::set<string> common_factors;
     int i = 0;
@@ -1011,8 +1010,9 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
   }
 
   // Fold Conj into Transpose or ConjugateTranspose.
-  if (node->op() == "Conj" || node->op() == "Transpose" ||
-      node->op() == "ConjugateTranspose") {
+  if ((node->op() == "Conj" || node->op() == "Transpose" ||
+       node->op() == "ConjugateTranspose") &&
+      node_map->GetNode(node->name() + "_fused") == nullptr) {
     const NodeDef* input = node_map->GetNode(node->input(0));
     const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
     const NodeDef* conj_op = node->op() == "Conj" ? node : input;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 9f471302c7f..60fb47f51aa 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -38,8 +38,8 @@ TEST_F(ArithmeticOptimizerTest, NoOp) {
 
   ArithmeticOptimizer optimizer;
   GraphDef output;
-  Status s = optimizer.Optimize(nullptr, item, &output);
-  TF_EXPECT_OK(s);
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
 
   EXPECT_EQ(item.graph.node_size(), output.node_size());
   for (int i = 0; i < item.graph.node_size(); ++i) {
@@ -66,6 +66,10 @@ TEST_F(ArithmeticOptimizerTest, OpDedupping) {
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
 
   EXPECT_EQ(2, output.node_size());
   const NodeDef& new_c1 = output.node(0);
@@ -91,6 +95,10 @@ TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
 
   EXPECT_EQ(4, output.node_size());
   const NodeDef& new_c1 = output.node(0);
@@ -146,13 +154,17 @@ TEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
 
   EXPECT_EQ(6, output.node_size());
   EXPECT_EQ("squeeze", output.node(5).input(0));
   EXPECT_EQ("c", output.node(2).input(0));
 }
 
-TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) {
+TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
   Output add = ops::Add(s.WithOpName("add"), x, x);
@@ -165,6 +177,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) {
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
 
   EXPECT_EQ(5, output.node_size());
   const NodeDef& new_const = output.node(3);
@@ -178,7 +194,61 @@ TEST_F(ArithmeticOptimizerTest, SimplifyReplaceTrivialSums) {
   EXPECT_EQ("add_mul", new_id.input(0));
 }
 
-TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) {
+TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
+  // Test case from b/69059093.
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
+  Output add = ops::Add(s.WithOpName("Add"), p, p);
+  Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
+  Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
+  Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
+  Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
+  Output id = ops::Identity(s.WithOpName("id"), add6);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  ArithmeticOptimizer optimizer;
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  EXPECT_EQ(11, output.node_size());
+  const NodeDef& new_id = output.node(4);
+  EXPECT_EQ("id", new_id.name());
+  EXPECT_EQ("Add_6_mul", new_id.input(0));
+
+  // Add4 and add5 get deduped, and we rewrite each of the 3 remaining add nodes
+  // of the form Add(x,x) into Mul(Const(2), x).
+  const NodeDef& new_add_4_const = output.node(5);
+  EXPECT_EQ("Add_4_const", new_add_4_const.name());
+  EXPECT_EQ("^Add", new_add_4_const.input(0));
+  const NodeDef& new_add_4_mul = output.node(6);
+  EXPECT_EQ("Add_4_mul", new_add_4_mul.name());
+  EXPECT_EQ("Add_4_const", new_add_4_mul.input(0));
+  EXPECT_EQ("Add_mul", new_add_4_mul.input(1));
+
+  const NodeDef& new_add_6_const = output.node(7);
+  EXPECT_EQ("Add_6_const", new_add_6_const.name());
+  EXPECT_EQ("^Add_4_mul", new_add_6_const.input(0));
+  const NodeDef& new_add_6_mul = output.node(8);
+  EXPECT_EQ("Add_6_mul", new_add_6_mul.name());
+  EXPECT_EQ("Add_6_const", new_add_6_mul.input(0));
+  EXPECT_EQ("Add_4_mul", new_add_6_mul.input(1));
+
+  const NodeDef& new_add_const = output.node(9);
+  EXPECT_EQ("Add_const", new_add_const.name());
+  EXPECT_EQ("^Placeholder", new_add_const.input(0));
+  const NodeDef& new_add_mul = output.node(10);
+  EXPECT_EQ("Add_mul", new_add_mul.name());
+  EXPECT_EQ("Add_const", new_add_mul.input(0));
+  EXPECT_EQ("Placeholder", new_add_mul.input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, HoistFactor) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
   Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
@@ -195,6 +265,10 @@ TEST_F(ArithmeticOptimizerTest, SimplifyHoistFactor) {
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
 
   EXPECT_EQ(9, output.node_size());
   const NodeDef& new_add = output.node(8);
@@ -225,6 +299,10 @@ TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
 
   EXPECT_EQ(7, output.node_size());
   EXPECT_EQ("trans_fused", output.node(6).name());
@@ -272,6 +350,10 @@ TEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
+  // Run the optimizer twice to make sure the rewrite is idempotent.
+  item.graph.Swap(&output);
+  status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
 
   EXPECT_EQ(7, output.node_size());
   EXPECT_EQ("conj_fused", output.node(6).name());
@@ -304,6 +386,10 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
     GraphDef output;
     Status status = optimizer.Optimize(nullptr, item, &output);
     TF_EXPECT_OK(status);
+    // Run the optimizer twice to make sure the rewrite is idempotent.
+    item.graph.Swap(&output);
+    status = optimizer.Optimize(nullptr, item, &output);
+    TF_EXPECT_OK(status);
 
     EXPECT_EQ(7, output.node_size());
     EXPECT_EQ("matmul_fused", output.node(6).name());

From 9d5a6650ca2ad7068ee556c8dbce03b96ea22128 Mon Sep 17 00:00:00 2001
From: Alexandre Passos <apassos@google.com>
Date: Thu, 9 Nov 2017 13:56:17 -0800
Subject: [PATCH 089/115] Instances per second in the eager microbenchmarks.

PiperOrigin-RevId: 175207829
---
 tensorflow/python/eager/benchmarks_test.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 26a70a617d5..b555f16f1d3 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -66,7 +66,8 @@ class MicroBenchmarks(test.Benchmark):
       func()
     end = time.time()
     mean_us = (end - start) * 1e6 / num_iters
-    self.report_benchmark(iters=num_iters, wall_time=mean_us)
+    self.report_benchmark(iters=num_iters, wall_time=mean_us,
+                          extras={"examples_per_sec": num_iters/(end-start)})
 
   def benchmark_create_np_array(self):
     func = lambda: np.array([3.0])

From 898b3486ab16fd2acc3d9f12f57a3be8d83d09ec Mon Sep 17 00:00:00 2001
From: Justin Lebar <jlebar@google.com>
Date: Thu, 9 Nov 2017 14:10:11 -0800
Subject: [PATCH 090/115] Limit internal fragmentation in BFCAllocator to 128mb
 per allocation.

Previously, if you had a very large allocation, it would round up to the
next power of 2, and then, if this didn't fit in your GPU's available
memory, eat all remaining memory in the device.

Now we waste at most 128mb of memory in a large alloc.

PiperOrigin-RevId: 175209995
---
 tensorflow/core/common_runtime/bfc_allocator.cc | 13 +++++++------
 1 file changed, 7 insertions(+), 6 deletions(-)

diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 38fe247521b..6399b8cf55b 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -296,12 +296,13 @@ void* BFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes,
         // it from the free bin structure prior to using.
         RemoveFreeChunkIterFromBin(&b->free_chunks, citer);
 
-        // If we can break the size of the chunk into two reasonably
-        // large pieces, do so.
-        //
-        // TODO(vrv): What should be the criteria when deciding when
-        // to split?
-        if (chunk->size >= rounded_bytes * 2) {
+        // If we can break the size of the chunk into two reasonably large
+        // pieces, do so.  In any case don't waste more than
+        // kMaxInternalFragmentation bytes on padding this alloc.
+        const int64 kMaxInternalFragmentation = 128 << 20;  // 128mb
+        if (chunk->size >= rounded_bytes * 2 ||
+            static_cast<int64>(chunk->size) - rounded_bytes >=
+                kMaxInternalFragmentation) {
           SplitChunk(h, rounded_bytes);
           chunk = ChunkFromHandle(h);  // Update chunk pointer in case it moved
         }

From e4cbba18dd0c04e5490997bc04c09a5269ce19e8 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 14:14:30 -0800
Subject: [PATCH 091/115] Expose padded_batch_and_drop_remainder

PiperOrigin-RevId: 175210678
---
 tensorflow/contrib/data/__init__.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 824ac4298f8..6e43ae0e632 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -23,6 +23,7 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview.
 @@TextLineDataset
 
 @@batch_and_drop_remainder
+@@padded_batch_and_drop_remainder
 @@dense_to_sparse_batch
 @@enumerate_dataset
 @@group_by_window
@@ -45,6 +46,7 @@ from __future__ import print_function
 
 from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder
 from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch
+from tensorflow.contrib.data.python.ops.batching import padded_batch_and_drop_remainder
 from tensorflow.contrib.data.python.ops.batching import unbatch
 from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
 from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element

From 3de7349955b839edbd61fef7bac3db9e140ffd3d Mon Sep 17 00:00:00 2001
From: Eugene Brevdo <ebrevdo@google.com>
Date: Thu, 9 Nov 2017 14:21:25 -0800
Subject: [PATCH 092/115] Add tf.nn.softmax_cross_entropy_with_logits_v2 which
 enables backprop wrt the labels.

Clarify current backprop behavior.

Original bugfix by Alexandre Passos.

PiperOrigin-RevId: 175211803
---
 .../python/kernel_tests/xent_op_test.py       |  18 ++
 tensorflow/python/ops/nn.py                   |   1 +
 tensorflow/python/ops/nn_grad.py              |   5 +-
 tensorflow/python/ops/nn_ops.py               | 180 ++++++++++++------
 .../tools/api/golden/tensorflow.nn.pbtxt      |   4 +
 5 files changed, 151 insertions(+), 57 deletions(-)

diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 4b3dadc1128..43be08f8a14 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -181,6 +181,24 @@ class XentTest(test.TestCase):
     print("cross entropy gradient err = ", err)
     self.assertLess(err, 5e-8)
 
+  def testGradientLabelWithV2(self):
+    with self.test_session():
+      l = constant_op.constant(
+          [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.5],
+          shape=[3, 4],
+          dtype=dtypes.float64,
+          name="l")
+      f = constant_op.constant(
+          [0.1, 0.2, 0.3, 0.4, 0.1, 0.4, 0.9, 1.6, 0.1, 0.8, 2.7, 6.4],
+          shape=[3, 4],
+          dtype=dtypes.float64,
+          name="f")
+      x = nn_ops.softmax_cross_entropy_with_logits_v2(labels=l, logits=f,
+                                                      name="xent")
+      err = gradient_checker.compute_gradient_error(l, [3, 4], x, [3])
+
+    self.assertLess(err, 5e-8)
+
   def testSecondGradient(self):
     with self.test_session() as sess:
       l = constant_op.constant([0.0, 0.0, 1.0/3, 0.0,
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 79af3ac1172..ee1a00623a7 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -74,6 +74,7 @@ See the @{$python/nn} guide.
 @@softmax
 @@log_softmax
 @@softmax_cross_entropy_with_logits
+@@softmax_cross_entropy_with_logits_v2
 @@sparse_softmax_cross_entropy_with_logits
 @@weighted_cross_entropy_with_logits
 @@embedding_lookup
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 557f39fb42e..4b406ba8404 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -420,7 +420,6 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
   # grad_loss is the backprop for cost, and we multiply it with the gradients
   # (which is output[1])
   # grad_grad is the backprop for softmax gradient.
-  # There is no gradient for the labels
   #
   # Second derivative is just softmax derivative w.r.t. logits.
   softmax_grad = op.outputs[1]
@@ -436,15 +435,15 @@ def _SoftmaxCrossEntropyWithLogitsGrad(op, grad_loss, grad_grad):
     const_fill_value = tensor_util.constant_value(g)
     return const_fill_value is not None and (const_fill_value == 0).all()
 
+  logits = op.inputs[0]
   if grad_grad is not None and not IsZero(grad_grad):
-    logits = op.inputs[0]
     softmax = nn_ops.softmax(logits)
 
     grad += ((grad_grad - array_ops.squeeze(
         math_ops.matmul(grad_grad[:, None, :],
                         softmax[:, :, None]), axis=1)) * softmax)
 
-  return grad, None
+  return grad, _BroadcastMul(grad_loss, -nn_ops.log_softmax(logits))
 
 
 @ops.RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index a37b68c6fa7..bdaac65904a 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -32,11 +32,13 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_nn_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
+
 # go/tf-wildcard-import
 # pylint: disable=wildcard-import
 from tensorflow.python.ops.gen_nn_ops import *
 # pylint: enable=wildcard-import
 
+from tensorflow.python.util import deprecation
 
 # Aliases for some automatically-generated names.
 local_response_normalization = gen_nn_ops.lrn
@@ -1700,6 +1702,121 @@ def _ensure_xent_args(name, sentinel, labels, logits):
     raise ValueError("Both labels and logits must be provided.")
 
 
+def softmax_cross_entropy_with_logits_v2(_sentinel=None,  # pylint: disable=invalid-name
+                                         labels=None, logits=None,
+                                         dim=-1, name=None):
+  """Computes softmax cross entropy between `logits` and `labels`.
+
+  Measures the probability error in discrete classification tasks in which the
+  classes are mutually exclusive (each entry is in exactly one class).  For
+  example, each CIFAR-10 image is labeled with one and only one label: an image
+  can be a dog or a truck, but not both.
+
+  **NOTE:**  While the classes are mutually exclusive, their probabilities
+  need not be.  All that is required is that each row of `labels` is
+  a valid probability distribution.  If they are not, the computation of the
+  gradient will be incorrect.
+
+  If using exclusive `labels` (wherein one and only
+  one class is true at a time), see `sparse_softmax_cross_entropy_with_logits`.
+
+  **WARNING:** This op expects unscaled logits, since it performs a `softmax`
+  on `logits` internally for efficiency.  Do not call this op with the
+  output of `softmax`, as it will produce incorrect results.
+
+  `logits` and `labels` must have the same shape, e.g.
+  `[batch_size, num_classes]` and the same dtype (either `float16`, `float32`,
+  or `float64`).
+
+  Backpropagation will happen into both `logits` and `labels`.  To disallow
+  backpropagation into `labels`, pass label tensors through a `stop_gradients`
+  before feeding it to this function.
+
+  **Note that to avoid confusion, it is required to pass only named arguments to
+  this function.**
+
+  Args:
+    _sentinel: Used to prevent positional parameters. Internal, do not use.
+    labels: Each row `labels[i]` must be a valid probability distribution.
+    logits: Unscaled log probabilities.
+    dim: The class dimension. Defaulted to -1 which is the last dimension.
+    name: A name for the operation (optional).
+
+  Returns:
+    A 1-D `Tensor` of length `batch_size` of the same type as `logits` with the
+    softmax cross entropy loss.
+  """
+  _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel,
+                    labels, logits)
+
+  # TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This
+  # could break users who call this with bad labels, but disregard the bad
+  # results.
+
+  with ops.name_scope(
+      name, "softmax_cross_entropy_with_logits", [logits, labels]) as name:
+    logits = ops.convert_to_tensor(logits, name="logits")
+    labels = ops.convert_to_tensor(labels, name="labels")
+    precise_logits = math_ops.cast(logits, dtypes.float32) if (
+        logits.dtype == dtypes.float16) else logits
+    # labels and logits must be of the same type
+    labels = math_ops.cast(labels, precise_logits.dtype)
+    input_rank = array_ops.rank(precise_logits)
+    # For shape inference.
+    shape = logits.get_shape()
+
+    # Move the dim to the end if dim is not the last dimension.
+    if dim is not -1:
+      def _move_dim_to_end(tensor, dim_index, rank):
+        return array_ops.transpose(tensor,
+                                   array_ops.concat([
+                                       math_ops.range(dim_index),
+                                       math_ops.range(dim_index + 1, rank),
+                                       [dim_index]
+                                   ], 0))
+
+      precise_logits = _move_dim_to_end(precise_logits, dim, input_rank)
+      labels = _move_dim_to_end(labels, dim, input_rank)
+
+    input_shape = array_ops.shape(precise_logits)
+
+    # Make precise_logits and labels into matrices.
+    precise_logits = _flatten_outer_dims(precise_logits)
+    labels = _flatten_outer_dims(labels)
+
+    # Do the actual op computation.
+    # The second output tensor contains the gradients.  We use it in
+    # _CrossEntropyGrad() in nn_grad but not here.
+    cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
+        precise_logits, labels, name=name)
+
+    # The output cost shape should be the input minus dim.
+    output_shape = array_ops.slice(input_shape, [0],
+                                   [math_ops.subtract(input_rank, 1)])
+    cost = array_ops.reshape(cost, output_shape)
+
+    # Make shape inference work since reshape and transpose may erase its static
+    # shape.
+    if context.in_graph_mode() and shape is not None and shape.dims is not None:
+      shape = shape.as_list()
+      del shape[dim]
+      cost.set_shape(shape)
+
+    if logits.dtype == dtypes.float16:
+      return math_ops.cast(cost, dtypes.float16)
+    else:
+      return cost
+
+
+_XENT_DEPRECATION = """
+Future major versions of TensorFlow will allow gradients to flow
+into the labels input on backprop by default.
+
+See tf.nn.softmax_cross_entropy_with_logits_v2.
+"""
+
+
+@deprecation.deprecated(date=None, instructions=_XENT_DEPRECATION)
 def softmax_cross_entropy_with_logits(_sentinel=None,  # pylint: disable=invalid-name
                                       labels=None, logits=None,
                                       dim=-1, name=None):
@@ -1726,6 +1843,10 @@ def softmax_cross_entropy_with_logits(_sentinel=None,  # pylint: disable=invalid
   `[batch_size, num_classes]` and the same dtype (either `float16`, `float32`,
   or `float64`).
 
+  Backpropagation will happen only into `logits`.  To calculate a cross entropy
+  loss that allows backpropagation into both `logits` and `labels`, see
+  @{tf.nn.softmax_cross_entropy_with_logits_v2}.
+
   **Note that to avoid confusion, it is required to pass only named arguments to
   this function.**
 
@@ -1743,61 +1864,12 @@ def softmax_cross_entropy_with_logits(_sentinel=None,  # pylint: disable=invalid
   _ensure_xent_args("softmax_cross_entropy_with_logits", _sentinel,
                     labels, logits)
 
-  # TODO(pcmurray) Raise an error when the labels do not sum to 1. Note: This
-  # could break users who call this with bad labels, but disregard the bad
-  # results.
+  with ops.name_scope(
+      name, "softmax_cross_entropy_with_logits_sg", [logits, labels]) as name:
+    labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
 
-  logits = ops.convert_to_tensor(logits)
-  labels = ops.convert_to_tensor(labels)
-  precise_logits = math_ops.cast(logits, dtypes.float32) if (
-      logits.dtype == dtypes.float16) else logits
-  # labels and logits must be of the same type
-  labels = math_ops.cast(labels, precise_logits.dtype)
-  input_rank = array_ops.rank(precise_logits)
-  # For shape inference.
-  shape = logits.get_shape()
-
-  # Move the dim to the end if dim is not the last dimension.
-  if dim is not -1:
-    def _move_dim_to_end(tensor, dim_index, rank):
-      return array_ops.transpose(tensor,
-                                 array_ops.concat([
-                                     math_ops.range(dim_index),
-                                     math_ops.range(dim_index + 1, rank),
-                                     [dim_index]
-                                 ], 0))
-
-    precise_logits = _move_dim_to_end(precise_logits, dim, input_rank)
-    labels = _move_dim_to_end(labels, dim, input_rank)
-
-  input_shape = array_ops.shape(precise_logits)
-
-  # Make precise_logits and labels into matrices.
-  precise_logits = _flatten_outer_dims(precise_logits)
-  labels = _flatten_outer_dims(labels)
-
-  # Do the actual op computation.
-  # The second output tensor contains the gradients.  We use it in
-  # _CrossEntropyGrad() in nn_grad but not here.
-  cost, unused_backprop = gen_nn_ops._softmax_cross_entropy_with_logits(
-      precise_logits, labels, name=name)
-
-  # The output cost shape should be the input minus dim.
-  output_shape = array_ops.slice(input_shape, [0],
-                                 [math_ops.subtract(input_rank, 1)])
-  cost = array_ops.reshape(cost, output_shape)
-
-  # Make shape inference work since reshape and transpose may erase its static
-  # shape.
-  if context.in_graph_mode() and shape is not None and shape.dims is not None:
-    shape = shape.as_list()
-    del shape[dim]
-    cost.set_shape(shape)
-
-  if logits.dtype == dtypes.float16:
-    return math_ops.cast(cost, dtypes.float16)
-  else:
-    return cost
+  return softmax_cross_entropy_with_logits_v2(
+      labels=labels, logits=logits, dim=dim, name=name)
 
 
 def sparse_softmax_cross_entropy_with_logits(_sentinel=None,  # pylint: disable=invalid-name
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
index 11637814a6e..24c0448deae 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
@@ -288,6 +288,10 @@ tf_module {
     name: "softmax_cross_entropy_with_logits"
     argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'-1\', \'None\'], "
   }
+  member_method {
+    name: "softmax_cross_entropy_with_logits_v2"
+    argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'dim\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'-1\', \'None\'], "
+  }
   member_method {
     name: "softplus"
     argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

From a0e9c52921aef9eecbd358fa5f129328f0024ab9 Mon Sep 17 00:00:00 2001
From: Anna R <annarev@google.com>
Date: Thu, 9 Nov 2017 14:31:06 -0800
Subject: [PATCH 093/115] Internal change.

PiperOrigin-RevId: 175213336
---
 tensorflow/contrib/nccl/BUILD | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index ed9fb64b954..df9dbb457ac 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -48,8 +48,8 @@ tf_cuda_cc_test(
     # Disabled on jenkins until errors finding nvmlShutdown are found.
     tags = [
         "manual",
+        "multi_gpu",
         "no_oss",
-        "noguitar",  # note: is run manually there
         "notap",
     ],
     deps = if_cuda(
@@ -138,8 +138,8 @@ cuda_py_test(
     # Disabled on jenkins until errors finding nvmlShutdown are found.
     tags = [
         "manual",
+        "multi_gpu",
         "no_oss",
-        "noguitar",  # note: is run manually there
         "notap",
     ],
 )

From f3f85e9aa0f6f26740d1da77e5bcc58ff70aa71c Mon Sep 17 00:00:00 2001
From: HyoukJoong Lee <hyouklee@google.com>
Date: Thu, 9 Nov 2017 14:48:37 -0800
Subject: [PATCH 094/115] Change for asynchronous Send and Recv by splitting
 Send into {Send, SendDone} and Recv into {Recv, RecvDone}. See
 operation_semantics.md for the updated semantics.

PiperOrigin-RevId: 175216012
---
 .../compiler/xla/service/buffer_assignment.cc | 11 ---
 .../compiler/xla/service/cpu/ir_emitter.cc    | 10 +++
 .../compiler/xla/service/cpu/ir_emitter.h     |  2 +
 .../compiler/xla/service/dfs_hlo_visitor.h    |  6 +-
 .../service/dfs_hlo_visitor_with_default.h    | 10 ++-
 .../compiler/xla/service/gpu/ir_emitter.cc    |  8 ++
 .../compiler/xla/service/gpu/ir_emitter.h     |  2 +
 .../compiler/xla/service/hlo_cost_analysis.cc |  8 ++
 .../compiler/xla/service/hlo_cost_analysis.h  |  2 +
 .../xla/service/hlo_dataflow_analysis.cc      | 65 ++++++++++++++
 .../xla/service/hlo_dataflow_analysis.h       |  2 +
 .../xla/service/hlo_dataflow_analysis_test.cc | 48 ++++++++++
 .../compiler/xla/service/hlo_graph_dumper.cc  |  4 +
 .../compiler/xla/service/hlo_instruction.cc   | 57 ++++++++++--
 .../compiler/xla/service/hlo_instruction.h    | 22 +++--
 .../compiler/xla/service/hlo_matchers.h       |  2 +
 tensorflow/compiler/xla/service/hlo_opcode.h  |  2 +
 .../xla/service/hlo_rematerialization.cc      |  2 +
 .../compiler/xla/service/hlo_verifier.cc      | 49 +++++++++-
 .../xla/service/instruction_fusion.cc         |  2 +
 .../xla/service/logical_buffer_analysis.cc    | 15 ++++
 .../xla/service/logical_buffer_analysis.h     |  2 +
 .../xla/service/tuple_points_to_analysis.cc   | 58 ++++++++++++
 .../xla/service/tuple_points_to_analysis.h    |  2 +
 .../service/tuple_points_to_analysis_test.cc  | 45 ++++++++++
 .../compiler/xla/service/user_computation.cc  |  6 +-
 .../xla/service/while_loop_simplifier.cc      |  4 +-
 .../xla/service/while_loop_simplifier_test.cc |  6 +-
 .../compiler/xla/tools/parser/hlo_parser.cc   | 30 ++++++-
 .../xla/tools/parser/hlo_parser_test.cc       | 26 ++++--
 .../performance/xla/operation_semantics.md    | 89 +++++++++++++++++++
 31 files changed, 550 insertions(+), 47 deletions(-)

diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index c74f050f775..3c5b360c8ef 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -819,17 +819,6 @@ Status BufferAssigner::AssignBuffersForComputation(
       continue;
     }
 
-    if (instruction->opcode() == HloOpcode::kRecv) {
-      // Make sure that recv operations get a new unique allocation so that
-      // don't share their buffer with any other operations.
-      BufferAllocation* allocation = assignment->NewAllocation(
-          *buffer, buffer_size, is_thread_local, /*is_reusable=*/false);
-      allocation_indices.push_back(allocation->index());
-      VLOG(3) << "New allocation #" << allocation->index()
-              << " for recv: " << *buffer;
-      continue;
-    }
-
     if (ShapeUtil::IsTuple(buffer->shape())) {
       // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
       // assumes longer buffer liveness than indicated by the analysis.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index a20ce6826ca..e547f291b8e 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -1983,6 +1983,11 @@ Status IrEmitter::HandleSend(HloInstruction* send) {
   return Unimplemented("Send is not implemented on CPU. See b/33942983.");
 }
 
+Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
+  // TODO(b/33942983): Support Send/Recv on CPU.
+  return Unimplemented("Send-done is not implemented on CPU. See b/33942983.");
+}
+
 Status IrEmitter::HandleSlice(HloInstruction* slice) {
   VLOG(2) << "HandleSlice: " << slice->ToString();
   auto operand = slice->operand(0);
@@ -2148,6 +2153,11 @@ Status IrEmitter::HandleRecv(HloInstruction* recv) {
   return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
 }
 
+Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) {
+  // TODO(b/33942983): Support Send/Recv on CPU.
+  return Unimplemented("Recv-done is not implemented on CPU. See b/33942983.");
+}
+
 Status IrEmitter::HandlePad(HloInstruction* pad) {
   // CPU backend does not properly handle negative padding but this is ok
   // because negative padding should be removed by the algebraic simplifier.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 5d061e11e3c..83eded5ad8c 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -171,11 +171,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
   Status HandleReduceWindow(HloInstruction* reduce_window) override;
   Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override;
   Status HandleSend(HloInstruction* send) override;
+  Status HandleSendDone(HloInstruction* send_done) override;
   Status HandleSlice(HloInstruction* slice) override;
   Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
   Status HandleDynamicUpdateSlice(
       HloInstruction* dynamic_update_slice) override;
   Status HandleRecv(HloInstruction* recv) override;
+  Status HandleRecvDone(HloInstruction* recv_done) override;
   Status HandlePad(HloInstruction* pad) override;
   Status HandleTuple(HloInstruction* tuple) override;
   Status HandleMap(HloInstruction* map) override;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index de3cd154408..bc73839a88d 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -211,9 +211,11 @@ class DfsHloVisitorBase {
 
   virtual Status HandlePad(HloInstructionPtr hlo) = 0;
 
-  virtual Status HandleSend(HloInstructionPtr hlo) = 0;
+  virtual Status HandleSend(HloInstructionPtr send) = 0;
+  virtual Status HandleSendDone(HloInstructionPtr send_done) = 0;
 
-  virtual Status HandleRecv(HloInstructionPtr hlo) = 0;
+  virtual Status HandleRecv(HloInstructionPtr recv) = 0;
+  virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0;
 
   virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0;
 
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 7ce88be89df..5415bab5b35 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -167,11 +167,17 @@ class DfsHloVisitorWithDefaultBase
   Status HandleWhile(HloInstructionPtr xla_while) override {
     return DefaultAction(xla_while);
   }
+  Status HandleRecv(HloInstructionPtr recv) override {
+    return DefaultAction(recv);
+  }
+  Status HandleRecvDone(HloInstructionPtr recv_done) override {
+    return DefaultAction(recv_done);
+  }
   Status HandleSend(HloInstructionPtr send) override {
     return DefaultAction(send);
   }
-  Status HandleRecv(HloInstructionPtr recv) override {
-    return DefaultAction(recv);
+  Status HandleSendDone(HloInstructionPtr send_done) override {
+    return DefaultAction(send_done);
   }
 
   // Invoked to inform the visitor that the traversal has completed, and that
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 57a3f713e35..9d55c7859df 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -128,10 +128,18 @@ Status IrEmitter::HandleSend(HloInstruction*) {
   return Unimplemented("Send is not implemented on GPU");
 }
 
+Status IrEmitter::HandleSendDone(HloInstruction*) {
+  return Unimplemented("Send-Done is not implemented on GPU");
+}
+
 Status IrEmitter::HandleRecv(HloInstruction*) {
   return Unimplemented("Recv is not implemented on GPU");
 }
 
+Status IrEmitter::HandleRecvDone(HloInstruction*) {
+  return Unimplemented("Recv-done is not implemented on GPU");
+}
+
 Status IrEmitter::HandleTuple(HloInstruction* tuple) {
   std::vector<llvm::Value*> base_ptrs;
   for (const HloInstruction* operand : tuple->operands()) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 263992d9254..61fdeaa0ee7 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -84,7 +84,9 @@ class IrEmitter : public DfsHloVisitorWithDefault {
   Status HandleOutfeed(HloInstruction* outfeed) override;
   Status HandleSort(HloInstruction* sort) override;
   Status HandleSend(HloInstruction* send) override;
+  Status HandleSendDone(HloInstruction* send_done) override;
   Status HandleRecv(HloInstruction* recv) override;
+  Status HandleRecvDone(HloInstruction* recv_done) override;
   Status HandleParameter(HloInstruction* parameter) override;
   Status HandleReduce(HloInstruction* reduce) override;
   Status HandleTuple(HloInstruction* tuple) override;
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 17ba2b673ac..1877065f672 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -337,10 +337,18 @@ Status HloCostAnalysis::HandleSend(const HloInstruction*) {
   return Status::OK();
 }
 
+Status HloCostAnalysis::HandleSendDone(const HloInstruction*) {
+  return Status::OK();
+}
+
 Status HloCostAnalysis::HandleRecv(const HloInstruction*) {
   return Status::OK();
 }
 
+Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) {
+  return Status::OK();
+}
+
 Status HloCostAnalysis::HandleReshape(const HloInstruction*) {
   return Status::OK();
 }
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 8074868e375..0f447753788 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -60,7 +60,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
   Status HandleReducePrecision(const HloInstruction* hlo) override;
   Status HandleConcatenate(const HloInstruction* concatenate) override;
   Status HandleSend(const HloInstruction* send) override;
+  Status HandleSendDone(const HloInstruction* send_done) override;
   Status HandleRecv(const HloInstruction* recv) override;
+  Status HandleRecvDone(const HloInstruction* recv_done) override;
   Status HandleConvert(const HloInstruction* convert) override;
   Status HandleCopy(const HloInstruction* copy) override;
   Status HandleDot(const HloInstruction* dot) override;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 92261bce627..ff80f18bb56 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -242,6 +242,51 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
   return false;
 }
 
+bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
+  CHECK_EQ(send->opcode(), HloOpcode::kSend);
+  bool changed = false;
+  // Send forwards the operand value to the output tuple at {0}.
+  for (auto& pair : GetInstructionValueSet(send->operand(0))) {
+    const ShapeIndex& operand_index = pair.first;
+    const HloValueSet& operand_value_set = pair.second;
+
+    ShapeIndex index = {0};
+    for (int64 i : operand_index) {
+      index.push_back(i);
+    }
+
+    HloValueSet& value_set = GetValueSet(send, index);
+    if (value_set != operand_value_set) {
+      value_set = operand_value_set;
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
+  CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
+  bool changed = false;
+  // RecvDone forwards the operand value at {0} to the output.
+  for (auto& pair : GetInstructionValueSet(recv_done)) {
+    ShapeIndex& index = pair.first;
+    HloValueSet& value_set = pair.second;
+
+    ShapeIndex operand_index = {0};
+    for (int64 i : index) {
+      operand_index.push_back(i);
+    }
+
+    const HloValueSet& operand_value_set =
+        GetValueSet(recv_done->operand(0), operand_index);
+    if (value_set != operand_value_set) {
+      value_set = operand_value_set;
+      changed = true;
+    }
+  }
+  return changed;
+}
+
 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
   CHECK_EQ(call->opcode(), HloOpcode::kCall);
   InstructionValueSet& value_set = GetInstructionValueSet(call);
@@ -429,6 +474,10 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
       return UpdateCallValueSet(instruction);
     case HloOpcode::kWhile:
       return UpdateWhileValueSet(instruction);
+    case HloOpcode::kSend:
+      return UpdateSendValueSet(instruction);
+    case HloOpcode::kRecvDone:
+      return UpdateRecvDoneValueSet(instruction);
     default:
       // Instruction does not forward HloValues (it defines all values in its
       // output). No update is necessary.
@@ -537,6 +586,12 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
         GetValueSet(instruction, /*index=*/{}).AddValue(value);
       };
 
+      // Lambda to set the value set at the given index of the output.
+      auto define_value_at = [this, &instruction](const ShapeIndex& index) {
+        HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
+        GetValueSet(instruction, index).AddValue(value);
+      };
+
       switch (instruction->opcode()) {
         case HloOpcode::kBitcast:
           if (bitcast_defines_value_) {
@@ -577,6 +632,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
           // values flow from their operands.
           define_top_level_only();
           break;
+        case HloOpcode::kRecvDone:
+          // RecvDone aliases its input tuple element {0}, therefore does not
+          // define any values.
+          break;
+        case HloOpcode::kSend:
+          // Send produces a tuple of {aliased operand, U32 context}, therefore
+          // only defines the top-level tuple and the tuple element at {1}.
+          define_value_at(/*index=*/{});
+          define_value_at(/*index=*/{1});
+          break;
         default:
           define_all_values();
           break;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 207e553bf7f..63467f32060 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -146,7 +146,9 @@ class HloDataflowAnalysis {
   bool UpdateCopyValueSet(HloInstruction* copy);
   bool UpdateGetTupleElementValueSet(HloInstruction* gte);
   bool UpdateParameterValueSet(HloInstruction* parameter);
+  bool UpdateRecvDoneValueSet(HloInstruction* recv_done);
   bool UpdateSelectValueSet(HloInstruction* select);
+  bool UpdateSendValueSet(HloInstruction* send);
   bool UpdateTupleValueSet(HloInstruction* tuple);
   bool UpdateWhileValueSet(HloInstruction* xla_while);
 
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 4b8eb237a67..66a538fc519 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1139,6 +1139,54 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) {
       analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
 }
 
+TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
+  // Test that a Send forwards its operand to the output tuple at {0}.
+  auto builder = HloComputation::Builder(TestName());
+  auto param = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
+  auto send = builder.AddInstruction(
+      HloInstruction::CreateSend(param, /*channel_id=*/0));
+  auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
+  module_->AddEntryComputation(builder.Build());
+
+  bool ssa_form = GetParam();
+  const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+  EXPECT_EQ(analysis.values().size(), 4);
+
+  EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
+  EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
+  EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
+  EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
+  EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
+  EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
+              UnorderedElementsAre(analysis.GetValueDefinedAt(param)));
+}
+
+TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
+  // Test that a RecvDone forwards its operand tuple element at {0} to the
+  // output.
+  auto builder = HloComputation::Builder(TestName());
+  auto recv = builder.AddInstruction(
+      HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0));
+  auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
+  module_->AddEntryComputation(builder.Build());
+
+  bool ssa_form = GetParam();
+  const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+
+  EXPECT_EQ(analysis.values().size(), 3);
+
+  EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
+  EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
+  EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
+  EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done));
+  EXPECT_THAT(HloValuesAt(recv_done),
+              UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0})));
+  EXPECT_TRUE(
+      analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
+}
+
 TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
   // A simple chain of elementwise operations. No values should interfere.
   //
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 1c063c973dc..67e0238c4af 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -943,7 +943,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
     case HloOpcode::kFusion:
       return kGray;
     case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
     case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
     case HloOpcode::kInfeed:
     case HloOpcode::kOutfeed:
     case HloOpcode::kCrossReplicaSum:
@@ -1037,7 +1039,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
                    ? ""
                    : StrCat("stride=", VectorString(instr->slice_strides()));
       case HloOpcode::kSend:
+      case HloOpcode::kSendDone:
       case HloOpcode::kRecv:
+      case HloOpcode::kRecvDone:
         return StrCat("channel_id=", instr->channel_id());
       default:
         return "";
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index ee98c3fabc5..ffb933155f7 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -371,20 +371,50 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
 
 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
     HloInstruction* operand, int64 channel_id) {
+  // Send instruction produces a tuple of {aliased operand, U32 context}.
+  Shape output_shape = ShapeUtil::MakeTupleShape(
+      {operand->shape(), ShapeUtil::MakeShape(U32, {})});
   auto instruction =
-      WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil()));
+      WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape));
   instruction->AppendOperand(operand);
   instruction->channel_id_ = channel_id;
   return instruction;
 }
 
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
+    HloInstruction* operand) {
+  CHECK(operand->opcode() == HloOpcode::kSend)
+      << "SendDone must take the context operand from Send";
+  auto instruction = WrapUnique(
+      new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
+  instruction->AppendOperand(operand);
+  instruction->channel_id_ = operand->channel_id();
+  return instruction;
+}
+
 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
     const Shape& shape, int64 channel_id) {
-  auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape));
+  // Recv instruction produces a tuple of {receive buffer, U32 context}.
+  Shape output_shape =
+      ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
+  auto instruction =
+      WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
   instruction->channel_id_ = channel_id;
   return instruction;
 }
 
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
+    HloInstruction* operand) {
+  CHECK(operand->opcode() == HloOpcode::kRecv)
+      << "RecvDone must take the context operand from Recv";
+  Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
+  auto instruction =
+      WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
+  instruction->AppendOperand(operand);
+  instruction->channel_id_ = operand->channel_id();
+  return instruction;
+}
+
 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
     const Shape& shape, HloInstruction* operand,
     tensorflow::gtl::ArraySlice<int64> dimensions) {
@@ -908,7 +938,9 @@ RandomDistribution HloInstruction::random_distribution() const {
 bool HloInstruction::HasSideEffect() const {
   switch (opcode_) {
     case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
     case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
     case HloOpcode::kInfeed:
     case HloOpcode::kOutfeed:
     case HloOpcode::kTrace:
@@ -1164,7 +1196,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
                                   new_operands[4], epsilon(), feature_index());
       break;
     case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
     case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
     case HloOpcode::kTrace:
       LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
   }
@@ -1557,8 +1591,10 @@ bool HloInstruction::IdenticalSlowPath(
     case HloOpcode::kInfeed:
     case HloOpcode::kOutfeed:
     case HloOpcode::kSort:
-    case HloOpcode::kSend:
     case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
+    case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
       return false;
   }
 }
@@ -1891,7 +1927,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
                        })));
   }
 
-  if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) {
+  if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
+      opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
     extra.push_back(StrCat("channel_id=", channel_id_));
   }
 
@@ -2071,8 +2108,10 @@ bool HloInstruction::IsFusable() const {
     case HloOpcode::kOutfeed:
     case HloOpcode::kParameter:
     case HloOpcode::kTrace:
-    case HloOpcode::kSend:
     case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
+    case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
       return false;
     // Only fuse Rng if it is used once, otherwise the random numbers generated
     // will be different in each fusion. If it is the root (user count = 0)
@@ -2279,10 +2318,14 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
       return visitor->HandleCall(this);
     case HloOpcode::kCustomCall:
       return visitor->HandleCustomCall(this);
-    case HloOpcode::kSend:
-      return visitor->HandleSend(this);
     case HloOpcode::kRecv:
       return visitor->HandleRecv(this);
+    case HloOpcode::kRecvDone:
+      return visitor->HandleRecvDone(this);
+    case HloOpcode::kSend:
+      return visitor->HandleSend(this);
+    case HloOpcode::kSendDone:
+      return visitor->HandleSendDone(this);
 
     // These opcodes are not handled here.
     case HloOpcode::kTrace:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 4d8fe6bc10c..974d43d89ee 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -181,18 +181,28 @@ class HloInstruction {
       const Shape& shape, HloInstruction* operand,
       tensorflow::StringPiece outfeed_config);
 
-  // Creates a send instruction with the given channel id, which sends the
-  // operand data to a unique receive instruction in another computation that
-  // has the same channel id.
+  // Creates an asynchronous send instruction with the given channel id, which
+  // initiates sending the operand data to a unique receive instruction in
+  // another computation that has the same channel id.
   static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
                                                     int64 channel_id);
 
-  // Creates a receive instruction with the given channel id, which receives
-  // data of the given shape from a unique send instruction in another
-  // computation that has the same channel id.
+  // Blocks until data transfer for the Send instruction (operand) is complete.
+  // The operand must be kSend.
+  static std::unique_ptr<HloInstruction> CreateSendDone(
+      HloInstruction* operand);
+
+  // Creates an asynchronous receive instruction with the given channel id,
+  // which allocates resources to receive data of the given shape from a unique
+  // send instruction in another computation that has the same channel id.
   static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
                                                     int64 channel_id);
 
+  // Blocks until data transfer for the Recv instruction (operand) is complete
+  // and returns the receive buffer. The operand must be kRecv.
+  static std::unique_ptr<HloInstruction> CreateRecvDone(
+      HloInstruction* operand);
+
   // Creates a slice instruction, where the operand is sliced by the given
   // start/limit indices.
   static std::unique_ptr<HloInstruction> CreateSlice(
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 4d4010b0253..268fa0f632d 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -121,6 +121,7 @@ HLO_MATCHER(Outfeed);
 HLO_MATCHER(Pad);
 HLO_MATCHER(Power);
 HLO_MATCHER(Recv);
+HLO_MATCHER(RecvDone);
 HLO_MATCHER(Reduce);
 HLO_MATCHER(ReducePrecision);
 HLO_MATCHER(ReduceWindow);
@@ -131,6 +132,7 @@ HLO_MATCHER(Rng);
 HLO_MATCHER(Select);
 HLO_MATCHER(SelectAndScatter);
 HLO_MATCHER(Send);
+HLO_MATCHER(SendDone);
 HLO_MATCHER(ShiftLeft);
 HLO_MATCHER(ShiftRightLogical);
 HLO_MATCHER(ShiftRightArithmetic);
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index d68fc203211..e0d02e0665c 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -97,6 +97,7 @@ namespace xla {
   V(kPower, "power")                                         \
   V(kReal, "real")                                           \
   V(kRecv, "recv")                                           \
+  V(kRecvDone, "recv-done")                                  \
   V(kReduce, "reduce")                                       \
   V(kReducePrecision, "reduce-precision")                    \
   V(kReduceWindow, "reduce-window")                          \
@@ -108,6 +109,7 @@ namespace xla {
   V(kSelect, "select")                                       \
   V(kSelectAndScatter, "select-and-scatter")                 \
   V(kSend, "send")                                           \
+  V(kSendDone, "send-done")                                  \
   V(kShiftLeft, "shift-left")                                \
   V(kShiftRightArithmetic, "shift-right-arithmetic")         \
   V(kShiftRightLogical, "shift-right-logical")               \
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index c96df50e79a..828be8490c9 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -66,7 +66,9 @@ bool IsRematerializable(const HloInstruction* instruction) {
     case HloOpcode::kInfeed:
     case HloOpcode::kParameter:
     case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
     case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
     case HloOpcode::kTrace:
     case HloOpcode::kWhile:
       return false;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index c1aa655401a..c938450891a 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -270,12 +270,40 @@ class ShapeVerifier : public DfsHloVisitor {
                                                     pad->padding_config()));
   }
 
-  Status HandleSend(HloInstruction*) override {
-    return tensorflow::Status::OK();
+  Status HandleSend(HloInstruction* send) override {
+    TF_RET_CHECK(send->users().size() == 1);
+    const HloInstruction* send_done = send->users()[0];
+    TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
+    TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
+    return CheckShape(
+        send, ShapeUtil::MakeTupleShape(
+                  {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})}));
   }
 
-  Status HandleRecv(HloInstruction*) override {
-    return tensorflow::Status::OK();
+  Status HandleSendDone(HloInstruction* send_done) override {
+    TF_RET_CHECK(send_done->operands().size() == 1);
+    const HloInstruction* send = send_done->operand(0);
+    TF_RET_CHECK(send->opcode() == HloOpcode::kSend);
+    TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done));
+    return CheckShape(send_done, ShapeUtil::MakeNil());
+  }
+
+  Status HandleRecv(HloInstruction* recv) override {
+    TF_RET_CHECK(recv->users().size() == 1);
+    const HloInstruction* recv_done = recv->users()[0];
+    TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+    TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
+    return CheckShape(recv,
+                      ShapeUtil::MakeTupleShape(
+                          {recv_done->shape(), ShapeUtil::MakeShape(U32, {})}));
+  }
+
+  Status HandleRecvDone(HloInstruction* recv_done) override {
+    TF_RET_CHECK(recv_done->operands().size() == 1);
+    const HloInstruction* recv = recv_done->operand(0);
+    TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv);
+    TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done));
+    return CheckShape(recv_done, recv->shape().tuple_shapes(0));
   }
 
   Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override {
@@ -365,6 +393,19 @@ class ShapeVerifier : public DfsHloVisitor {
                           instruction->opcode(), instruction->operands()));
   }
 
+  // Checks if the given two instructions shares the same channel id.
+  Status CheckSameChannel(const HloInstruction* instr1,
+                          const HloInstruction* instr2) {
+    if (instr1->channel_id() != instr2->channel_id()) {
+      return FailedPrecondition(
+          "Expected to have the same channel id, actual channel ids are: %s "
+          "(%lld), %s (%lld)",
+          instr1->ToString().c_str(), instr1->channel_id(),
+          instr2->ToString().c_str(), instr2->channel_id());
+    }
+    return tensorflow::Status::OK();
+  }
+
   // Returns the size of a Shape in bytes.
   const std::function<int64(const Shape&)> shape_size_fn_;
 };
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 0d1b7bc109c..dea47b1fd7b 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -113,7 +113,9 @@ namespace xla {
     case HloOpcode::kTrace:
     case HloOpcode::kWhile:
     case HloOpcode::kSend:
+    case HloOpcode::kSendDone:
     case HloOpcode::kRecv:
+    case HloOpcode::kRecvDone:
       return true;
   }
 
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index b92017c6cbc..02dc49e78c7 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -104,6 +104,21 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
   return Status::OK();
 }
 
+Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) {
+  // RecvDone doesn't create a new buffer but rather aliases its input (Recv)
+  // tuple element at {0} to its output.
+  return Status::OK();
+}
+
+Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
+  // Send creates new buffers for the top-level tuple and the context (tuple
+  // element at {1}). Tuple element at {0} is an alias of the Send operand, so
+  // we don't need to create a new Logical Buffer for that.
+  NewLogicalBuffer(send, /*index=*/{});
+  NewLogicalBuffer(send, /*index=*/{1});
+  return Status::OK();
+}
+
 Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
   // A Tuple instruction only creates the top-level buffer.
   NewLogicalBuffer(tuple, /*index=*/{});
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
index a82e83ec5c3..598d08b7203 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h
@@ -60,6 +60,8 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault {
   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
   Status HandleBitcast(HloInstruction* bitcast) override;
   Status HandleCopy(HloInstruction* copy) override;
+  Status HandleRecvDone(HloInstruction* recv_done) override;
+  Status HandleSend(HloInstruction* send) override;
   Status HandleSelect(HloInstruction* select) override;
 
   // A map from the buffer ID to the logical buffer
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index df537bd7c15..a1f9451dd48 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -253,6 +253,64 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
   return Status::OK();
 }
 
+Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) {
+  // RecvDone aliases its input (Recv) tuple element {0} to its output.
+  PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done);
+  const PointsToSet& operand_points_to_set =
+      GetPointsToSet(recv_done->operand(0));
+
+  // Recursively copy the points to set of the operand tuple {0}.
+  points_to_set.ForEachMutableElement(
+      [this, &points_to_set, &operand_points_to_set](
+          const ShapeIndex& index, PointsToSet::BufferList* buffers) {
+        ShapeIndex src_index({0});
+        for (auto element : index) {
+          src_index.push_back(element);
+        }
+        *buffers = operand_points_to_set.element(src_index);
+        for (auto& tuple_source :
+             operand_points_to_set.tuple_sources(src_index)) {
+          points_to_set.add_tuple_source(index, tuple_source);
+        }
+      });
+  return Status::OK();
+}
+
+Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) {
+  // Send creates a tuple of {aliased operand, U32 context}.
+  PointsToSet& points_to_set = CreateEmptyPointsToSet(send);
+
+  // Creates the points to set for the tuple and its element at {1}.
+  auto top_buffer = points_to_set.mutable_element(ShapeIndex({}));
+  top_buffer->push_back(
+      &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({})));
+  points_to_set.add_tuple_source({}, send);
+
+  auto context_buffer = points_to_set.mutable_element(ShapeIndex({1}));
+  context_buffer->push_back(
+      &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1})));
+
+  // Recursively copy the points to set of the operand to output tuple {0}.
+  const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0));
+  operand_points_to_set.ForEachElement(
+      [&points_to_set, &operand_points_to_set](
+          const ShapeIndex& src_index,
+          const PointsToSet::BufferList& points_to) {
+        ShapeIndex target_index({0});
+        for (auto element : src_index) {
+          target_index.push_back(element);
+        }
+        *points_to_set.mutable_element(target_index) = points_to;
+
+        for (HloInstruction* tuple :
+             operand_points_to_set.tuple_sources(src_index)) {
+          points_to_set.add_tuple_source(target_index, tuple);
+        }
+      });
+
+  return Status::OK();
+}
+
 Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) {
   tensorflow::gtl::ArraySlice<HloInstruction*> operands(tuple->operands());
   PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index e6157a1ed11..8928de107ee 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -251,6 +251,8 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
   Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
   Status HandleBitcast(HloInstruction* bitcast) override;
   Status HandleCopy(HloInstruction* copy) override;
+  Status HandleRecvDone(HloInstruction* recv_done) override;
+  Status HandleSend(HloInstruction* send) override;
   Status HandleSelect(HloInstruction* select) override;
 
   string ToString() const;
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index 694ed57fa24..dec446d4dac 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -313,6 +313,51 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
       {constant1, constant2, copy});
 }
 
+TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) {
+  // Send forwards its operand to the output tuple at {0}.
+  auto builder = HloComputation::Builder(TestName());
+  auto constant = builder.AddInstruction(
+      HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+  auto send = builder.AddInstruction(
+      HloInstruction::CreateSend(constant, /*channel_id=*/0));
+  auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
+
+  BuildModuleAndRunAnalysis(builder.Build());
+
+  EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous());
+  EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct());
+  EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous());
+  EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct());
+
+  ExpectHasTopLevelBuffers(
+      points_to_analysis_->GetPointsToSet(send).element({}), {send});
+  ExpectHasTopLevelBuffers(
+      points_to_analysis_->GetPointsToSet(send).element({0}), {constant});
+  ExpectHasTopLevelBuffers(
+      points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(),
+      {send_done});
+  ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}});
+}
+
+TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) {
+  // RecvDone forwards its operand tuple element at {0} to the output.
+  auto builder = HloComputation::Builder(TestName());
+  auto recv = builder.AddInstruction(HloInstruction::CreateRecv(
+      ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0));
+  auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
+
+  BuildModuleAndRunAnalysis(builder.Build());
+
+  EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous());
+  EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct());
+  EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous());
+  EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct());
+
+  ExpectHasTopLevelBuffers(
+      points_to_analysis_->GetPointsToSet(recv).element({}), {recv});
+  ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}});
+}
+
 TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
   // Select from two different tuples. This should create an ambiguous points to
   // set containing the union of both sides.
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index e9d182509b5..8d5bb08e518 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -2927,8 +2927,9 @@ void ComputationLowerer::Visit(
 
     case OpRequest::kRecvRequest: {
       const RecvRequest& recv_request = request.request().recv_request();
-      hlo_instruction = add_instruction(HloInstruction::CreateRecv(
+      HloInstruction* recv = add_instruction(HloInstruction::CreateRecv(
           request.output_shape(), recv_request.channel_handle().handle()));
+      hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv));
       break;
     }
 
@@ -3120,8 +3121,9 @@ void ComputationLowerer::Visit(
     case OpRequest::kSendRequest: {
       const SendRequest& send_request = request.request().send_request();
       HloInstruction* operand = lookup_instruction(send_request.operand());
-      hlo_instruction = add_instruction(HloInstruction::CreateSend(
+      HloInstruction* send = add_instruction(HloInstruction::CreateSend(
           operand, send_request.channel_handle().handle()));
+      hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send));
       break;
     }
 
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 65734f91bc6..2fac914892e 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -58,7 +58,9 @@ static bool ContainsSendOrRecv(const HloComputation* comp) {
 
 static bool IsOrContainsSendOrRecv(const HloInstruction* instr) {
   if (instr->opcode() == HloOpcode::kSend ||
-      instr->opcode() == HloOpcode::kRecv) {
+      instr->opcode() == HloOpcode::kSendDone ||
+      instr->opcode() == HloOpcode::kRecv ||
+      instr->opcode() == HloOpcode::kRecvDone) {
     return true;
   }
   for (const auto& subcomp : instr->called_computations()) {
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
index 8e1a2dcde12..d99b31dc003 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc
@@ -144,10 +144,11 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) {
   auto* while_op = computation->root_instruction();
   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
   auto* while_body = while_op->while_body();
-  while_body->AddInstruction(HloInstruction::CreateSend(
+  auto* send = while_body->AddInstruction(HloInstruction::CreateSend(
       while_body->AddInstruction(
           HloInstruction::CreateConstant(Literal::CreateR0<bool>(true))),
       /*channel_id=*/0));
+  while_body->AddInstruction(HloInstruction::CreateSendDone(send));
   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
 }
 
@@ -156,9 +157,10 @@ TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) {
   auto* while_op = computation->root_instruction();
   ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile);
   auto* while_body = while_op->while_body();
-  while_body->AddInstruction(
+  auto* recv = while_body->AddInstruction(
       HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}),
                                  /*channel_id=*/0));
+  while_body->AddInstruction(HloInstruction::CreateRecvDone(recv));
   EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
 }
 
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index df07e069a04..3741c3daac7 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -442,7 +442,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
         return false;
       }
       instruction = builder->AddInstruction(
-          HloInstruction::CreateRecv(shape, *channel_id));
+          HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id));
+      break;
+    }
+    case HloOpcode::kRecvDone: {
+      optional<int64> channel_id;
+      attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      if (channel_id != operands[0]->channel_id()) {
+        return false;
+      }
+      instruction =
+          builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0]));
       break;
     }
     case HloOpcode::kSend: {
@@ -456,6 +470,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
           HloInstruction::CreateSend(operands[0], *channel_id));
       break;
     }
+    case HloOpcode::kSendDone: {
+      optional<int64> channel_id;
+      attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      if (channel_id != operands[0]->channel_id()) {
+        return false;
+      }
+      instruction =
+          builder->AddInstruction(HloInstruction::CreateSendDone(operands[0]));
+      break;
+    }
     case HloOpcode::kGetTupleElement: {
       optional<int64> index;
       attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index a9dc3609784..ca476a4bb77 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -226,9 +226,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
 R"(HloModule TwoSendRecvBothWayRecvFist_module:
 
 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
-  %recv = f32[] recv(), channel_id=15, sharding={maximal device=1}
-  ROOT %constant = f32[] constant(2.1), sharding={maximal device=0}
-  %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
+  %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1}
+  ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1}
+  %constant = f32[] constant(2.1), sharding={maximal device=0}
+  %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
+  %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0}
 }
 
 )"
@@ -522,9 +524,11 @@ TEST_F(HloParserTest, UnexpectedAttribute) {
   const string original = R"(HloModule unexpected_attr_module:
 
 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
-  %recv = f32[] recv(), channel_id=15
+  %recv = (f32[], u32[]) recv(), channel_id=15
+  %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
   ROOT %constant = f32[] constant(2.1)
-  %send = () send(f32[] %constant), channel_id=16, calls=%recv
+  %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv
+  %send-done = () send-done((f32[], u32[]) %send), channel_id=16
 }
 
 )";
@@ -536,9 +540,11 @@ TEST_F(HloParserTest, MissingAttribute) {
   const string original = R"(HloModule missing_attr_module:
 
 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
-  %recv = f32[] recv(), channel_id=15
+  %recv = (f32[], u32[]) recv(), channel_id=15
+  %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
   ROOT %constant = f32[] constant(-2.1)
-  %send = () send(f32[] %constant)
+  %send = (f32[], u32[]) send(f32[] %constant)
+  %send-done = () send-done((f32[], u32[]) %send), channel_id=16
 }
 
 )";
@@ -550,9 +556,11 @@ TEST_F(HloParserTest, PredecessorUndefined) {
   const string original = R"(HloModule pre_not_found_module:
 
 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
-  %recv = f32[] recv(), channel_id=15
+  %recv = (f32[], u32[]) recv(), channel_id=15
+  %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15
   ROOT %constant = f32[] constant(2.1)
-  %send = () send(f32[] %constant), channel_id=16, control-predecessors={%done}
+  %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done}
+  %send-done = () send-done((f32[], u32[]) %send), channel_id=16
 }
 
 )";
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 3ca3b51a5ef..ccced8792ef 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -901,6 +901,95 @@ are all 0. Figure below shows examples of different `edge_padding` and
   <img style="width:100%" src="https://www.tensorflow.org/images/ops_pad.png">
 </div>
 
+## Recv
+
+See also
+[`ComputationBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+
+<b> `Recv(shape, channel_handle)` </b>
+
+| Arguments        | Type            | Semantics                            |
+| ---------------- | --------------- | ------------------------------------ |
+| `shape`          | `Shape`         | shape of the data to receive         |
+| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
+
+Receives data of the given shape from a `Send` instruction in another
+computation that shares the same channel handle. Returns a
+ComputationDataHandle for the received data.
+
+The client API of `Recv` operation represents synchronous communication.
+However, the instruction is internally decomposed into 2 HLO instructions
+(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also
+[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
+
+<b>`Recv(const Shape& shape, int64 channel_id)`</b>
+
+Allocates resources required to receive data from a `Send` instruction with the
+same channel_id. Returns a context for the allocated resources, which is used
+by a following `RecvDone` instruction to wait for the completion of the data
+transfer. The context is a tuple of {receive buffer (shape), request identifier
+(U32)} and it can only be used by a `RecvDone` instruction.
+
+<b> `RecvDone(HloInstruction context)` </b>
+
+Given a context created by a `Recv` instruction, waits for the data transfer to
+complete and returns the received data.
+
+## Send
+
+See also
+[`ComputationBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+
+<b> `Send(operand, channel_handle)` </b>
+
+| Arguments        | Type                    | Semantics                        |
+| ---------------- | ----------------------- | -------------------------------- |
+| `operand`        | `ComputationDataHandle` | data to send (array of type T)   |
+| `channel_handle` | `ChannelHandle`         | unique identifier for each send/recv pair |
+
+Sends the given operand data to a `Recv` instruction in another computation
+that shares the same channel handle. Does not return any data.
+
+Similar to the `Recv` operation, the client API of `Send` operation represents
+synchronous communication, and is internally decomposed into 2 HLO instructions
+(`Send` and `SendDone`) to enable asynchronous data transfers. See also
+[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
+
+<b>`Send(HloInstruction operand, int64 channel_id)`</b>
+
+Initiates an asynchronous transfer of the operand to the resources allocated by
+the `Recv` instruction with the same channel id. Returns a context, which is
+used by a following `SendDone` instruction to wait for the completion of the
+data transfer. The context is a tuple of {operand (shape), request identifier
+(U32)} and it can only be used by a `SendDone` instruction.
+
+<b> `SendDone(HloInstruction context)` </b>
+
+Given a context created by a `Send` instruction, waits for the data transfer to
+complete.  The instruction does not return any data.
+
+<b> Scheduling of channel instructions </b>
+
+The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`,
+`Send`, `SendDone`) is as below.
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+  <img style="width:70%" src="../../images/send_recv_order.png">
+</div>
+
+* `Recv` happens before `Send`
+* `Send` happens before `RecvDone`
+* `Recv` happens before `RecvDone`
+* `Send` happens before `SendDone`
+
+When the backend compilers generate a linear schedule for each computation that
+communicates via channel instructions, there must not be cycles across the
+computations. For example, below schedules lead to deadlocks.
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+  <img style="width:100%" src="../../images/send_recv_schedule.png">
+</div>
+
 ## Reduce
 
 See also

From b11a79032856722d0e51ce421455af8a8610d965 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 14:55:09 -0800
Subject: [PATCH 095/115] Adds streaming_dynamic_auc to Tensorflow contrib
 metrics. This metric differs from streaming_auc because it uses every
 prediction as a threshold rather than linearly spaced fixed thresholds.

PiperOrigin-RevId: 175217002
---
 tensorflow/contrib/metrics/__init__.py        |   2 +
 .../contrib/metrics/python/ops/metric_ops.py  | 149 ++++++++++
 .../metrics/python/ops/metric_ops_test.py     | 262 ++++++++++++++++--
 3 files changed, 385 insertions(+), 28 deletions(-)

diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py
index 302042c4dd6..8eed45c4b38 100644
--- a/tensorflow/contrib/metrics/__init__.py
+++ b/tensorflow/contrib/metrics/__init__.py
@@ -27,6 +27,7 @@ See the @{$python/contrib.metrics} guide.
 @@streaming_false_negative_rate
 @@streaming_false_negative_rate_at_thresholds
 @@streaming_auc
+@@streaming_dynamic_auc
 @@streaming_curve_points
 @@streaming_recall_at_k
 @@streaming_mean_absolute_error
@@ -88,6 +89,7 @@ from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc
 from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat
 from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance
 from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points
+from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_dynamic_auc
 from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate
 from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds
 from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 3dd1f1a6277..24692ff12fb 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -1178,6 +1178,154 @@ def streaming_auc(predictions,
       name=name)
 
 
+def _compute_dynamic_auc(labels, predictions, curve='ROC'):
+  """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
+
+  Computes the area under the ROC or PR curve using each prediction as a
+  threshold. This could be slow for large batches, but has the advantage of not
+  having its results degrade depending on the distribution of predictions.
+
+  Args:
+    labels: A `Tensor` of ground truth labels with the same shape as
+      `predictions` with values of 0 or 1 and type `int64`.
+    predictions: A 1-D `Tensor` of predictions whose values are `float64`.
+    curve: The name of the curve to be computed, 'ROC' for the Receiving
+      Operating Characteristic or 'PR' for the Precision-Recall curve.
+
+  Returns:
+    A scalar `Tensor` containing the area-under-curve value for the input.
+  """
+  # Count the total number of positive and negative labels in the input.
+  size = array_ops.size(predictions)
+  total_positive = math_ops.cast(math_ops.reduce_sum(labels), dtypes.int32)
+
+  def continue_computing_dynamic_auc():
+    """Continues dynamic auc computation, entered if labels are not all equal.
+
+    Returns:
+      A scalar `Tensor` containing the area-under-curve value.
+    """
+    # Sort the predictions descending, and the corresponding labels as well.
+    ordered_predictions, indices = nn.top_k(predictions, k=size)
+    ordered_labels = array_ops.gather(labels, indices)
+
+    # Get the counts of the unique ordered predictions.
+    _, _, counts = array_ops.unique_with_counts(ordered_predictions)
+
+    # Compute the indices of the split points between different predictions.
+    splits = math_ops.cast(
+        array_ops.pad(math_ops.cumsum(counts), paddings=[[1, 0]]), dtypes.int32)
+
+    # Count the positives to the left of the split indices.
+    positives = math_ops.cast(
+        array_ops.pad(math_ops.cumsum(ordered_labels), paddings=[[1, 0]]),
+        dtypes.int32)
+    true_positives = array_ops.gather(positives, splits)
+    if curve == 'ROC':
+      # Count the negatives to the left of every split point and the total
+      # number of negatives for computing the FPR.
+      false_positives = math_ops.subtract(splits, true_positives)
+      total_negative = size - total_positive
+      x_axis_values = math_ops.truediv(false_positives, total_negative)
+      y_axis_values = math_ops.truediv(true_positives, total_positive)
+    elif curve == 'PR':
+      x_axis_values = math_ops.truediv(true_positives, total_positive)
+      # For conformance, set precision to 1 when the number of positive
+      # classifications is 0.
+      y_axis_values = array_ops.where(
+          math_ops.greater(splits, 0),
+          math_ops.truediv(true_positives, splits),
+          array_ops.ones_like(true_positives, dtype=dtypes.float64))
+
+    # Calculate trapezoid areas.
+    heights = math_ops.add(y_axis_values[1:], y_axis_values[:-1]) / 2.0
+    widths = math_ops.abs(
+        math_ops.subtract(x_axis_values[1:], x_axis_values[:-1]))
+    return math_ops.reduce_sum(math_ops.multiply(heights, widths))
+
+  # If all the labels are the same, AUC isn't well-defined (but raising an
+  # exception seems excessive) so we return 0, otherwise we finish computing.
+  return control_flow_ops.cond(
+      math_ops.logical_or(
+          math_ops.equal(total_positive, 0),
+          math_ops.equal(total_positive, size)
+      ),
+      true_fn=lambda: array_ops.constant(0, dtypes.float64),
+      false_fn=continue_computing_dynamic_auc)
+
+
+def streaming_dynamic_auc(labels,
+                          predictions,
+                          curve='ROC',
+                          metrics_collections=(),
+                          updates_collections=(),
+                          name=None):
+  """Computes the apporixmate AUC by a Riemann sum with data-derived thresholds.
+
+  USAGE NOTE: this approach requires storing all of the predictions and labels
+  for a single evaluation in memory, so it may not be usable when the evaluation
+  batch size and/or the number of evaluation steps is very large.
+
+  Computes the area under the ROC or PR curve using each prediction as a
+  threshold. This has the advantage of being resilient to the distribution of
+  predictions by aggregating across batches, accumulating labels and predictions
+  and performing the final calculation using all of the concatenated values.
+
+  Args:
+    labels: A `Tensor` of ground truth labels with the same shape as `labels`
+      and with values of 0 or 1 whose values are castable to `int64`.
+    predictions: A `Tensor` of predictions whose values are castable to
+      `float64`. Will be flattened into a 1-D `Tensor`.
+    curve: The name of the curve for which to compute AUC, 'ROC' for the
+      Receiving Operating Characteristic or 'PR' for the Precision-Recall curve.
+    metrics_collections: An optional iterable of collections that `auc` should
+      be added to.
+    updates_collections: An optional iterable of collections that `update_op`
+      should be added to.
+    name: An optional name for the variable_scope that contains the metric
+      variables.
+
+  Returns:
+    auc: A scalar `Tensor` containing the current area-under-curve value.
+    update_op: An operation that concatenates the input labels and predictions
+      to the accumulated values.
+
+  Raises:
+    ValueError: If `labels` and `predictions` have mismatched shapes or if
+      `curve` isn't a recognized curve type.
+  """
+
+  if curve not in ['PR', 'ROC']:
+    raise ValueError('curve must be either ROC or PR, %s unknown' % curve)
+
+  with variable_scope.variable_scope(name, default_name='dynamic_auc'):
+    labels.get_shape().assert_is_compatible_with(predictions.get_shape())
+    predictions = array_ops.reshape(
+        math_ops.cast(predictions, dtypes.float64), [-1])
+    labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1])
+    with ops.control_dependencies([
+        check_ops.assert_greater_equal(
+            labels,
+            array_ops.zeros_like(labels, dtypes.int64),
+            message='labels must be 0 or 1, at least one is <0'),
+        check_ops.assert_less_equal(
+            labels,
+            array_ops.ones_like(labels, dtypes.int64),
+            message='labels must be 0 or 1, at least one is >1')
+    ]):
+      preds_accum, update_preds = streaming_concat(predictions,
+                                                   name='concat_preds')
+      labels_accum, update_labels = streaming_concat(labels,
+                                                     name='concat_labels')
+      update_op = control_flow_ops.group(update_labels, update_preds)
+      auc = _compute_dynamic_auc(labels_accum, preds_accum, curve=curve)
+      if updates_collections:
+        ops.add_to_collections(updates_collections, update_op)
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, auc)
+      return auc, update_op
+
+
 def streaming_precision_recall_at_equal_thresholds(predictions,
                                                    labels,
                                                    num_thresholds=None,
@@ -3285,6 +3433,7 @@ __all__ = [
     'streaming_accuracy',
     'streaming_auc',
     'streaming_curve_points',
+    'streaming_dynamic_auc',
     'streaming_false_negative_rate',
     'streaming_false_negative_rate_at_thresholds',
     'streaming_false_negatives',
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 6a8e58b4daf..5d0463e1f74 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -1708,6 +1708,34 @@ class StreamingCurvePointsTest(test.TestCase):
                    [[1.0, 4.0 / 6.0], [0.75, 1.0], [0.0, 1.0]])
 
 
+def _np_auc(predictions, labels, weights=None):
+  """Computes the AUC explicitly using Numpy.
+
+  Args:
+    predictions: an ndarray with shape [N].
+    labels: an ndarray with shape [N].
+    weights: an ndarray with shape [N].
+
+  Returns:
+    the area under the ROC curve.
+  """
+  if weights is None:
+    weights = np.ones(np.size(predictions))
+  is_positive = labels > 0
+  num_positives = np.sum(weights[is_positive])
+  num_negatives = np.sum(weights[~is_positive])
+
+  # Sort descending:
+  inds = np.argsort(-predictions)
+
+  sorted_labels = labels[inds]
+  sorted_weights = weights[inds]
+  is_positive = sorted_labels > 0
+
+  tp = np.cumsum(sorted_weights * is_positive) / num_positives
+  return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
+
+
 class StreamingAUCTest(test.TestCase):
 
   def setUp(self):
@@ -1896,33 +1924,6 @@ class StreamingAUCTest(test.TestCase):
 
       self.assertAlmostEqual(1, auc.eval(), 6)
 
-  def np_auc(self, predictions, labels, weights):
-    """Computes the AUC explicitly using Numpy.
-
-    Args:
-      predictions: an ndarray with shape [N].
-      labels: an ndarray with shape [N].
-      weights: an ndarray with shape [N].
-
-    Returns:
-      the area under the ROC curve.
-    """
-    if weights is None:
-      weights = np.ones(np.size(predictions))
-    is_positive = labels > 0
-    num_positives = np.sum(weights[is_positive])
-    num_negatives = np.sum(weights[~is_positive])
-
-    # Sort descending:
-    inds = np.argsort(-predictions)
-
-    sorted_labels = labels[inds]
-    sorted_weights = weights[inds]
-    is_positive = sorted_labels > 0
-
-    tp = np.cumsum(sorted_weights * is_positive) / num_positives
-    return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
-
   def testWithMultipleUpdates(self):
     num_samples = 1000
     batch_size = 10
@@ -1945,7 +1946,7 @@ class StreamingAUCTest(test.TestCase):
 
     for weights in (None, np.ones(num_samples), np.random.exponential(
         scale=1.0, size=num_samples)):
-      expected_auc = self.np_auc(predictions, labels, weights)
+      expected_auc = _np_auc(predictions, labels, weights)
 
       with self.test_session() as sess:
         enqueue_ops = [[] for i in range(num_batches)]
@@ -1974,6 +1975,211 @@ class StreamingAUCTest(test.TestCase):
         self.assertAlmostEqual(expected_auc, auc.eval(), 2)
 
 
+class StreamingDynamicAUCTest(test.TestCase):
+
+  def setUp(self):
+    super(StreamingDynamicAUCTest, self).setUp()
+    np.random.seed(1)
+    ops.reset_default_graph()
+
+  def testUnknownCurve(self):
+    with self.assertRaisesRegexp(
+        ValueError, 'curve must be either ROC or PR, TEST_CURVE unknown'):
+      metrics.streaming_dynamic_auc(labels=array_ops.ones((10, 1)),
+                                    predictions=array_ops.ones((10, 1)),
+                                    curve='TEST_CURVE')
+
+  def testVars(self):
+    metrics.streaming_dynamic_auc(
+        labels=array_ops.ones((10, 1)), predictions=array_ops.ones((10, 1)))
+    _assert_metric_variables(self, ['dynamic_auc/concat_labels/array:0',
+                                    'dynamic_auc/concat_labels/size:0',
+                                    'dynamic_auc/concat_preds/array:0',
+                                    'dynamic_auc/concat_preds/size:0'])
+
+  def testMetricsCollection(self):
+    my_collection_name = '__metrics__'
+    auc, _ = metrics.streaming_dynamic_auc(
+        labels=array_ops.ones((10, 1)),
+        predictions=array_ops.ones((10, 1)),
+        metrics_collections=[my_collection_name])
+    self.assertEqual(ops.get_collection(my_collection_name), [auc])
+
+  def testUpdatesCollection(self):
+    my_collection_name = '__updates__'
+    _, update_op = metrics.streaming_dynamic_auc(
+        labels=array_ops.ones((10, 1)),
+        predictions=array_ops.ones((10, 1)),
+        updates_collections=[my_collection_name])
+    self.assertEqual(ops.get_collection(my_collection_name), [update_op])
+
+  def testValueTensorIsIdempotent(self):
+    predictions = random_ops.random_uniform(
+        (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
+    labels = random_ops.random_uniform(
+        (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
+    auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+    with self.test_session() as sess:
+      sess.run(variables.local_variables_initializer())
+      # Run several updates.
+      for _ in xrange(10):
+        sess.run(update_op)
+      # Then verify idempotency.
+      initial_auc = auc.eval()
+      for _ in xrange(10):
+        self.assertAlmostEqual(initial_auc, auc.eval(), 5)
+
+  def testAllLabelsOnes(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant([1., 1., 1.])
+      labels = constant_op.constant([1, 1, 1])
+      auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+      self.assertEqual(0, auc.eval())
+
+  def testAllLabelsZeros(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant([1., 1., 1.])
+      labels = constant_op.constant([0, 0, 0])
+      auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+      self.assertEqual(0, auc.eval())
+
+  def testNonZeroOnePredictions(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant([2.5, -2.5, 2.5, -2.5],
+                                         dtype=dtypes_lib.float32)
+      labels = constant_op.constant([1, 0, 1, 0])
+      auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+      self.assertAlmostEqual(auc.eval(), 1.0)
+
+  def testAllCorrect(self):
+    inputs = np.random.randint(0, 2, size=(100, 1))
+    with self.test_session() as sess:
+      predictions = constant_op.constant(inputs)
+      labels = constant_op.constant(inputs)
+      auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+      self.assertEqual(1, auc.eval())
+
+  def testSomeCorrect(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant([1, 0, 1, 0])
+      labels = constant_op.constant([0, 1, 1, 0])
+      auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+      self.assertAlmostEqual(0.5, auc.eval())
+
+  def testAllIncorrect(self):
+    inputs = np.random.randint(0, 2, size=(100, 1))
+    with self.test_session() as sess:
+      predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
+      labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
+      auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+      self.assertAlmostEqual(0, auc.eval())
+
+  def testExceptionOnIncompatibleShapes(self):
+    with self.test_session() as sess:
+      predictions = array_ops.ones([5])
+      labels = array_ops.zeros([6])
+      with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
+        _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+        sess.run(variables.local_variables_initializer())
+        sess.run(update_op)
+
+  def testExceptionOnGreaterThanOneLabel(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
+      labels = constant_op.constant([2, 1, 0])
+      _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+      sess.run(variables.local_variables_initializer())
+      with self.assertRaisesRegexp(
+          errors_impl.InvalidArgumentError,
+          '.*labels must be 0 or 1, at least one is >1.*'):
+        sess.run(update_op)
+
+  def testExceptionOnNegativeLabel(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
+      labels = constant_op.constant([1, 0, -1])
+      _, update_op = metrics.streaming_dynamic_auc(labels, predictions)
+      sess.run(variables.local_variables_initializer())
+      with self.assertRaisesRegexp(
+          errors_impl.InvalidArgumentError,
+          '.*labels must be 0 or 1, at least one is <0.*'):
+        sess.run(update_op)
+
+  def testWithMultipleUpdates(self):
+    batch_size = 10
+    num_batches = 100
+    labels = np.array([])
+    predictions = np.array([])
+    tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32),
+                                   collections=[ops.GraphKeys.LOCAL_VARIABLES],
+                                   dtype=dtypes_lib.int32)
+    tf_predictions = variables.Variable(
+        array_ops.ones(batch_size),
+        collections=[ops.GraphKeys.LOCAL_VARIABLES],
+        dtype=dtypes_lib.float32)
+    auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
+    with self.test_session() as sess:
+      sess.run(variables.local_variables_initializer())
+      for _ in xrange(num_batches):
+        new_labels = np.random.randint(0, 2, size=batch_size)
+        noise = np.random.normal(0.0, scale=0.2, size=batch_size)
+        new_predictions = 0.4 + 0.2 * new_labels + noise
+        labels = np.concatenate([labels, new_labels])
+        predictions = np.concatenate([predictions, new_predictions])
+        sess.run(tf_labels.assign(new_labels))
+        sess.run(tf_predictions.assign(new_predictions))
+        sess.run(update_op)
+        expected_auc = _np_auc(predictions, labels)
+        self.assertAlmostEqual(expected_auc, auc.eval())
+
+  def testAUCPRReverseIncreasingPredictions(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant(
+          [0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
+      labels = constant_op.constant([0, 0, 1, 1])
+      auc, update_op = metrics.streaming_dynamic_auc(
+          labels, predictions, curve='PR')
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+      self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
+
+  def testAUCPRJumbledPredictions(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant(
+          [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
+      labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
+      auc, update_op = metrics.streaming_dynamic_auc(
+          labels, predictions, curve='PR')
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+      self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
+
+  def testAUCPRPredictionsLessThanHalf(self):
+    with self.test_session() as sess:
+      predictions = constant_op.constant(
+          [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
+          shape=(1, 7),
+          dtype=dtypes_lib.float32)
+      labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
+      auc, update_op = metrics.streaming_dynamic_auc(
+          labels, predictions, curve='PR')
+      sess.run(variables.local_variables_initializer())
+      sess.run(update_op)
+      self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5)
+
+
 class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
 
   def setUp(self):

From 9abe08570ffe5e4aaa9bbd1f977455e8b0dd4491 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 15:00:15 -0800
Subject: [PATCH 096/115] [TF:XLA] Clean up unused XLA options and functions.

PiperOrigin-RevId: 175217850
---
 .../compiler/jit/kernels/xla_launch_op.cc     |  1 -
 .../compiler/jit/xla_compilation_cache.cc     |  3 ---
 tensorflow/compiler/tf2xla/xla_compiler.h     |  6 ------
 .../compiler/xla/client/local_client.cc       | 20 -------------------
 tensorflow/compiler/xla/client/local_client.h | 16 ---------------
 .../compiler/xla/service/hlo_module_config.cc |  4 ++--
 .../compiler/xla/service/hlo_module_config.h  | 10 ----------
 .../compiler/xla/service/local_service.cc     | 20 -------------------
 tensorflow/compiler/xla/service/service.h     |  2 --
 9 files changed, 2 insertions(+), 80 deletions(-)

diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 27c5da08c11..e481796d9e6 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -257,7 +257,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
   options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
   options.graph_def_version = ctx->function_library()->graph_def_version();
   options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
-  options.local_executable_has_hybrid_result = true;
 
   const XlaCompiler::CompilationResult* kernel;
   xla::LocalExecutable* executable;
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 23368b6c76a..bc2eccd2779 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -227,10 +227,7 @@ Status XlaCompilationCache::BuildExecutable(
   }
   xla::ExecutableBuildOptions build_options;
   build_options.set_device_ordinal(client_->default_device_ordinal());
-  build_options.set_platform(client_->platform());
   build_options.set_result_layout(result.xla_output_shape);
-  build_options.set_has_hybrid_result(
-      options.local_executable_has_hybrid_result);
 
   auto compile_result =
       client_->Compile(*result.computation, argument_layouts, build_options);
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 4d40ca5825a..ac7d4cfb127 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -236,12 +236,6 @@ class XlaCompiler {
     // to the computation.
     bool allow_cpu_custom_calls = false;
 
-    // If 'local_executable_has_hybrid_result', the top-level pointers of the
-    // result tuple of compiled programs are stored in host memory and the
-    // nested buffers in device memory, otherwise the whole result tuple is
-    // stored in device memory.
-    bool local_executable_has_hybrid_result = false;
-
     // If not nullptr, populate_resource_manager is called with the
     // compilation device's resource manager when the compilation
     // device is created, and can be used to create metadata objects
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 15c744ecd34..b50425a09c7 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -27,16 +27,6 @@ namespace se = ::perftools::gputools;
 
 namespace xla {
 
-ExecutableBuildOptions& ExecutableBuildOptions::set_platform(
-    perftools::gputools::Platform* platform) {
-  platform_ = platform;
-  return *this;
-}
-
-perftools::gputools::Platform* ExecutableBuildOptions::platform() const {
-  return platform_;
-}
-
 ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
     int device_ordinal) {
   device_ordinal_ = device_ordinal;
@@ -56,16 +46,6 @@ const Shape* ExecutableBuildOptions::result_layout() const {
   return result_layout_set_ ? &result_layout_ : nullptr;
 }
 
-ExecutableBuildOptions& ExecutableBuildOptions::set_has_hybrid_result(
-    bool has_hybrid_result) {
-  has_hybrid_result_ = has_hybrid_result;
-  return *this;
-}
-
-bool ExecutableBuildOptions::has_hybrid_result() const {
-  return has_hybrid_result_;
-}
-
 namespace {
 StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
                                                    Backend* backend) {
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 9f985ed5275..e9eeaa0aa22 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -37,14 +37,6 @@ namespace xla {
 // LocalClient::Compile.
 class ExecutableBuildOptions {
  public:
-  // If set, this is the platform to build the computation for. This must match
-  // the underlying platform of the service. A value of nullptr indicates the
-  // option has not been set.
-  //
-  // TODO(b/28616830): Support multiple platforms.
-  ExecutableBuildOptions& set_platform(perftools::gputools::Platform* platform);
-  perftools::gputools::Platform* platform() const;
-
   // If set, this is the device to build the computation for. Valid
   // device_ordinal values are: 0 to # of devices - 1. These values are
   // identical to the device ordinal values used by StreamExecutor. The built
@@ -61,18 +53,10 @@ class ExecutableBuildOptions {
   ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
   const Shape* result_layout() const;
 
-  // If set, the executable will be built to output a hybrid
-  // ShapedBuffer with top-level tuple pointers in host memory and
-  // result buffers in device memory.
-  ExecutableBuildOptions& set_has_hybrid_result(bool has_hybrid_result);
-  bool has_hybrid_result() const;
-
  private:
-  perftools::gputools::Platform* platform_ = nullptr;
   int device_ordinal_ = -1;
   Shape result_layout_;
   bool result_layout_set_ = false;
-  bool has_hybrid_result_ = true;
 };
 
 class LocalExecutable {
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
index 8974deb530c..822e2f1f53e 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -39,8 +39,8 @@ void HloModuleConfig::SetDefaultComputationLayout(
 }
 
 string HloModuleConfig::compilation_cache_key() const {
-  string key = tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_,
-                                           "::hybrid=", has_hybrid_result_);
+  string key =
+      tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_);
   StrAppend(&key, "::(");
   std::vector<string> params;
   for (const ShapeLayout& param_layout :
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
index 4a7ead9c104..a5ee895e484 100644
--- a/tensorflow/compiler/xla/service/hlo_module_config.h
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -104,16 +104,6 @@ class HloModuleConfig {
   // Whether to enable HLO-level profiling.
   bool hlo_profiling_enabled_ = false;
 
-  // If this flag is true, the generated executable will return a ShapedBuffer
-  // holding the result of the computation. In a ShapedBuffer, tuples have their
-  // structure held in host memory and the element arrays (leaves of the tuple
-  // structure) stored in device memory. The ShapedBuffer is considered "hybrid"
-  // because its leaves are on device but its structure is stored on
-  // host. Otherwise, if this flag is false, the generated executable will
-  // return a DeviceMemoryBase where the result is held entirely in device
-  // memory.
-  bool has_hybrid_result_ = false;
-
   // Module/graph-level seed handle.
   uint64 seed_ = 0;
 
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index d4d35da9d63..06f43bd3cb2 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -68,26 +68,6 @@ LocalService::LocalService(const ServiceOptions& options,
                            std::unique_ptr<Backend> execute_backend)
     : Service(options, std::move(execute_backend)) {}
 
-namespace {
-// Returns the space required to allocate a shape. If
-// allocate_space_for_deep_copy the space includes all sub-buffers of
-// a tuple.
-int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy,
-                    TransferManager* transfer_manager) {
-  int64 size = 0;
-  // TODO(b/33492279) remove once no devices represent result tuples as
-  // contiguous buffers.
-  if (allocate_space_for_deep_copy) {
-    ShapeUtil::ForEachSubshape(
-        shape, [&size, transfer_manager](const Shape& subshape,
-                                         const ShapeIndex& /*index*/) {
-          size += transfer_manager->GetByteSizeRequirement(subshape);
-        });
-  }
-  return size;
-}
-}  // namespace
-
 StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
     const ComputationHandle& computation,
     const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
index 6646be2e9aa..47f4f0ade59 100644
--- a/tensorflow/compiler/xla/service/service.h
+++ b/tensorflow/compiler/xla/service/service.h
@@ -272,8 +272,6 @@ class Service : public ServiceInterface {
 
   // Create a Hlo module config for the given program shape and arguments.
   // execution_options is optional; if not given a default is used.
-  // has_hybrid_result is used to initialize the same-named field in
-  // HloModuleConfig -- see that class for documentation.
   StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
       const ProgramShape& program_shape,
       tensorflow::gtl::ArraySlice<const Shape*> argument_shapes,

From b57e6aaa330a2354f2f4cd26f3ffc1fd11103ff0 Mon Sep 17 00:00:00 2001
From: Sanjoy Das <sanjoy@google.com>
Date: Thu, 9 Nov 2017 15:04:22 -0800
Subject: [PATCH 097/115] Make LLVMCompilerTest less stateful.

Instead of assigning the pre and post optimization to a singleton xla::Compiler
object, prefer creating a short-lived CpuCompiler or a GpuCompiler instance on
the stack.  Without this change, adding a second test case on the
(Cpu|Gpu)Compiler in the same process triggers a use-after-free.

(Btw, LLVMCompiler should really be spelled LlvmCompiler per Google C++ style,
I'll do that rename shortly).

PiperOrigin-RevId: 175218617
---
 tensorflow/compiler/xla/tests/BUILD           |  21 +--
 .../compiler/xla/tests/llvm_compiler_test.cc  | 129 ++++++++++++++----
 2 files changed, 110 insertions(+), 40 deletions(-)

diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 2333a30ad58..3e62481629a 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1344,22 +1344,23 @@ xla_test(
     ],
 )
 
-xla_test(
+tf_cc_test(
     name = "llvm_compiler_test",
     srcs = ["llvm_compiler_test.cc"],
-    backends = [
-        "cpu",
-        "gpu",
-        "cpu_parallel",
-    ],
+    tags = ["requires-gpu-sm35"],
     deps = [
-        "//tensorflow/compiler/xla:literal_util",
+        "//tensorflow/compiler/xla:test_helpers",
+        "//tensorflow/compiler/xla/service:backend",
+        "//tensorflow/compiler/xla/service:cpu_plugin",
+        "//tensorflow/compiler/xla/service:gpu_plugin",
         "//tensorflow/compiler/xla/service:hlo",
         "//tensorflow/compiler/xla/service:llvm_compiler",
-        "//tensorflow/compiler/xla/tests:hlo_test_base",
-        "//tensorflow/compiler/xla/tests:literal_test_util",
-        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+        "//tensorflow/compiler/xla/service:platform_util",
+        "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+        "//tensorflow/compiler/xla/service/gpu:gpu_compiler",
         "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/stream_executor",
         "@llvm//:core",
     ],
 )
diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
index 458258e7ee1..70d8b764a33 100644
--- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
+++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc
@@ -14,49 +14,118 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
 #include "tensorflow/core/platform/test.h"
+#include "tensorflow/stream_executor/stream_executor.h"
 
 namespace xla {
 namespace {
 
-class LLVMCompilerTest : public HloTestBase {};
+class LLVMCompilerTest : public ::testing::Test {
+ public:
+  void SetUp() override {
+    Platform *platform = FindPlatform();
+    ASSERT_NE(platform, nullptr);
 
-XLA_TEST_F(LLVMCompilerTest, CompilerHooks) {
-  int pre_opt_hook_call_count = 0;
-  int post_opt_hook_call_count = 0;
+    BackendOptions backend_options;
+    backend_options.set_platform(platform);
+    StatusOr<std::unique_ptr<Backend>> backend_or_status =
+        Backend::CreateBackend(backend_options);
+    ASSERT_IS_OK(backend_or_status.status());
+    backend_ = backend_or_status.ConsumeValueOrDie();
+  }
 
-  auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) {
-    ++pre_opt_hook_call_count;
-    return Status::OK();
-  };
-  auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) {
-    ++post_opt_hook_call_count;
-    return Status::OK();
-  };
+  ~LLVMCompilerTest() override {}
 
-  // Create HLO module, and run the compiler.
-  auto builder = HloComputation::Builder(TestName());
-  builder.AddInstruction(
-      HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+ protected:
+  using Platform = ::perftools::gputools::Platform;
 
-  auto hlo_module = CreateNewModule();
-  hlo_module->AddEntryComputation(builder.Build());
+  explicit LLVMCompilerTest(string platform_name)
+      : platform_name_(std::move(platform_name)) {}
 
-  auto compiler = static_cast<LLVMCompiler *>(backend().compiler());
-  compiler->SetPreOptimizationHook(pre_opt_hook);
-  compiler->SetPostOptimizationHook(post_opt_hook);
+  void TestCompilerHooks(LLVMCompiler *compiler) {
+    int pre_opt_hook_call_count = 0;
+    int post_opt_hook_call_count = 0;
 
-  ASSERT_TRUE(
-      compiler
-          ->Compile(std::move(hlo_module), backend().default_stream_executor())
-          .ok());
+    auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) {
+      ++pre_opt_hook_call_count;
+      return Status::OK();
+    };
+    auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) {
+      ++post_opt_hook_call_count;
+      return Status::OK();
+    };
 
-  // Test that hooks were called.
-  EXPECT_EQ(1, pre_opt_hook_call_count);
-  EXPECT_EQ(1, post_opt_hook_call_count);
+    // Create HLO module, and run the compiler.
+    auto builder = HloComputation::Builder(TestName());
+    builder.AddInstruction(
+        HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0)));
+
+    auto hlo_module = CreateNewModule();
+    hlo_module->AddEntryComputation(builder.Build());
+
+    compiler->SetPreOptimizationHook(pre_opt_hook);
+    compiler->SetPostOptimizationHook(post_opt_hook);
+
+    ASSERT_TRUE(compiler
+                    ->Compile(std::move(hlo_module),
+                              backend_->default_stream_executor())
+                    .ok());
+
+    // Test that hooks were called.
+    EXPECT_EQ(1, pre_opt_hook_call_count);
+    EXPECT_EQ(1, post_opt_hook_call_count);
+  }
+
+ private:
+  Platform *FindPlatform() {
+    for (Platform *platform :
+         PlatformUtil::GetSupportedPlatforms().ConsumeValueOrDie()) {
+      if (platform->Name() == platform_name_) {
+        return platform;
+      }
+    }
+    return nullptr;
+  }
+
+  string platform_name_;
+  std::unique_ptr<Backend> backend_;
+
+  static string TestName() {
+    return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+  }
+
+  static std::unique_ptr<HloModule> CreateNewModule() {
+    HloModuleConfig config;
+    config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
+    return MakeUnique<HloModule>(TestName(), VersionedComputationHandle(),
+                                 config);
+  }
+};
+
+class CpuCompilerTest : public LLVMCompilerTest {
+ public:
+  CpuCompilerTest() : LLVMCompilerTest("Host") {}
+};
+
+class GpuCompilerTest : public LLVMCompilerTest {
+ public:
+  GpuCompilerTest() : LLVMCompilerTest("CUDA") {}
+};
+
+TEST_F(CpuCompilerTest, HooksTest) {
+  cpu::CpuCompiler compiler;
+  TestCompilerHooks(&compiler);
+}
+
+TEST_F(GpuCompilerTest, HooksTest) {
+  gpu::GpuCompiler compiler;
+  TestCompilerHooks(&compiler);
 }
 
 }  // namespace

From 67c3d9f7242df74492943c769719ffb863ca1af0 Mon Sep 17 00:00:00 2001
From: Alexandre Passos <apassos@google.com>
Date: Thu, 9 Nov 2017 15:55:07 -0800
Subject: [PATCH 098/115] Tensor template argument to gradienttape was
 unnecessary.

PiperOrigin-RevId: 175225805
---
 tensorflow/c/eager/tape.h                  | 112 +++++++++++----------
 tensorflow/python/eager/backprop.py        |   2 +-
 tensorflow/python/eager/imperative_grad.py |   2 +-
 tensorflow/python/eager/pywrap_tfe_src.cc  |  65 +++++++-----
 4 files changed, 104 insertions(+), 77 deletions(-)

diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 654ceb7bec4..29d73c5ca43 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -65,9 +65,6 @@ using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
 // adding gradients, getting zeroes, etc. Currently cannot be implemented
 // without using tensorflow python code, hence left unspecified here.
 //
-// Tensor is a representation of a tensor. We need to take its ID, and it needs
-// to match IDs in the tape.
-//
 // Gradient is the type returned by gradient functions. In Python TF it's either
 // Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need
 // to allow their size to be computed and they need to be passable to a backward
@@ -82,7 +79,7 @@ using OpTape = std::unordered_map<int64, OpTapeEntry<BackwardFunction>>;
 // TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
 // specialization, which is blocked by quite a few things needing to loop back
 // into python now.
-template <typename Tensor, typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction>
 class VSpace {
  public:
   virtual ~VSpace() {}
@@ -99,11 +96,7 @@ class VSpace {
   virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
 
   // Returns a Tensor which is filled with ones and like the input.
-  virtual Gradient* OnesLike(Tensor*) const = 0;
-
-  // Returns an integer which is a unique-to-within-this-program handle for this
-  // tensor.
-  virtual int64 TensorId(Tensor* tensor) const = 0;
+  virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
 
   // Calls the passed-in backward function.
   virtual Status CallBackwardFunction(
@@ -117,7 +110,7 @@ class VSpace {
 
 // Traces the execution of operations, doing eager garbage collection, and
 // exporting a full trace so other code can do backpropagation. Not thread-safe.
-template <typename Tensor, typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction>
 class GradientTape {
  public:
   GradientTape() {}
@@ -143,11 +136,11 @@ class GradientTape {
   // once) and produces the gradient of the target tensors with respect to the
   // source tensors. The output gradients are used if not empty and not
   // null. The result is populated with one tensor per target element.
-  Status ComputeGradient(
-      const VSpace<Tensor, Gradient, BackwardFunction>& vspace,
-      gtl::ArraySlice<Tensor*> target, gtl::ArraySlice<Tensor*> sources,
-      gtl::ArraySlice<Gradient*> output_gradients,
-      std::vector<Gradient*>* result);
+  Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
+                         gtl::ArraySlice<int64> target_tensor_ids,
+                         gtl::ArraySlice<int64> source_tensor_id,
+                         gtl::ArraySlice<Gradient*> output_gradients,
+                         std::vector<Gradient*>* result);
 
  private:
   TensorTape tensor_tape_;
@@ -161,8 +154,8 @@ class GradientTape {
 
 // Template instantiations here
 
-template <typename Tensor, typename Gradient, typename BackwardFunction>
-bool GradientTape<Tensor, Gradient, BackwardFunction>::ShouldRecord(
+template <typename Gradient, typename BackwardFunction>
+bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
     gtl::ArraySlice<int64> tensor_ids) {
   for (int64 i : tensor_ids) {
     if (tensor_tape_.find(i) != tensor_tape_.end()) {
@@ -172,13 +165,13 @@ bool GradientTape<Tensor, Gradient, BackwardFunction>::ShouldRecord(
   return false;
 }
 
-template <typename Tensor, typename Gradient, typename BackwardFunction>
-void GradientTape<Tensor, Gradient, BackwardFunction>::Watch(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction>
+void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
   tensor_tape_.emplace(tensor_id, -1);
 }
 
-template <typename Tensor, typename Gradient, typename BackwardFunction>
-void GradientTape<Tensor, Gradient, BackwardFunction>::RecordOperation(
+template <typename Gradient, typename BackwardFunction>
+void GradientTape<Gradient, BackwardFunction>::RecordOperation(
     const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
     gtl::ArraySlice<int64> input_tensor_id, BackwardFunction* backward_function,
     const std::function<void()>& backward_function_deleter) {
@@ -206,9 +199,8 @@ void GradientTape<Tensor, Gradient, BackwardFunction>::RecordOperation(
       op_type, tensors, ids, backward_function, backward_function_deleter};
 }
 
-template <typename Tensor, typename Gradient, typename BackwardFunction>
-void GradientTape<Tensor, Gradient, BackwardFunction>::DeleteTrace(
-    int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction>
+void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
   auto it = tensor_usage_.find(tensor_id);
   if (it == tensor_usage_.end()) {
     return;
@@ -353,15 +345,16 @@ std::vector<int64> InitialStack(
   return result;
 }
 
-template <typename Tensor, typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction>
 Status InitialGradients(
-    const VSpace<Tensor, Gradient, BackwardFunction>& vspace,
-    gtl::ArraySlice<Tensor*> target,
-    gtl::ArraySlice<Gradient*> output_gradients,
-    std::unordered_map<int64, int64> tensor_usage_counts,
+    const VSpace<Gradient, BackwardFunction>& vspace,
+    gtl::ArraySlice<int64> target_tensor_ids,
+    gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
+    const OpTape<BackwardFunction>& op_tape,
+    const std::unordered_map<int64, int64>& tensor_usage_counts,
     std::unordered_map<int64, std::vector<Gradient*>>* result) {
-  for (int i = 0; i < target.size(); ++i) {
-    int64 id = vspace.TensorId(target[i]);
+  for (int i = 0; i < target_tensor_ids.size(); ++i) {
+    const int64 id = target_tensor_ids[i];
     if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) {
       if (!output_gradients.empty() && output_gradients[i] != nullptr) {
         // TODO(apassos) figure out how to print debugging information here.
@@ -371,7 +364,31 @@ Status InitialGradients(
       }
     } else {
       if (output_gradients.empty() || output_gradients[i] == nullptr) {
-        (*result)[id].push_back(vspace.OnesLike(target[i]));
+        auto tensor_it = tensor_tape.find(id);
+        if (tensor_it != tensor_tape.end() && tensor_it->second != -1) {
+          auto op_it = op_tape.find(tensor_it->second);
+          if (op_it == op_tape.end()) {
+            return errors::Internal(
+                "Internal state of the gradient tape is invalid.");
+          }
+          bool found = false;
+          for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
+            if (op_it->second.output_tensor_info[j].id == id) {
+              found = true;
+              (*result)[id].push_back(
+                  vspace.Ones(op_it->second.output_tensor_info[j].shape,
+                              op_it->second.output_tensor_info[j].dtype));
+              break;
+            }
+          }
+          if (!found) {
+            return errors::Internal(
+                "Internal state of the gradient tape is invalid.");
+          }
+        } else {
+          // No record of the target tensor found on the tape, so no gradient
+          // needs to be computed from it. Do nothing.
+        }
       } else {
         (*result)[id].push_back(output_gradients[i]);
       }
@@ -388,29 +405,22 @@ Status InitialGradients(
 constexpr int kMinAggregateCount = 4;
 constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
 
-template <typename Tensor, typename Gradient, typename BackwardFunction>
-Status GradientTape<Tensor, Gradient, BackwardFunction>::ComputeGradient(
-    const VSpace<Tensor, Gradient, BackwardFunction>& vspace,
-    gtl::ArraySlice<Tensor*> target, gtl::ArraySlice<Tensor*> sources,
+template <typename Gradient, typename BackwardFunction>
+Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
+    const VSpace<Gradient, BackwardFunction>& vspace,
+    gtl::ArraySlice<int64> target_tensor_ids,
+    gtl::ArraySlice<int64> source_tensor_ids,
     gtl::ArraySlice<Gradient*> output_gradients,
     std::vector<Gradient*>* result) {
-  std::vector<int64> id_sources;
-  id_sources.reserve(sources.size());
-  for (Tensor* s : sources) {
-    id_sources.push_back(vspace.TensorId(s));
-  }
-  std::unordered_set<int64> sources_set(id_sources.begin(), id_sources.end());
-  std::vector<int64> id_targets;
-  id_sources.reserve(target.size());
-  for (Tensor* t : target) {
-    id_targets.push_back(vspace.TensorId(t));
-  }
+  std::unordered_set<int64> sources_set(source_tensor_ids.begin(),
+                                        source_tensor_ids.end());
   BackpropInitialState<BackwardFunction> state = PrepareBackprop(
-      id_targets, tensor_tape_, std::move(op_tape_), sources_set);
+      target_tensor_ids, tensor_tape_, std::move(op_tape_), sources_set);
   std::vector<int64> op_stack =
       InitialStack(state.op_tape, state.op_missing_tensor);
   std::unordered_map<int64, std::vector<Gradient*>> gradients;
-  Status s = InitialGradients(vspace, target, output_gradients,
+  Status s = InitialGradients(vspace, target_tensor_ids, output_gradients,
+                              tensor_tape_, state.op_tape,
                               state.tensor_usage_counts, &gradients);
   auto cleanup = [&state]() {
     // Release all backprop functions
@@ -542,8 +552,8 @@ Status GradientTape<Tensor, Gradient, BackwardFunction>::ComputeGradient(
     }
   }
   CHECK(state.op_tape.empty());
-  result->reserve(sources.size());
-  for (auto is : id_sources) {
+  result->reserve(source_tensor_ids.size());
+  for (auto is : source_tensor_ids) {
     auto grad_it = gradients.find(is);
     if (grad_it == gradients.end()) {
       result->push_back(nullptr);
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 111d7cef56a..0a92ab38a83 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -744,7 +744,7 @@ _default_vspace = imperative_grad.VSpace(
     aggregate_fn=_aggregate_grads,
     tensor_id=ops.tensor_id,
     zeros=_zeros,
-    ones_like=lambda x: ops.convert_to_tensor(array_ops.ones_like(x)))
+    ones=array_ops.ones)
 
 
 class GradientTape(object):
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 8932b7157b2..837cad974ac 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -26,7 +26,7 @@ from tensorflow.python.framework import errors
 
 VSpace = collections.namedtuple(
     "VSpace",
-    ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones_like"])
+    ["aggregate_fn", "num_elements_fn", "tensor_id", "zeros", "ones"])
 
 
 def imperative_grad(
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index d67c3b18f7b..77b49be8f88 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -443,8 +443,7 @@ void TFE_DeleteContextCapsule(PyObject* context) {
   TF_DeleteStatus(status);
 }
 
-using GradientTape =
-    tensorflow::eager::GradientTape<PyObject, PyObject, PyObject>;
+using GradientTape = tensorflow::eager::GradientTape<PyObject, PyObject>;
 
 typedef struct {
   PyObject_HEAD
@@ -630,8 +629,7 @@ void TFE_Py_TapeDeleteTrace(PyObject* tape, tensorflow::int64 tensor_id) {
   reinterpret_cast<TFE_Py_Tape*>(tape)->tape->DeleteTrace(tensor_id);
 }
 
-class PyVSpace
-    : public tensorflow::eager::VSpace<PyObject, PyObject, PyObject> {
+class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyObject> {
  public:
   explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {}
 
@@ -648,9 +646,9 @@ class PyVSpace
     if (zeros_ == nullptr) {
       return tensorflow::errors::InvalidArgument("invalid vspace");
     }
-    ones_like_ = PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_),
-                                        "ones_like");
-    if (ones_like_ == nullptr) {
+    ones_ =
+        PyObject_GetAttrString(reinterpret_cast<PyObject*>(py_vspace_), "ones");
+    if (ones_ == nullptr) {
       return tensorflow::errors::InvalidArgument("invalid vspace");
     }
     return tensorflow::Status::OK();
@@ -660,7 +658,7 @@ class PyVSpace
     Py_XDECREF(num_elements_);
     Py_XDECREF(aggregate_fn_);
     Py_XDECREF(zeros_);
-    Py_XDECREF(ones_like_);
+    Py_XDECREF(ones_);
   }
 
   tensorflow::int64 NumElements(PyObject* tensor) const final {
@@ -706,24 +704,21 @@ class PyVSpace
     return reinterpret_cast<PyObject*>(result);
   }
 
-  PyObject* OnesLike(PyObject* tensor) const final {
-    PyObject* arg_list = Py_BuildValue("(O)", tensor);
-    PyObject* result = PyEval_CallObject(ones_like_, arg_list);
-    if (result == nullptr) {
-      VLOG(1) << "Call to ones_like failed";
+  PyObject* Ones(tensorflow::TensorShape shape,
+                 tensorflow::DataType dtype) const final {
+    PyObject* py_shape = PyTuple_New(shape.dims());
+    for (int i = 0; i < shape.dims(); ++i) {
+      PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
     }
+    PyObject* py_dtype = PyLong_FromLong(static_cast<int>(dtype));
+    PyObject* arg_list = Py_BuildValue("OO", py_shape, py_dtype);
+    PyObject* result = PyEval_CallObject(ones_, arg_list);
     Py_DECREF(arg_list);
+    Py_DECREF(py_dtype);
+    Py_DECREF(py_shape);
     return result;
   }
 
-  tensorflow::int64 TensorId(PyObject* tensor) const final {
-    PyObject* py_tensor = reinterpret_cast<PyObject*>(tensor);
-    PyObject* id_field = PyObject_GetAttrString(py_tensor, "_id");
-    tensorflow::int64 id = MakeInt(id_field);
-    Py_DECREF(id_field);
-    return id;
-  }
-
   tensorflow::Status CallBackwardFunction(
       PyObject* backward_function,
       tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
@@ -781,7 +776,7 @@ class PyVSpace
   PyObject* num_elements_;
   PyObject* aggregate_fn_;
   PyObject* zeros_;
-  PyObject* ones_like_;
+  PyObject* ones_;
 };
 
 std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
@@ -799,6 +794,28 @@ std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
   return list;
 }
 
+std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
+  PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
+  if (seq == nullptr) {
+    return {};
+  }
+  int len = PySequence_Fast_GET_SIZE(seq);
+  std::vector<tensorflow::int64> list;
+  list.reserve(len);
+  for (int i = 0; i < len; ++i) {
+    PyObject* tensor = PySequence_Fast_GET_ITEM(seq, i);
+    if (EagerTensor_CheckExact(tensor)) {
+      list.push_back(EagerTensor_id(tensor));
+    } else {
+      PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
+      list.push_back(MakeInt(id_field));
+      Py_DECREF(id_field);
+    }
+  }
+  Py_DECREF(seq);
+  return list;
+}
+
 PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
                               PyObject* target, PyObject* sources,
                               PyObject* output_gradients, TF_Status* status) {
@@ -807,11 +824,11 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace,
     return nullptr;
   }
 
-  std::vector<PyObject*> target_vec = MakeTensorList(target);
+  std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
   if (PyErr_Occurred()) {
     return nullptr;
   }
-  std::vector<PyObject*> sources_vec = MakeTensorList(sources);
+  std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
   if (PyErr_Occurred()) {
     return nullptr;
   }

From d7b22fbfdf707d6c6fc8df553242da36dab20e47 Mon Sep 17 00:00:00 2001
From: Benoit Steiner <bsteiner@google.com>
Date: Thu, 9 Nov 2017 16:11:06 -0800
Subject: [PATCH 099/115] Materialize constants in more cases.

PiperOrigin-RevId: 175228264
---
 .../grappler/optimizers/constant_folding.cc   | 140 +++++++++++++-----
 .../optimizers/constant_folding_test.cc       |  29 +++-
 tensorflow/core/util/bcast.cc                 |   2 -
 3 files changed, 131 insertions(+), 40 deletions(-)

diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index a364ca487ea..02a732b0923 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -36,6 +36,7 @@ limitations under the License.
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/env.h"
 #include "tensorflow/core/public/version.h"
+#include "tensorflow/core/util/bcast.h"
 
 namespace tensorflow {
 namespace grappler {
@@ -301,6 +302,44 @@ bool ShapesEqual(const TensorShapeProto& shape1,
   return true;
 }
 
+namespace {
+bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
+                  BCast::Vec* shape, int64* min_id) {
+  if (shape_node.op() == "Shape") {
+    const std::vector<OpInfo::TensorProperties>& prop1 =
+        properties.GetInputProperties(shape_node.name());
+    if (prop1.size() != 1) {
+      return false;
+    }
+    const TensorShapeProto& shp = prop1[0].shape();
+    if (shp.unknown_rank()) {
+      return false;
+    }
+    for (const auto& dim : shp.dim()) {
+      shape->push_back(dim.size());
+      *min_id = std::min<int64>(*min_id, dim.size());
+    }
+  } else {
+    const TensorProto& raw_val = shape_node.attr().at("value").tensor();
+    if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
+      return false;
+    }
+    Tensor value(raw_val.dtype(), raw_val.tensor_shape());
+    if (!value.FromProto(raw_val)) {
+      return false;
+    }
+    for (int j = 0; j < value.NumElements(); ++j) {
+      if (raw_val.dtype() == DT_INT64) {
+        shape->push_back(value.vec<int64>()(j));
+      } else {
+        shape->push_back(value.vec<int>()(j));
+      }
+    }
+  }
+  return true;
+}
+}  // namespace
+
 Status ConstantFolding::MaterializeConstants(
     const GrapplerItem& item, const GraphProperties& properties) {
   const int node_count = graph_.node_size();
@@ -312,49 +351,76 @@ Status ConstantFolding::MaterializeConstants(
     }
     const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
     const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
-    if (shape_node1 == nullptr || shape_node1->op() != "Shape" ||
-        shape_node2 == nullptr || shape_node2->op() != "Shape") {
+    if (shape_node1 == nullptr ||
+        (shape_node1->op() != "Shape" && shape_node1->op() != "Const") ||
+        shape_node2 == nullptr ||
+        (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) {
       continue;
     }
-    const std::vector<OpInfo::TensorProperties>& prop1 =
-        properties.GetInputProperties(shape_node1->name());
-    const std::vector<OpInfo::TensorProperties>& prop2 =
-        properties.GetInputProperties(shape_node2->name());
-    if (prop1.size() != 1 || prop2.size() != 1) {
+    int64 min_id = 0;
+    BCast::Vec shape1;
+    if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
       continue;
     }
-    const TensorShapeProto& shape1 = prop1[0].shape();
-    const TensorShapeProto& shape2 = prop2[0].shape();
-    if (ShapesEqual(shape1, shape2)) {
-      DataType type = node.attr().at("T").type();
-      Tensor empty(type, TensorShape());
-      NodeDef* out[2];
-      for (int i = 0; i < 2; ++i) {
+    BCast::Vec shape2;
+    if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
+      continue;
+    }
+    // A value of -1 means we don't known anything about the dimension. Replace
+    // the -1 values with unique dimension ids since we don't want two '-1'
+    // dimensions to be considered equal.
+    for (auto& id : shape1) {
+      if (id == -1) {
+        id = --min_id;
+      }
+    }
+    for (auto& id : shape2) {
+      if (id == -1) {
+        id = --min_id;
+      }
+    }
+    BCast bcast(shape1, shape2);
+    if (!bcast.IsValid()) {
+      continue;
+    }
+    BCast::Vec reduce_dims[2];
+    reduce_dims[0] = bcast.grad_x_reduce_idx();
+    reduce_dims[1] = bcast.grad_y_reduce_idx();
+
+    const DataType type = node.attr().at("T").type();
+    NodeDef* out[2];
+    for (int j = 0; j < 2; ++j) {
+      if (!reduce_dims[j].empty()) {
+        // This is the case when a tensor dimension 1 is matched against an
+        // unknown dimension. The unknown dimension could also be equal to 1, in
+        // which case there would be no reduction.
+        out[j] = nullptr;
+      } else {
+        Tensor value(type, TensorShape({0}));
         string const_name = AddPrefixToNodeName(
-            strings::StrCat(node.name(), "-", i), kConstantFoldingConst);
-        out[i] = node_map_->GetNode(const_name);
-        if (!out[i]) {
-          out[i] = graph_.add_node();
-          *out[i] = CreateNodeDef(const_name, TensorValue(&empty));
-          out[i]->set_device(node.device());
-          node_map_->AddNode(const_name, out[i]);
+            strings::StrCat(node.name(), "-", j), kConstantFoldingConst);
+        out[j] = node_map_->GetNode(const_name);
+        if (!out[j]) {
+          out[j] = graph_.add_node();
+          *out[j] = CreateNodeDef(const_name, TensorValue(&value));
+          out[j]->set_device(node.device());
+          node_map_->AddNode(const_name, out[j]);
           string ctrl_dep =
               AddControlDependency(node.name(), &graph_, node_map_.get());
-          *out[i]->add_input() = ctrl_dep;
+          *out[j]->add_input() = ctrl_dep;
           node_map_->AddOutput(NodeName(ctrl_dep), const_name);
         }
       }
+    }
 
-      auto outputs = node_map_->GetOutputs(node.name());
-      for (const auto& output : outputs) {
-        for (int k = 0; k < output->input_size(); ++k) {
-          int port;
-          string node_name = ParseNodeName(output->input(k), &port);
-          if (node_name == node.name() && port >= 0 && port < 2) {
-            *output->mutable_input(k) = out[port]->name();
-            node_map_->UpdateInput(output->name(), node_name,
-                                   out[port]->name());
-          }
+    auto outputs = node_map_->GetOutputs(node.name());
+    for (const auto& output : outputs) {
+      for (int k = 0; k < output->input_size(); ++k) {
+        int port;
+        string node_name = ParseNodeName(output->input(k), &port);
+        if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
+          *output->mutable_input(k) = out[port]->name();
+          node_map_->UpdateInput(output->name(), node_name, out[port]->name());
         }
       }
     }
@@ -1005,15 +1071,13 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
   GraphProperties properties(item);
   Status s = properties.InferStatically();
   bool has_feed = !item.feed.empty();
-  // bool has_feed = false;
+
   if (!has_feed && s.ok()) {
     // Only use static shape information when there is no feed in the
     // graph. That's because it's possible to feed a placeholder with a tensor
     // of any shape, which could make the static information inconsistent with
     // the shapes actually fed.
-    if (s.ok()) {
-      TF_RETURN_IF_ERROR(MaterializeShapes(item, properties));
-    }
+    TF_RETURN_IF_ERROR(MaterializeShapes(item, properties));
   }
   if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) {
     TF_RETURN_IF_ERROR(MaterializeConstants(item, properties));
@@ -1040,12 +1104,14 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
 
   GrapplerItem item_to_optimize = item;
   *output = item.graph;
+  int64 node_count;
   do {
     graph_.Swap(output);
     item_to_optimize.graph = graph_;
     *output = GraphDef();
+    node_count = graph_.node_size();
     TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output));
-  } while (output->node_size() < graph_.node_size());
+  } while (output->node_size() != node_count);
 
   *output->mutable_library() = item.graph.library();
   *output->mutable_versions() = item.graph.versions();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 17f9854b599..43f84b1ddfd 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -849,10 +849,18 @@ TEST_F(ConstantFoldingTest, ConstantMaterialization) {
   Output c = ops::Mul(s.WithOpName("c"), a, b);
   Output d = ops::Shape(s.WithOpName("d"), a);
   Output e = ops::Shape(s.WithOpName("e"), b);
+
   auto f = ops::internal::BroadcastGradientArgs(s.WithOpName("f"), d, e);
   Output o1 = ops::Identity(s.WithOpName("o1"), f.r0);
   Output o2 = ops::Identity(s.WithOpName("o2"), f.r1);
 
+  Output g = ops::Placeholder(s.WithOpName("g"), DT_FLOAT,
+                              ops::Placeholder::Shape(PartialTensorShape({1})));
+  Output h = ops::Shape(s.WithOpName("h"), g);
+  auto i = ops::internal::BroadcastGradientArgs(s.WithOpName("i"), d, h);
+  Output p1 = ops::Identity(s.WithOpName("p1"), i.r0);
+  Output p2 = ops::Identity(s.WithOpName("p2"), i.r1);
+
   GrapplerItem item;
   TF_CHECK_OK(s.ToGraphDef(&item.graph));
 
@@ -881,14 +889,33 @@ TEST_F(ConstantFoldingTest, ConstantMaterialization) {
       EXPECT_EQ("Const", node.op());
       EXPECT_EQ(1, node.input_size());
       EXPECT_EQ("^f", node.input(0));
+      EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
+                       .num_elements());
     } else if (node.name() == "ConstantFolding/f-1") {
       ++found;
       EXPECT_EQ("Const", node.op());
       EXPECT_EQ(1, node.input_size());
       EXPECT_EQ("^f", node.input(0));
+      EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
+                       .num_elements());
+    } else if (node.name() == "p1") {
+      ++found;
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("ConstantFolding/i-0", node.input(0));
+    } else if (node.name() == "p2") {
+      ++found;
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("i:1", node.input(0));
+    } else if (node.name() == "ConstantFolding/i-0") {
+      ++found;
+      EXPECT_EQ("Const", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("^i", node.input(0));
+      EXPECT_EQ(0, TensorShape(node.attr().at("value").tensor().tensor_shape())
+                       .num_elements());
     }
   }
-  EXPECT_EQ(4, found);
+  EXPECT_EQ(7, found);
 }
 
 }  // namespace
diff --git a/tensorflow/core/util/bcast.cc b/tensorflow/core/util/bcast.cc
index 47e6ddb3d82..1eab7e3d024 100644
--- a/tensorflow/core/util/bcast.cc
+++ b/tensorflow/core/util/bcast.cc
@@ -68,9 +68,7 @@ BCast::BCast(const Vec& sx, const Vec& sy, const bool fewer_dims_optimization) {
       // Output shape.
       State curr = UNKNOWN;
       const int64 x_i = x[i];  // i-th dimension of x.
-      CHECK_GE(x_i, 0);
       const int64 y_i = y[i];  // i-th dimension of y.
-      CHECK_GE(y_i, 0);
       int64 o_i;   // i-th dimension of the output.
       int64 bx_i;  // i-th broadcast for x.
       int64 by_i;  // i-th broadcast for y.

From 9a5f4814bafdd53d574d3c8aabc859d8a06ba39d Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 16:11:25 -0800
Subject: [PATCH 100/115] Added some additional documentation to the swish()
 function

PiperOrigin-RevId: 175228315
---
 tensorflow/python/ops/nn_impl.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 2c83e4e29f3..431ea1186a7 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -275,9 +275,6 @@ def _swish_shape(op):
   return [op.inputs[0].shape]
 
 
-# Set noinline=True so that sigmoid(features) is re-computed during
-# backprop, and we can free the sigmoid(features) expression immediately
-# after use during the forward pass.
 @function.Defun(shape_func=_swish_shape, func_name="swish_grad", noinline=True)
 def _swish_grad(features, grad):
   """Gradient of Swish function defined below."""
@@ -287,6 +284,11 @@ def _swish_grad(features, grad):
   return grad * activation_grad
 
 
+# Naively, x * tf.nn.sigmoid(x) requires keeping both x and sigmoid(x) around
+# for backprop, effectively doubling the tensor's memory consumption. We use a
+# @Defun decorator with noinline=True so that sigmoid(features) is re-computed
+# during backprop, and we can free the sigmoid(features) expression immediately
+# after use during the forward pass.
 @function.Defun(
     grad_func=_swish_grad,
     shape_func=_swish_shape,
@@ -296,7 +298,7 @@ def swish(features):
   # pylint: disable=g-doc-args
   """Computes the Swish activation function: `x * sigmoid(x)`.
 
-  Source: "Swish: a Self-Gated Activation Function" (Ramachandran et al. 2017)
+  Source: "Searching for Activation Functions" (Ramachandran et al. 2017)
   https://arxiv.org/abs/1710.05941
 
   Args:

From 39cee098f01a56bd67be41648342f4008870b988 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 16:23:43 -0800
Subject: [PATCH 101/115] Always push updated nodes to the queue for possible
 further optimization.

PiperOrigin-RevId: 175229944
---
 .../optimizers/arithmetic_optimizer.cc         | 18 ++++++++++++------
 1 file changed, 12 insertions(+), 6 deletions(-)

diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 44d16e5a426..f2277a9b79d 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -703,7 +703,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
           node_map->AddOutput(new_transpose->name(), new_cast->name());
 
           new_nodes->push_back(new_transpose);
-          new_nodes->push_back(new_cast);
           //  Add frame dependencies that the original node might have had.
           AddFrameControlDeps(node, {new_transpose, new_cast},
                               new_transpose->input(0), {new_transpose},
@@ -880,7 +879,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
       new_mul_node->set_device(node->device());
       SetDataTypeToAttr(type, "T", new_mul_node);
       node_map->AddNode(new_mul_node->name(), new_mul_node);
-      new_nodes->push_back(new_mul_node);
       new_mul_node->add_input(new_const_node->name());
       node_map->AddOutput(new_const_node->name(), new_mul_node->name());
       new_mul_node->add_input(node->input(0));
@@ -945,7 +943,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
           new_mul_node->set_name(new_mul_node->name() + "_hoist");
           new_mul_node->set_input(0, common_factor);
           new_mul_node->set_input(1, new_add_node->name());
-          new_nodes->push_back(new_mul_node);
           node_map->AddNode(new_mul_node->name(), new_mul_node);
         }
       }
@@ -1045,10 +1042,14 @@ namespace {
 template <class T>
 class SetVector {
  public:
-  void PushBack(const T& value) {
-    CHECK(!Exists(value)) << "Value " << value << " is already in the set.";
-    set_.insert(value);
+  // Returns false if value already existed in the set, true otherwise.
+  bool PushBack(const T& value) {
+    if (!set_.insert(value).second) {
+      VLOG(2) << "Value " << value << " is already in the set.";
+      return false;
+    }
     vector_.push_back(value);
+    return true;
   }
 
   T PopBack() {
@@ -1089,6 +1090,11 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(
     }
 
     if (NodeName(simplified_tensor) != node->name()) {
+      // Always consider simplified_tensor for further optimizations.
+      const NodeDef* simplified_node = node_map.GetNode(simplified_tensor);
+      if (simplified_node != nullptr) {
+        nodes_to_simplify.PushBack(simplified_node);
+      }
       // When `node` is simplifed to another node rather than in-place, the
       // consumers of `node` are already redirected to `simplified_tensor`.
       // Re-push the consumers into `nodes_to_simplify` for further

From 0719d26b1e61d13af1754b28ae855ba094d944ea Mon Sep 17 00:00:00 2001
From: Gunhan Gulsoy <gunan@google.com>
Date: Thu, 9 Nov 2017 16:26:07 -0800
Subject: [PATCH 102/115] Increase tolerance in flaky multinomial test.

PiperOrigin-RevId: 175230217
---
 .../python/kernel_tests/distributions/multinomial_test.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
index 614a34f077b..ebc89f15c58 100644
--- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
@@ -283,10 +283,10 @@ class MultinomialTest(test.TestCase):
           dist.variance(),
           dist.stddev(),
       ])
-      self.assertAllClose(sample_mean_, analytic_mean, atol=0., rtol=0.01)
-      self.assertAllClose(sample_cov_, analytic_cov, atol=0., rtol=0.01)
-      self.assertAllClose(sample_var_, analytic_var, atol=0., rtol=0.01)
-      self.assertAllClose(sample_stddev_, analytic_stddev, atol=0., rtol=0.01)
+      self.assertAllClose(sample_mean_, analytic_mean, atol=0.01, rtol=0.01)
+      self.assertAllClose(sample_cov_, analytic_cov, atol=0.01, rtol=0.01)
+      self.assertAllClose(sample_var_, analytic_var, atol=0.01, rtol=0.01)
+      self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01)
 
   def testSampleUnbiasedNonScalarBatch(self):
     with self.test_session() as sess:

From 47f8f08f0db5bb668d73993624691e8e9d064af4 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 16:35:57 -0800
Subject: [PATCH 103/115] Support more instructions in Hlo parser: Broadcast,
 Concatenate, Map, Reduce, SelectAndScatter, Reverse, Slice, DynamicSlice,
 DynamicUpdateSlice, Transpose, BatchNormTraining, BatchNormInference,
 BatchNormGrad.

PiperOrigin-RevId: 175231463
---
 .../compiler/xla/service/hlo_instruction.cc   |  12 +-
 .../compiler/xla/tools/parser/hlo_parser.cc   | 314 +++++++++++++++++-
 .../xla/tools/parser/hlo_parser_test.cc       | 231 ++++++++++++-
 3 files changed, 540 insertions(+), 17 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index ffb933155f7..1b2161fc2e8 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1891,7 +1891,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
   if (padding_config_ != nullptr) {
     extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
   }
-  if (!slice_starts_.empty() && !slice_limits_.empty()) {
+  if (opcode() == HloOpcode::kSlice) {
     std::vector<string> bounds;
     bounds.reserve(slice_starts_.size());
     const bool omit_stride =
@@ -1904,6 +1904,16 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
     }
     extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
   }
+  if (opcode() == HloOpcode::kDynamicSlice) {
+    extra.push_back(
+        StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}"));
+  }
+  if (opcode() == HloOpcode::kBatchNormTraining ||
+      opcode() == HloOpcode::kBatchNormInference ||
+      opcode() == HloOpcode::kBatchNormGrad) {
+    extra.push_back(StrCat("epsilon=", epsilon()));
+    extra.push_back(StrCat("feature_index=", feature_index()));
+  }
 
   if (convolution_dimension_numbers_ != nullptr) {
     extra.push_back(ConvolutionDimensionNumbersToString());
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 3741c3daac7..710877b4e04 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -80,14 +80,25 @@ class HloParser {
   bool ParseOperands(std::vector<HloInstruction*>* operands,
                      const int expected_size);
 
+  // Describes the start, limit, and stride on every dimension of the operand
+  // being sliced.
+  struct SliceRanges {
+    std::vector<int64> starts;
+    std::vector<int64> limits;
+    std::vector<int64> strides;
+  };
+
   // Types of attributes.
   enum class AttrTy {
     kInt64,
+    kFloat,
+    kBracedInt64List,
     kHloComputation,
     kWindow,
     kConvolutionDimensionNumbers,
     kSharding,
     kInstructionList,
+    kSliceRanges,
   };
 
   struct AttrConfig {
@@ -131,6 +142,10 @@ class HloParser {
   // Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
   bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
 
+  bool ParseSliceRanges(SliceRanges* result);
+  bool ParseInt64List(const TokKind start, const TokKind end,
+                      const TokKind delim, std::vector<int64>* result);
+
   bool ParseParamList();
   bool ParseName(string* result);
   bool ParseAttributeName(string* result);
@@ -535,26 +550,190 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
           shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
       break;
     }
-    case HloOpcode::kBroadcast:
+    case HloOpcode::kBroadcast: {
+      optional<std::vector<int64>> broadcast_dimensions;
+      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+                             &broadcast_dimensions};
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
+          shape, operands[0], *broadcast_dimensions));
+      break;
+    }
+    case HloOpcode::kConcatenate: {
+      optional<std::vector<int64>> dimensions;
+      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+                             &dimensions};
+      if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
+          dimensions->size() != 1) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
+          shape, operands, dimensions->at(0)));
+      break;
+    }
+    case HloOpcode::kMap: {
+      optional<HloComputation*> to_apply;
+      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+                           &to_apply};
+      if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(
+          HloInstruction::CreateMap(shape, operands, *to_apply));
+      break;
+    }
+    case HloOpcode::kReduce: {
+      optional<HloComputation*> reduce_computation;
+      attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+                           &reduce_computation};
+      optional<std::vector<int64>> dimensions_to_reduce;
+      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+                             &dimensions_to_reduce};
+      if (!ParseOperands(&operands, /*expected_size=*/2) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateReduce(
+          shape, /*operand=*/operands[0], /*init_value=*/operands[1],
+          *dimensions_to_reduce, *reduce_computation));
+      break;
+    }
+    case HloOpcode::kReverse: {
+      optional<std::vector<int64>> dimensions;
+      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+                             &dimensions};
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(
+          HloInstruction::CreateReverse(shape, operands[0], *dimensions));
+      break;
+    }
+    case HloOpcode::kSelectAndScatter: {
+      optional<HloComputation*> select;
+      attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
+      optional<HloComputation*> scatter;
+      attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
+      optional<Window> window;
+      attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
+      if (!ParseOperands(&operands, /*expected_size=*/3) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction =
+          builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
+              shape, /*operand=*/operands[0], *select, *window,
+              /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
+      break;
+    }
+    case HloOpcode::kSlice: {
+      optional<SliceRanges> slice_ranges;
+      attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateSlice(
+          shape, operands[0], slice_ranges->starts, slice_ranges->limits,
+          slice_ranges->strides));
+      break;
+    }
+    case HloOpcode::kDynamicSlice: {
+      optional<std::vector<int64>> dynamic_slice_sizes;
+      attrs["dynamic_slice_sizes"] = {
+          /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
+      if (!ParseOperands(&operands, /*expected_size=*/2) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
+          shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
+          *dynamic_slice_sizes));
+      break;
+    }
+    case HloOpcode::kDynamicUpdateSlice: {
+      if (!ParseOperands(&operands, /*expected_size=*/3) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction =
+          builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+              shape, /*operand=*/operands[0], /*update=*/operands[1],
+              /*start_indices=*/operands[2]));
+      break;
+    }
+    case HloOpcode::kTranspose: {
+      optional<std::vector<int64>> dimensions;
+      attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
+                             &dimensions};
+      if (!ParseOperands(&operands, /*expected_size=*/1) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(
+          HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
+      break;
+    }
+    case HloOpcode::kBatchNormTraining: {
+      optional<float> epsilon;
+      attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
+      optional<int64> feature_index;
+      attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
+                                &feature_index};
+      if (!ParseOperands(&operands, /*expected_size=*/3) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction =
+          builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
+              shape, /*operand=*/operands[0], /*scale=*/operands[1],
+              /*offset=*/operands[2], *epsilon, *feature_index));
+      break;
+    }
+    case HloOpcode::kBatchNormInference: {
+      optional<float> epsilon;
+      attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
+      optional<int64> feature_index;
+      attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
+                                &feature_index};
+      if (!ParseOperands(&operands, /*expected_size=*/5) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction =
+          builder->AddInstruction(HloInstruction::CreateBatchNormInference(
+              shape, /*operand=*/operands[0], /*scale=*/operands[1],
+              /*offset=*/operands[2], /*mean=*/operands[3],
+              /*variance=*/operands[4], *epsilon, *feature_index));
+      break;
+    }
+    case HloOpcode::kBatchNormGrad: {
+      optional<float> epsilon;
+      attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
+      optional<int64> feature_index;
+      attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
+                                &feature_index};
+      if (!ParseOperands(&operands, /*expected_size=*/5) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
+          shape, /*operand=*/operands[0], /*scale=*/operands[1],
+          /*mean=*/operands[2], /*variance=*/operands[3],
+          /*grad_output=*/operands[4], *epsilon, *feature_index));
+      break;
+    }
     case HloOpcode::kCustomCall:
-    case HloOpcode::kConcatenate:
     case HloOpcode::kReducePrecision:
-    case HloOpcode::kMap:
     case HloOpcode::kPad:
-    case HloOpcode::kReduce:
-    case HloOpcode::kSelectAndScatter:
-    case HloOpcode::kReverse:
     case HloOpcode::kRng:
-    case HloOpcode::kSlice:
-    case HloOpcode::kDynamicSlice:
-    case HloOpcode::kDynamicUpdateSlice:
-    case HloOpcode::kTranspose:
     case HloOpcode::kFusion:
-    case HloOpcode::kBatchNormTraining:
-    case HloOpcode::kBatchNormInference:
     case HloOpcode::kInfeed:
     case HloOpcode::kOutfeed:
-    case HloOpcode::kBatchNormGrad:
     case HloOpcode::kTrace:
       return TokenError(StrCat("parsing not yet implemented for op: ",
                                HloOpcodeString(opcode)));
@@ -1121,6 +1300,19 @@ bool HloParser::ParseAttributes(
           static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
           return true;
         }
+        case AttrTy::kFloat: {
+          double result;
+          if (!ParseDouble(&result)) {
+            return false;
+          }
+          if (result > std::numeric_limits<float>::max() ||
+              result < std::numeric_limits<float>::lowest()) {
+            return TokenError("value out of range for float");
+          }
+          static_cast<optional<float>*>(attr_out_ptr)
+              ->emplace(static_cast<float>(result));
+          return true;
+        }
         case AttrTy::kHloComputation: {
           HloComputation* result;
           if (!ParseComputationName(&result)) {
@@ -1164,6 +1356,24 @@ bool HloParser::ParseAttributes(
               ->emplace(result);
           return true;
         }
+        case AttrTy::kBracedInt64List: {
+          std::vector<int64> result;
+          if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace,
+                              TokKind::kComma, &result)) {
+            return false;
+          }
+          static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
+              ->emplace(result);
+          return true;
+        }
+        case AttrTy::kSliceRanges: {
+          SliceRanges result;
+          if (!ParseSliceRanges(&result)) {
+            return false;
+          }
+          static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
+          return true;
+        }
       }
     }();
     if (!success) {
@@ -1380,6 +1590,84 @@ bool HloParser::ParseConvolutionDimensionNumbers(
   return true;
 }
 
+// ::= '{' ranges '}'
+//   ::= /*empty*/
+//   ::= range (',' range)*
+// range ::= '[' start ':' limit (':' stride)? ']'
+//
+// The slice ranges are printed as:
+//
+//  {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
+//
+// This function extracts the starts, limits, and strides as 3 vectors to the
+// result. If stride is not present, stride is 1. For example, if the slice
+// ranges is printed as:
+//
+//  {[2:3:4], [5:6:7], [8:9]}
+//
+// The the parsed result will be:
+//
+//  {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
+//
+bool HloParser::ParseSliceRanges(SliceRanges* result) {
+  if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
+    return false;
+  }
+  std::vector<std::vector<int64>> ranges;
+  if (lexer_.GetKind() == TokKind::kRbrace) {
+    // empty
+    return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
+  }
+  do {
+    ranges.emplace_back();
+    if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
+                        &ranges.back())) {
+      return false;
+    }
+  } while (EatIfPresent(TokKind::kComma));
+
+  for (const auto& range : ranges) {
+    if (range.size() != 2 && range.size() != 3) {
+      return TokenError(Printf(
+          "expects [start:limit:step] or [start:limit], but sees %ld elements.",
+          range.size()));
+    }
+  }
+
+  for (const auto& range : ranges) {
+    result->starts.push_back(range[0]);
+    result->limits.push_back(range[1]);
+    result->strides.push_back(range.size() == 3 ? range[2] : 1);
+  }
+  return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
+}
+
+// int64list ::= start int64_elements end
+// int64_elements
+//   ::= /*empty*/
+//   ::= int64_val (delim int64_val)*
+bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
+                               const TokKind delim,
+                               std::vector<int64>* result) {
+  if (!ParseToken(start, StrCat("expects an int64 list starting with ",
+                                TokKindToString(start)))) {
+    return false;
+  }
+  if (lexer_.GetKind() == end) {
+    // empty
+  } else {
+    do {
+      int64 i;
+      if (!ParseInt64(&i)) {
+        return false;
+      }
+      result->push_back(i);
+    } while (EatIfPresent(delim));
+  }
+  return ParseToken(
+      end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
+}
+
 // param_list ::= '(' param_list1 ')'
 // param_list1
 //   ::= /*empty*/
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index ca476a4bb77..fbe0409e3d1 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -44,10 +44,11 @@ std::vector<TestData> CreateTestCases() {
 "AxpyParam",
 R"(HloModule axpy_module:
 
-ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
-  %alpha = f32[2,4]{1,0} parameter(0)
+ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
+  %alpha = f32[] parameter(0)
+  %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
   %x = f32[2,4]{1,0} parameter(1)
-  %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x)
+  %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
   %y = f32[2,4]{1,0} parameter(2)
   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
 }
@@ -296,6 +297,218 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f
 }
 
+)"
+},
+// reverse(constant)
+{
+"Reverse4D",
+R"(HloModule Reverse4DFloatArrayOnDim01_module:
+
+ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
+  %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
+  ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
+}
+
+)"
+},
+// concat
+{
+"Concat",
+R"(HloModule Concat2x3With2x5_module:
+
+ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
+  %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } })
+  %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
+  ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
+}
+
+)"
+},
+// map
+{
+"Map",
+R"(HloModule MapBinaryAdder_module:
+
+%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+}
+
+ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] {
+  %param0 = f32[4]{0} parameter(0)
+  %param1 = f32[4]{0} parameter(1)
+  ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3
+}
+
+)"
+},
+// reduce
+{
+"Reduce",
+R"(HloModule ReduceR3ToR2_module:
+
+%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
+}
+
+ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] {
+  %input = f32[8,16,256]{2,1,0} parameter(0)
+  %constant = f32[] constant(0)
+  ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3
+}
+
+)"
+},
+// select and scatter
+{
+"SelectAndScatter",
+R"(HloModule R4F32OverlapSmall_module:
+
+%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
+  %lhs = f32[] parameter(0)
+  %rhs = f32[] parameter(1)
+  ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs)
+}
+
+%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
+  %lhs.1 = f32[] parameter(0)
+  %rhs.1 = f32[] parameter(1)
+  ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
+}
+
+ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
+  %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
+  %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
+  %constant.2 = f32[] constant(0)
+  ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
+}
+
+)"
+},
+// slice
+{
+"Slice",
+R"(HloModule slice_module:
+
+ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
+  %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
+  ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
+}
+
+)"
+},
+// slice, no stride
+{
+"SliceNoStride",
+R"(HloModule Slice3x3x3_To_1x3x3_F32_module:
+
+ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
+  %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
+  ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
+}
+
+)"
+},
+// slice R0
+{
+"SliceR0",
+R"(HloModule SliceR0_module:
+
+ENTRY %SliceR0.v2 () -> s32[] {
+  %constant = s32[] constant(1)
+  ROOT %slice = s32[] slice(s32[] %constant), slice={}
+}
+
+)"
+},
+// transpose
+{
+"Transpose",
+R"(HloModule Transpose_module:
+
+ENTRY %Transpose.v2 () -> s32[1,2,3] {
+  %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } })
+  ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
+}
+
+)"
+},
+// Dynamic slice
+{
+"DynamicSlice",
+R"(HloModule DynamicSlice_module:
+
+ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
+  %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
+  %constant = s32[1]{0} constant({0})
+  %start_index = s32[1]{0} parameter(1)
+  %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
+  ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
+}
+
+)"
+},
+// Dynamic update slice
+{
+"DynamicUpdateSlice",
+R"(HloModule DynamicUpdateSlice_module:
+
+ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
+  %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
+  %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
+  %start_indices = s32[4]{0} parameter(2)
+  ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
+}
+
+)"
+},
+// batch norm training
+{
+"BatchNormTraining",
+R"(HloModule BasicTraining_module:
+
+ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
+  %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } })
+  %constant.1 = f32[2]{0} constant({2, 3})
+  %constant.2 = f32[2]{0} constant({1, 2})
+  ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
+}
+
+)"
+},
+// batch norm inference
+{
+"BatchNormInference",
+R"(HloModule BatchNormInference_module:
+
+ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
+  %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
+  %offset = f32[2]{0} parameter(1)
+  %scale = f32[2]{0} parameter(2)
+  %mean = f32[2]{0} parameter(3)
+  %variance = f32[2]{0} parameter(4)
+  ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
+}
+
+)"
+},
+// batch norm grad
+{
+"BatchNormGrad",
+R"(HloModule BatchNormGrad_module:
+
+ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
+  %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
+  %scale = f32[2]{0} parameter(1)
+  %mean = f32[2]{0} parameter(2)
+  %variance = f32[2]{0} parameter(3)
+  %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
+  ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
+}
+
 )"
 }
   });
@@ -568,6 +781,18 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
                   "'done' is not defined");
 }
 
+TEST_F(HloParserTest, SliceAllowOmitStride1) {
+  const string original = R"(HloModule slice_module:
+
+ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
+  %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
+  ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
+}
+
+)";
+  TF_EXPECT_OK(Parse(original).status());
+}
+
 }  // namespace
 }  // namespace tools
 }  // namespace xla

From 2397537748552d8a7850b89d1f39dc1fc0b2a9f8 Mon Sep 17 00:00:00 2001
From: Igor Saprykin <isaprykin@google.com>
Date: Thu, 9 Nov 2017 16:40:22 -0800
Subject: [PATCH 104/115] De-flakify and re-enable tests in
 replicate_model_fn_test.py.

"Reduce metric variables" operation is a single operation across all metric variables, which means it is across all eval metrics.  Previously, an update op for every eval metric was conditioned on a copy of overall "reduce metric variables" op.  The latter was meant to be idempotent and thus the end result was supposed to be correct.

However, "reduce metric variables" op consists of a number of variable assignments and thus is not atomic.  If execution of two "reduce metric variables" ops interleaves, then the end result might come out to be incorrect.  This caused flakiness in replicate_model_fn_test.py.  To fix the problem, there is now a single copy of the "reduce metric variables" and every eval metric is associated with that single instance.

PiperOrigin-RevId: 175232016
---
 .../python/estimator/replicate_model_fn.py    |  18 +--
 .../estimator/replicate_model_fn_test.py      | 108 +++++++++---------
 2 files changed, 58 insertions(+), 68 deletions(-)

diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index 421bf18c45d..0848c5f62f3 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -357,25 +357,17 @@ def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'):
       [spec.loss for spec in tower_specs], aggregation_device,
       aggregated_loss_name)
 
-  eval_metric_ops_lists = {}
+  update_ops = []
   for tower_spec in tower_specs:
-    metrics = tower_spec.eval_metric_ops or {}
-    for name, (_, update_op) in six.iteritems(metrics):
-      update_ops = eval_metric_ops_lists.setdefault(name, ([]))
+    for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops):
       update_ops.append(update_op)
 
+  with ops_lib.control_dependencies(update_ops):
+    reduced_update_op = _reduce_metric_variables(len(tower_specs))
+
   eval_metric_ops = {}
   for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops):
-    with ops_lib.control_dependencies(eval_metric_ops_lists[name]):
-      # This operation reduces local variables across all metrics, yet is
-      # called for every metric.  This is redundant and it's done because
-      # it is hard to know what local variables correspond to what metric.
-      # Estimator is going to execute all `reduced_update_op`s as part of
-      # a group inside a single `Session.run()` call, which will avoid duplicate
-      # computation.
-      reduced_update_op = _reduce_metric_variables(len(tower_specs))
     eval_metric_ops[name] = (metric_tensor, reduced_update_op)
-
   estimator_spec['eval_metric_ops'] = eval_metric_ops
   return model_fn_lib.EstimatorSpec(**estimator_spec)
 
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index bb06700160d..21d5a9c327f 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -223,34 +223,34 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
                                            features, labels, self.params)
       del estimator_spec
 
-# TODO(isaprykin):  Resolve the source of flakinness.
-#   def test_eval(self):
-#     features = np.array([[0.01], [0.002]])
-#     labels = np.array([[0.01], [0.02]])
-#
-#     with self.test_session() as session:
-#       replicated_model_fn = replicate_model_fn.replicate_model_fn(
-#           self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
-#     estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
-#                                            labels, self.params)
-#       session.run(variables.local_variables_initializer())
-#       session.run(variables.global_variables_initializer())
-#
-#       accuracy, a = estimator_spec.eval_metric_ops['accuracy']
-#       auc, b = estimator_spec.eval_metric_ops['auc']
-#
-#       session.run([a, b])
-#       accuracy = session.run(accuracy)
-#       auc = session.run(auc)
-#
-#       # Accuracy is 0.0 (no match) in the first tower.
-#       # Accuracy is 1.0 (match) in the second tower, since the feature
-#       # times weight "c" happened to be equal to the label.
-#       total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
-#
-#       self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
-#       self.assertEqual(0, auc)
-#       self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
+  def test_eval(self):
+    features = np.array([[0.01], [0.002]])
+    labels = np.array([[0.01], [0.02]])
+
+    with self.test_session() as session:
+      replicated_model_fn = replicate_model_fn.replicate_model_fn(
+          self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1'])
+      estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features,
+                                           labels, self.params)
+      session.run(variables.local_variables_initializer())
+      session.run(variables.global_variables_initializer())
+
+      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+      auc, b = estimator_spec.eval_metric_ops['auc']
+
+      session.run([a, b])
+      accuracy = session.run(accuracy)
+      auc = session.run(auc)
+
+      # loss[i] = features[i] * 10 - labels[i].
+      # Accuracy is 0.0 (no match) in the first tower.
+      # Accuracy is 1.0 (match) in the second tower, since the feature
+      # times weight "c" happened to be equal to the label.
+      total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02))
+
+      self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01)
+      self.assertEqual(0, auc)
+      self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01)
 
   def test_predict(self):
     features = np.array([[0.01], [0.002]])
@@ -524,33 +524,31 @@ class EvalSpecTest(test_util.TensorFlowTestCase):
     }
     return metrics
 
-# TODO(isaprykin):  Resolve the source of flakinness.
-#   def test_example(self):
-#     with self.test_session() as session:
-#       tower_losses = map(self.create_constant_loss, [2, 4, 6])
-#       tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
-#       tower_specs = [
-#           self.create_estimator_spec(l, m)
-#           for l, m in zip(tower_losses, tower_metrics)
-#       ]
-#       session.run(variables.local_variables_initializer())
-#
-#       estimator_spec = replicate_model_fn._eval_spec(
-#           tower_specs, aggregation_device='/device:GPU:0')
-#
-#       accuracy, a = estimator_spec.eval_metric_ops['accuracy']
-#       auc, b = estimator_spec.eval_metric_ops['auc']
-#
-#       self.assertEqual('/device:CPU:0', accuracy.device)
-#       self.assertEqual('/device:CPU:0', auc.device)
-#
-#       session.run([a, b])
-#       accuracy = session.run(accuracy)
-#       auc = session.run(auc)
-#
-#       self.assertNear((12 - 2) / 12, accuracy, 0.01)
-#       self.assertEqual(0, auc)
-#       self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
+  def test_example(self):
+    with self.test_session() as session:
+      tower_losses = map(self.create_constant_loss, [2, 4, 6])
+      tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3])
+      tower_specs = [
+          self.create_estimator_spec(l, m)
+          for l, m in zip(tower_losses, tower_metrics)
+      ]
+      session.run(variables.local_variables_initializer())
+
+      estimator_spec = replicate_model_fn._eval_spec(
+          tower_specs, aggregation_device='/device:GPU:0')
+
+      accuracy, a = estimator_spec.eval_metric_ops['accuracy']
+      auc, b = estimator_spec.eval_metric_ops['auc']
+
+      self.assertEqual('/device:CPU:0', accuracy.device)
+      self.assertEqual('/device:CPU:0', auc.device)
+
+      session.run([a, b])
+      accuracy, auc = session.run([accuracy, auc])
+
+      self.assertNear((12 - 2) / 12, accuracy, 0.01)
+      self.assertEqual(0, auc)
+      self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss))
 
   def test_handles_single_tower(self):
     with self.test_session() as session:

From 70e5cf1c486b28579d960d191f957f869a160e34 Mon Sep 17 00:00:00 2001
From: Benoit Steiner <bsteiner@google.com>
Date: Thu, 9 Nov 2017 16:45:14 -0800
Subject: [PATCH 105/115] Improved the reporting of dimensions

PiperOrigin-RevId: 175232587
---
 tensorflow/python/grappler/model_analyzer.cc | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc
index 4ec7620bce9..7d365c3be92 100644
--- a/tensorflow/python/grappler/model_analyzer.cc
+++ b/tensorflow/python/grappler/model_analyzer.cc
@@ -59,10 +59,15 @@ void ModelAnalyzer::PrintNodeInfo(const NodeDef* node,
           if (i > 0) {
             os << ", ";
           }
-          if (prop.shape().dim(i).size() < 0) {
+          if (prop.shape().dim(i).size() >= 0) {
+            // Print the actual dimension.
+            os << prop.shape().dim(i).size();
+          } else if (prop.shape().dim(i).size() == -1) {
+            // We don't know anything about the dimension.
             os << "?";
           } else {
-            os << prop.shape().dim(i).size();
+            // Symbolic dimension.
+            os << "x" << -prop.shape().dim(i).size();
           }
         }
         os << "]";

From b31493889da917c9b78aeb00e23a00e398272c26 Mon Sep 17 00:00:00 2001
From: Justin Lebar <jlebar@google.com>
Date: Thu, 9 Nov 2017 17:10:33 -0800
Subject: [PATCH 106/115] [StreamExecutor] LOG(ERROR) the driver version when
 cudnnCreate fails.

Previously we LOG(INFO)'ed the driver version, which meant it wouldn't
be printed unless you passed --logtostderr.  But this information is
pretty important, especially since cudnnCreate failing is likely to be a
fatal error.

PiperOrigin-RevId: 175235628
---
 tensorflow/stream_executor/cuda/cuda_dnn.cc | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index a20334e40a5..ad8164c7f98 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -390,8 +390,8 @@ port::Status CudnnSupport::Init() {
                  << DriverVersionStatusToString(result);
     } else {
       const auto& version = result.ValueOrDie();
-      LOG(INFO) << "possibly insufficient driver version: "
-                << DriverVersionToString(version);
+      LOG(ERROR) << "possibly insufficient driver version: "
+                 << DriverVersionToString(version);
       // OS X kernel driver does not report version accurately
 #if !defined(__APPLE__)
       if (std::get<0>(version) < 340) {

From 9268d1b471cc9f37011d145bc39d0b63d2125c1f Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Thu, 9 Nov 2017 17:55:08 -0800
Subject: [PATCH 107/115] Hlo parser: support padding.

Also, give PaddingConfig its own ToString format.

PiperOrigin-RevId: 175239832
---
 .../compiler/xla/service/hlo_instruction.cc   | 18 ++++++-
 .../compiler/xla/service/hlo_instruction.h    |  2 +
 .../compiler/xla/tools/parser/README.md       |  6 +--
 .../compiler/xla/tools/parser/hlo_lexer.cc    | 13 ++---
 .../compiler/xla/tools/parser/hlo_lexer.h     |  2 +-
 .../compiler/xla/tools/parser/hlo_parser.cc   | 51 +++++++++++++++++-
 .../xla/tools/parser/hlo_parser_test.cc       | 52 ++++++++++++++++++-
 .../compiler/xla/tools/parser/hlo_token.h     |  2 +-
 8 files changed, 130 insertions(+), 16 deletions(-)

diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 1b2161fc2e8..674d3e3836a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1889,7 +1889,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
     extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
   }
   if (padding_config_ != nullptr) {
-    extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
+    extra.push_back(
+        StrCat("padding=", xla::PaddingConfigToString(*padding_config_)));
   }
   if (opcode() == HloOpcode::kSlice) {
     std::vector<string> bounds;
@@ -2894,6 +2895,21 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
   return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str());
 }
 
+string PaddingConfigToString(const PaddingConfig& padding) {
+  bool has_interior_padding =
+      std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
+                  [](const PaddingConfig::PaddingConfigDimension& dim) {
+                    return dim.interior_padding() != 0;
+                  });
+  return Join(
+      padding.dimensions(), "x",
+      [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
+        StrAppend(
+            out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
+            has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
+      });
+}
+
 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
   return os << ToString(kind);
 }
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 974d43d89ee..64a88164a70 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1234,6 +1234,8 @@ string ToString(HloInstruction::FusionKind kind);
 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
     const string& kind_name);
 
+string PaddingConfigToString(const PaddingConfig& padding);
+
 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
 
 // Map classes that guarantee a deterministic iteration order when the key is
diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md
index 986041caf61..b768b94e770 100644
--- a/tensorflow/compiler/xla/tools/parser/README.md
+++ b/tensorflow/compiler/xla/tools/parser/README.md
@@ -54,9 +54,9 @@ attribute
 attribute_value
   : kInt
   | kName
-  | [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,} /*dim_labels_pattern*/
-  | [0-9]+(x[0-9]+)+                     /*dxd_pattern*/
-  | [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*       /*window_pad_pattern*/
+  | [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}                /*dim_labels_pattern*/
+  | [0-9]+(x[0-9]+)+                                    /*dxd_pattern*/
+  | [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*  /*pad_pattern*/
   | '{' sub_attributes '}'
   ;
 
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
index f70386411cf..b5befbf58ba 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
@@ -254,13 +254,13 @@ TokKind HloLexer::LexPercent() {
 }
 
 // Lex integer and floating-point values, -inf, and patterns for dim labels,
-// dxd (e.g. 1x2x3), and window pad.
+// dxd (e.g. 1x2x3), and pad.
 //
 // fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
 // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
 // dim_labels_pattern ::= [0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,}
 // dxd_pattern ::= [0-9]+(x[0-9]+)+
-// window_pad_pattern ::= [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*
+// pad_pattern ::= [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*
 // int ::=  [-]?[0-9]+
 // negative inf ::= '-inf'
 TokKind HloLexer::LexNumberOrPattern() {
@@ -277,7 +277,8 @@ TokKind HloLexer::LexNumberOrPattern() {
   static LazyRE2 dim_labels_pattern = {
       R"([0-9bf]{3,}_[0-9io]{3,}->[0-9bf]{3,})"};
   static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
-  static LazyRE2 pad_pattern = {R"([0-9]+_[0-9]+(x[0-9]+_[0-9]+)*)"};
+  static LazyRE2 pad_pattern = {
+      R"([0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*)"};
 
   if (RE2::Consume(&consumable, *dim_labels_pattern)) {
     current_ptr_ = consumable.begin();
@@ -294,7 +295,7 @@ TokKind HloLexer::LexNumberOrPattern() {
   if (RE2::Consume(&consumable, *pad_pattern)) {
     current_ptr_ = consumable.begin();
     str_val_.assign(token_start_, current_ptr_);
-    return TokKind::kWindowPad;
+    return TokKind::kPad;
   }
 
   static LazyRE2 int_pattern = {R"([-]?\d+)"};
@@ -395,8 +396,8 @@ string TokKindToString(TokKind kind) {
       return "kDimLabels";
     case TokKind::kDxD:
       return "kDxD";
-    case TokKind::kWindowPad:
-      return "kWindowPad";
+    case TokKind::kPad:
+      return "kPad";
     case TokKind::kShape:
       return "kShape";
     case TokKind::kOpcode:
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
index 74e6829180a..79c4f271a1d 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h
@@ -45,7 +45,7 @@ class HloLexer {
       case TokKind::kAttributeName:
       case TokKind::kDimLabels:
       case TokKind::kDxD:
-      case TokKind::kWindowPad:
+      case TokKind::kPad:
         return str_val_;
       default:
         LOG(FATAL) << "This token does not have string value";
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 710877b4e04..fed0492a54c 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -99,6 +99,7 @@ class HloParser {
     kSharding,
     kInstructionList,
     kSliceRanges,
+    kPaddingConfig,
   };
 
   struct AttrConfig {
@@ -134,6 +135,7 @@ class HloParser {
   bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
   bool ParseWindow(Window* window);
   bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
+  bool ParsePaddingConfig(PaddingConfig* padding);
   bool ParseSharding(OpSharding* sharding);
   bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
 
@@ -727,9 +729,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
           /*grad_output=*/operands[4], *epsilon, *feature_index));
       break;
     }
+    case HloOpcode::kPad: {
+      optional<PaddingConfig> padding;
+      attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
+      if (!ParseOperands(&operands, /*expected_size=*/2) ||
+          !ParseAttributes(attrs)) {
+        return false;
+      }
+      instruction = builder->AddInstruction(HloInstruction::CreatePad(
+          shape, operands[0], /*padding_value=*/operands[1], *padding));
+      break;
+    }
     case HloOpcode::kCustomCall:
     case HloOpcode::kReducePrecision:
-    case HloOpcode::kPad:
     case HloOpcode::kRng:
     case HloOpcode::kFusion:
     case HloOpcode::kInfeed:
@@ -1374,6 +1386,14 @@ bool HloParser::ParseAttributes(
           static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
           return true;
         }
+        case AttrTy::kPaddingConfig: {
+          PaddingConfig result;
+          if (!ParsePaddingConfig(&result)) {
+            return false;
+          }
+          static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
+          return true;
+        }
       }
     }();
     if (!success) {
@@ -1774,7 +1794,7 @@ bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
   if (!pad->empty()) {
     return TokenError("sub-attribute 'pad=' already exists");
   }
-  if (lexer_.GetKind() != TokKind::kWindowPad) {
+  if (lexer_.GetKind() != TokKind::kPad) {
     return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
   }
   string str = lexer_.GetStrVal();
@@ -1792,6 +1812,33 @@ bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
   return true;
 }
 
+// This is the inverse xla::ToString(PaddingConfig). The padding config string
+// looks like "0_0_0x3_3_1". The string is first separated by 'x', each
+// substring represents one PaddingConfigDimension. The substring is 3 (or 2)
+// numbers joined by '_'.
+bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
+  if (lexer_.GetKind() != TokKind::kPad) {
+    return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
+  }
+  string str = lexer_.GetStrVal();
+  std::vector<string> padding_str = Split(str, 'x');
+  for (const auto& padding_dim_str : padding_str) {
+    std::vector<int64> padding_dim;
+    if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
+        (padding_dim.size() != 2 && padding_dim.size() != 3)) {
+      return TokenError(
+          "expects padding config pattern like 'low_high_interior' or "
+          "'low_high'");
+    }
+    auto* dim = padding->add_dimensions();
+    dim->set_edge_padding_low(padding_dim[0]);
+    dim->set_edge_padding_high(padding_dim[1]);
+    dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
+  }
+  lexer_.Lex();
+  return true;
+}
+
 bool HloParser::ParseOpcode(HloOpcode* result) {
   VLOG(1) << "ParseOpcode";
   if (lexer_.GetKind() != TokKind::kOpcode) {
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
index fbe0409e3d1..d19c6e18774 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
@@ -36,6 +36,10 @@ string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
   return data.param.test_name;
 }
 
+// For each string below, we check that:
+//  - we parse it to an HloModule successfully, and
+//  - the stringification of the resulting HloModule is equal to our original
+//    string.
 std::vector<TestData> CreateTestCases() {
   // clang-format off
   return std::vector<TestData>({
@@ -509,6 +513,32 @@ ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], varia
   ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
 }
 
+)"
+},
+// pad
+{
+"Pad",
+R"(HloModule Pad1DS3Array_module:
+
+ENTRY %Pad1DS3Array.v3 () -> f32[8] {
+  %constant = f32[3]{0} constant({1, 2, 3})
+  %constant.1 = f32[] constant(0.1)
+  ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
+}
+
+)"
+},
+// pad has interior
+{
+"PadHasInterior",
+R"(HloModule PadHasInterior_module:
+
+ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
+  %input = f32[1,25,7,7]{3,2,1,0} parameter(0)
+  %constant = f32[] constant(-5.123)
+  ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
+}
+
 )"
 }
   });
@@ -523,7 +553,10 @@ class HloParserTest : public ::testing::Test,
         << "'" << s << "' does not contain '" << expected << "'";
   }
 
-  void ExpectSuccess() {
+  // Expects "ToString(Parse(string)) == string", that is, parses the string,
+  // asserts that it succeeded, stringifies the parsed module, and checks that
+  // the it equals the original string.
+  void ExpectEqual() {
     const string& original = GetParam().module_string;
     auto result = Parse(original);
     TF_EXPECT_OK(result.status());
@@ -532,7 +565,7 @@ class HloParserTest : public ::testing::Test,
   }
 };
 
-TEST_P(HloParserTest, Run) { ExpectSuccess(); }
+TEST_P(HloParserTest, Run) { ExpectEqual(); }
 
 INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
                         ::testing::ValuesIn(CreateTestCases()),
@@ -793,6 +826,21 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
   TF_EXPECT_OK(Parse(original).status());
 }
 
+TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
+  const string original = R"(HloModule window_pad_module:
+
+ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
+  %input = f32[1,2,1]{2,1,0} parameter(0)
+  %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
+  %filter = f32[1,1,1]{2,1,0} parameter(1)
+  ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
+}
+
+)";
+  ExpectHasSubstr(Parse(original).status().error_message(),
+                  "expects padding_low and padding_high separated by '_'");
+}
+
 }  // namespace
 }  // namespace tools
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h
index 15ab8b1cccf..9afd2fac231 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_token.h
+++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h
@@ -59,7 +59,7 @@ enum class TokKind {
   kAttributeName,  // dimensions=
   kDimLabels,      // [0-9bf]+_[0-9io]+->[0-9bf]+
   kDxD,            // [0-9]+(x[0-9]+)+
-  kWindowPad,      // [0-9]+_[0-9]+(x[0-9]+_[0-9]+)*
+  kPad,            // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*
   kShape,          // f32[2,3]{1,0}
   kOpcode,         // add
   kInt,            // 42

From 80646b480fded909ec439e32165223046b445f1f Mon Sep 17 00:00:00 2001
From: Justin Lebar <jlebar@google.com>
Date: Thu, 9 Nov 2017 19:27:20 -0800
Subject: [PATCH 108/115] [XLA] Don't deemphasize nodes inside of
 subcomputations in dumped XLA graphs.

Nodes inside of subcomputations (e.g. fusion computations) are always
printed by the HLO graph dumper.  Before this change, the dumper was not
fully aware of this fact, leading it to mark as "deemphasized" (i.e.
draw as gray with a dashed outline) nodes that had no business of being
deemphasized.

PiperOrigin-RevId: 175247474
---
 tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 67e0238c4af..04b3059fb12 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1303,7 +1303,9 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
 
   auto is_displayed = [&](const HloInstruction* instr) {
     // Constants are displayed inline with their users; they're never omitted.
-    return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant;
+    // Nodes in subcomputations are always shown.
+    return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant ||
+           instr->parent() != root->parent();
   };
 
   // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we

From badd35648851c0b84fdbd997b1f6e9aa20122216 Mon Sep 17 00:00:00 2001
From: Yunxing Dai <yunxing@google.com>
Date: Thu, 9 Nov 2017 20:45:39 -0800
Subject: [PATCH 109/115] Add bfloat support to XLA.

This is necessary in providing bfloat support in GPU backend.
RELNOTES: bfloat support is now added to XLA infra.
PiperOrigin-RevId: 175252067
---
 tensorflow/compiler/tf2xla/type_util.cc       |   3 +
 tensorflow/compiler/xla/BUILD                 |   1 +
 tensorflow/compiler/xla/literal_util.cc       |  99 ++++++-
 tensorflow/compiler/xla/literal_util.h        |  23 ++
 tensorflow/compiler/xla/literal_util_test.cc  |  62 +++++
 tensorflow/compiler/xla/primitive_util.cc     |   8 +-
 tensorflow/compiler/xla/primitive_util.h      |   7 +
 tensorflow/compiler/xla/service/backend.cc    |   4 +-
 .../xla/service/cpu/cpu_runtime_test.cc       |   4 +-
 .../compiler/xla/service/hlo_evaluator.cc     |   4 +
 tensorflow/compiler/xla/service/hlo_runner.cc |   3 +-
 tensorflow/compiler/xla/shape_util.cc         |   1 +
 .../compiler/xla/tests/literal_test_util.cc   |  13 +-
 .../xla/tests/local_client_test_base.cc       |   3 +-
 tensorflow/compiler/xla/types.h               |   3 +
 tensorflow/compiler/xla/xla_data.proto        |  13 +-
 tensorflow/core/framework/bfloat16.cc         |  30 +--
 tensorflow/core/framework/bfloat16_test.cc    |  92 +++++++
 tensorflow/core/framework/numeric_types.h     | 251 +++++++++++++++++-
 19 files changed, 580 insertions(+), 44 deletions(-)

diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc
index 1efbe0ffb17..c969212a1bf 100644
--- a/tensorflow/compiler/tf2xla/type_util.cc
+++ b/tensorflow/compiler/tf2xla/type_util.cc
@@ -49,6 +49,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
     case tensorflow::DT_UINT64:
       *type = xla::U64;
       return Status::OK();
+    case tensorflow::DT_BFLOAT16:
+      *type = xla::BF16;
+      return Status::OK();
     case tensorflow::DT_HALF:
       *type = xla::F16;
       return Status::OK();
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 660f419e464..f6e405744a1 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -77,6 +77,7 @@ cc_library(
     hdrs = ["types.h"],
     visibility = [":friends"],
     deps = [
+        "//tensorflow/core:framework_lite",
         "//tensorflow/core:lib",
         "//third_party/eigen3",
     ],
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 0cb2223ae5a..93d3cd425f0 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -33,6 +33,20 @@ limitations under the License.
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
+namespace {
+using tensorflow::int64;
+
+constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
+
+// Converts between little and big endian, assuming elements in the array are 16
+// bits long.
+void ConvertEndianShort(char* bytes, int64 size) {
+  CHECK_EQ(size / 2, 0);
+  for (int64 i = 0; i < size; i += 2) {
+    std::swap(bytes[i], bytes[i + 1]);
+  }
+}
+}  // namespace
 
 namespace xla {
 
@@ -169,6 +183,8 @@ Status Literal::Copy(const Literal& src_literal,
       return CopyRange<int64>(src_literal, src_base, dest_base, copy_size);
     case F16:
       return CopyRange<half>(src_literal, src_base, dest_base, copy_size);
+    case BF16:
+      return CopyRange<bfloat16>(src_literal, src_base, dest_base, copy_size);
     case F32:
       return CopyRange<float>(src_literal, src_base, dest_base, copy_size);
     case F64:
@@ -200,6 +216,8 @@ Status Literal::Copy(const Literal& src_literal,
       return *Literal::CreateR0<int64>(0);
     case F16:
       return *Literal::CreateR0<half>(static_cast<half>(0.0f));
+    case BF16:
+      return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
     case F32:
       return *Literal::CreateR0<float>(0);
     case F64:
@@ -285,6 +303,9 @@ Status Literal::Copy(const Literal& src_literal,
     case F16:
       return *Literal::CreateR0<half>(
           static_cast<half>(-std::numeric_limits<float>::infinity()));
+    case BF16:
+      return *Literal::CreateR0<bfloat16>(
+          static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
     case TUPLE:
       LOG(FATAL) << "tuple element type has no minimum value";
     case OPAQUE:
@@ -321,6 +342,9 @@ Status Literal::Copy(const Literal& src_literal,
     case F16:
       return *Literal::CreateR0<half>(
           static_cast<half>(std::numeric_limits<float>::infinity()));
+    case BF16:
+      return *Literal::CreateR0<bfloat16>(
+          static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
     case TUPLE:
       LOG(FATAL) << "tuple element type has no maximum value";
     case OPAQUE:
@@ -428,6 +452,7 @@ std::unique_ptr<Literal> Literal::Transpose(
   // The shape with affine layout resulting from that operation will be
   // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the
   // most minor.
+  //
   // Essentially, given MinMaj(Di) the position of the Di dimension within the
   // minor to major vector, and given T(Di) the index that the original Di
   // dimension has within the transposed array, a layout is affine if
@@ -536,6 +561,9 @@ string Literal::GetAsString(
     }
     case F16:
       return tensorflow::strings::StrCat(Get<half>(multi_index));
+    case BF16:
+      return tensorflow::strings::StrCat(
+          static_cast<float>(Get<bfloat16>(multi_index)));
     default:
       return tensorflow::strings::StrCat(
           "[", PrimitiveType_Name(shape().element_type()), "]");
@@ -743,6 +771,8 @@ void* Literal::MutableInternalData() {
       return reinterpret_cast<void*>(c64s_.data());
     case F16:
       return reinterpret_cast<void*>(f16s_.data());
+    case BF16:
+      return reinterpret_cast<void*>(bf16s_.data());
     default:
       LOG(FATAL) << "primitive type not supported in literals: "
                  << PrimitiveType_Name(shape().element_type());
@@ -785,6 +815,9 @@ void Literal::Reserve(int64 num_elements) {
     case F16:
       Resize<half>(num_elements, static_cast<half>(0.0f));
       break;
+    case BF16:
+      Resize<bfloat16>(num_elements, static_cast<bfloat16>(0.0f));
+      break;
     default:
       LOG(FATAL) << "primitive type not supported in literals: "
                  << PrimitiveType_Name(shape().element_type());
@@ -824,6 +857,9 @@ tensorflow::Status Literal::ValidateLiteral() const {
     case F16:
       actual = f16s().size() / sizeof(half);
       break;
+    case BF16:
+      actual = bf16s().size();
+      break;
     default:
       return tensorflow::errors::Unimplemented(
           "unhandled element type for literal validation: " +
@@ -920,6 +956,7 @@ StatusOr<std::unique_ptr<Literal>> ConvertIfDestTypeMatches(
     CONVERT_IF_TYPES_MATCH(F16)
     CONVERT_IF_TYPES_MATCH(F32)
     CONVERT_IF_TYPES_MATCH(F64)
+    CONVERT_IF_TYPES_MATCH(BF16)
 #undef CONVERT_IF_TYPES_MATCH
     case C64:
       return ConvertToC64<primitive_src_type>(src_literal);
@@ -949,8 +986,9 @@ StatusOr<std::unique_ptr<Literal>> Literal::Convert(
     CONVERT_IF_DEST_TYPE_MATCHES(F16)
     CONVERT_IF_DEST_TYPE_MATCHES(F32)
     CONVERT_IF_DEST_TYPE_MATCHES(F64)
+    CONVERT_IF_DEST_TYPE_MATCHES(BF16)
 #undef CONVERT_IF_DEST_TYPE_MATCHES
-    // Other types are not yet supported.
+      // Other types are not yet supported.
     default:
       return InvalidArgument("Unimplemented: Convert from type %s to type %s",
                              PrimitiveType_Name(shape().element_type()).c_str(),
@@ -1019,6 +1057,8 @@ bool Literal::operator==(const Literal& other) const {
         return EqualElements<double>(*this, other, 0, &multi_index);
       case F16:
         return EqualElements<half>(*this, other, 0, &multi_index);
+      case BF16:
+        return EqualElements<bfloat16>(*this, other, 0, &multi_index);
       case C64:
         return EqualElements<complex64>(*this, other, 0, &multi_index);
       default:
@@ -1128,13 +1168,18 @@ tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice() {
 
 template <>
 tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice<half>() {
-  // TODO - there is an endianess problem here. fix it, or wait for uint16
-  //        support in protobuf
   auto values = mutable_f16s();
   return tensorflow::gtl::MutableArraySlice<half>(values->data(),
                                                   values->size());
 }
 
+template <>
+tensorflow::gtl::MutableArraySlice<bfloat16>
+Literal::GetMutableArraySlice<bfloat16>() {
+  auto values = mutable_bf16s();
+  return {values->data(), values->size()};
+}
+
 template <>
 tensorflow::gtl::ArraySlice<bool> Literal::GetArraySlice<bool>() const {
   CHECK_EQ(shape().element_type(), PRED);
@@ -1205,6 +1250,12 @@ tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const {
                                            f16s().size() / sizeof(half));
 }
 
+template <>
+tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const {
+  CHECK_EQ(shape().element_type(), BF16);
+  return {bf16s().data(), bf16s().size()};
+}
+
 template <>
 tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
     const {
@@ -1253,6 +1304,9 @@ bool Literal::IsAll(int8 value) const {
       return AllElementsEqualValue<double>(*this, value);
     case F16:
       return AllElementsEqualValue<half>(*this, static_cast<half>(value));
+    case BF16:
+      return AllElementsEqualValue<bfloat16>(*this,
+                                             static_cast<bfloat16>(value));
     case PRED:
       if (value == 0) {
         return AllElementsEqualValue<bool>(*this, false);
@@ -1274,6 +1328,9 @@ bool Literal::IsAllFloat(float value) const {
       return AllElementsEqualValue<double>(*this, value);
     case F16:
       return AllElementsEqualValue<half>(*this, static_cast<half>(value));
+    case BF16:
+      return AllElementsEqualValue<bfloat16>(*this,
+                                             static_cast<bfloat16>(value));
     default:
       return false;
   }
@@ -1310,6 +1367,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice<int64> indices) const {
       return Get<complex64>(indices) == complex64(0.0f, 0.0f);
     case F16:
       return Get<half>(indices) == static_cast<half>(0.0f);
+    case BF16:
+      return Get<bfloat16>(indices) == static_cast<bfloat16>(0.0f);
     case PRED:
       return Get<bool>(indices) == false;
     default:
@@ -1377,6 +1436,12 @@ void Literal::Resize<half>(int64 num_elements, half value) {
   mutable_f16s()->resize(num_elements, value);
 }
 
+template <>
+void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value) {
+  CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
+  mutable_bf16s()->resize(num_elements, value);
+}
+
 template <>
 void Literal::Resize<complex64>(int64 num_elements, complex64 value) {
   CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements);
@@ -1425,6 +1490,19 @@ LiteralProto Literal::ToProto() const {
       *proto.mutable_f16s() =
           string(reinterpret_cast<const char*>(f16s_.data()),
                  f16s_.size() * sizeof(half));
+      if (!kLittleEndian) {
+        ConvertEndianShort(const_cast<char*>(proto.mutable_f16s()->data()),
+                           proto.f16s().size());
+      }
+      break;
+    case BF16:
+      *proto.mutable_bf16s() =
+          string(reinterpret_cast<const char*>(bf16s_.data()),
+                 bf16s_.size() * sizeof(bfloat16));
+      if (!kLittleEndian) {
+        ConvertEndianShort(const_cast<char*>(proto.mutable_bf16s()->data()),
+                           proto.bf16s().size());
+      }
       break;
     case F32:
       CopyToRepeatedField(proto.mutable_f32s(), f32s());
@@ -1493,6 +1571,21 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) {
       CHECK_EQ(0, s.size() % sizeof(half));
       f16s_ = std::vector<half>(s.size() / sizeof(half));
       memcpy(f16s_.data(), s.data(), s.size());
+
+      if (!kLittleEndian) {
+        ConvertEndianShort(reinterpret_cast<char*>(f16s_.data()), s.size());
+      }
+      break;
+    }
+    case BF16: {
+      const string& s(literal_proto.bf16s());
+      CHECK_EQ(0, s.size() % sizeof(bfloat16));
+      bf16s_ = std::vector<bfloat16>(s.size() / sizeof(bfloat16));
+      memcpy(bf16s_.data(), s.data(), s.size());
+
+      if (!kLittleEndian) {
+        ConvertEndianShort(reinterpret_cast<char*>(bf16s_.data()), s.size());
+      }
       break;
     }
     case F32:
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 667f926c464..f37e529caf5 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -163,6 +163,11 @@ class Literal {
   const std::vector<complex64>& c64s() const { return c64s_; }
   std::vector<complex64>* mutable_c64s() { return &c64s_; }
 
+  int bf16s_size() const { return bf16s().size(); }
+  bfloat16 bf16s(int i) const { return bf16s_[i]; }
+  const std::vector<bfloat16>& bf16s() const { return bf16s_; }
+  std::vector<bfloat16>* mutable_bf16s() { return &bf16s_; }
+
   int tuple_literals_size() const { return tuple_literals().size(); }
   const Literal& tuple_literals(int i) const { return tuple_literals_[i]; }
   Literal* add_tuple_literals() {
@@ -622,6 +627,7 @@ class Literal {
   std::vector<uint16> u16s_;
   std::vector<uint32> u32s_;
   std::vector<uint64> u64s_;
+  std::vector<bfloat16> bf16s_;
   std::vector<half> f16s_;
   std::vector<float> f32s_;
   std::vector<double> f64s_;
@@ -674,6 +680,9 @@ tensorflow::gtl::ArraySlice<double> Literal::GetArraySlice<double>() const;
 template <>
 tensorflow::gtl::ArraySlice<half> Literal::GetArraySlice<half>() const;
 
+template <>
+tensorflow::gtl::ArraySlice<bfloat16> Literal::GetArraySlice<bfloat16>() const;
+
 template <>
 tensorflow::gtl::ArraySlice<complex64> Literal::GetArraySlice<complex64>()
     const;
@@ -714,6 +723,9 @@ tensorflow::gtl::MutableArraySlice<double> Literal::GetMutableArraySlice();
 template <>
 tensorflow::gtl::MutableArraySlice<half> Literal::GetMutableArraySlice();
 
+template <>
+tensorflow::gtl::MutableArraySlice<bfloat16> Literal::GetMutableArraySlice();
+
 template <>
 tensorflow::gtl::MutableArraySlice<complex64> Literal::GetMutableArraySlice();
 
@@ -747,6 +759,9 @@ void Literal::Resize<double>(int64 num_elements, double value);
 template <>
 void Literal::Resize<half>(int64 num_elements, half value);
 
+template <>
+void Literal::Resize<bfloat16>(int64 num_elements, bfloat16 value);
+
 template <>
 void Literal::Resize<complex64>(int64 num_elements, complex64 value);
 
@@ -990,6 +1005,14 @@ inline half Literal::Get<half>(
   return GetArraySlice<half>()[linear_index];
 }
 
+template <>
+inline bfloat16 Literal::Get<bfloat16>(
+    tensorflow::gtl::ArraySlice<int64> multi_index) const {
+  CHECK(shape().element_type() == BF16);
+  int64 linear_index = LinearIndex(multi_index);
+  return GetArraySlice<bfloat16>()[linear_index];
+}
+
 template <typename NativeT>
 void Literal::Set(tensorflow::gtl::ArraySlice<int64> multi_index,
                   NativeT value) {
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
index 6d596da4ada..1e081017598 100644
--- a/tensorflow/compiler/xla/literal_util_test.cc
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -110,6 +110,18 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
 
   auto c64_lit = Literal::CreateR0<complex64>({3.14f, 2.78f});
   ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString());
+
+  auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
+  ASSERT_EQ("0.5", bf16_lit->ToString());
+
+  // 3.14 will be rounded to 3.125 in bfloat16 format (Round to nearest even).
+  auto bf16_lit_truncated =
+      Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
+  ASSERT_EQ("3.140625", bf16_lit_truncated->ToString());
+
+  auto bf16_lit_truncated2 =
+      Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
+  ASSERT_EQ("9", bf16_lit_truncated2->ToString());
 }
 
 TEST_F(LiteralUtilTest, LiteralVectorToString) {
@@ -397,6 +409,18 @@ TEST_F(LiteralUtilTest, IsAll) {
   EXPECT_FALSE(Literal::CreateR2<half>({{h8}, {h9}})->IsAll(8));
   EXPECT_FALSE(Literal::CreateR2<half>({{h9}, {h8}})->IsAll(8));
 
+  bfloat16 b8(8.0f);
+  bfloat16 b9(9.0f);
+
+  EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b8}, {b8}})->IsAll(8));
+  EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b8}, {b9}})->IsAll(8));
+  EXPECT_FALSE(Literal::CreateR2<bfloat16>({{b9}, {b8}})->IsAll(8));
+
+  // 9.001 will be truncated to 9.0
+  bfloat16 b91(9.001f);
+  bfloat16 b90(9.00f);
+  EXPECT_TRUE(Literal::CreateR2<bfloat16>({{b91}, {b90}})->IsAll(9.0));
+
   complex64 c8_9 = {8, 9};
   EXPECT_FALSE(Literal::CreateR2<complex64>({{c8_9}, {c8_9}})->IsAll(8));
 
@@ -691,6 +715,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) {
   EXPECT_EQ(output, *expected);
 }
 
+TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) {
+  Literal output;
+  bfloat16 h(0.25f);
+  output.PopulateWithValue<bfloat16>(h, {});
+  auto expected = Literal::CreateR0<bfloat16>(h);
+  EXPECT_EQ(output, *expected);
+}
+
+TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) {
+  Literal output;
+  bfloat16 h(0.5f);
+  output.PopulateWithValue<bfloat16>(h, {3});
+  auto expected = Literal::CreateR1<bfloat16>({h, h, h});
+  EXPECT_EQ(output, *expected);
+}
+
+TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) {
+  Literal output;
+  bfloat16 h(2.0f);
+  output.PopulateWithValue<bfloat16>(h, {2, 2});
+  auto expected = Literal::CreateR2<bfloat16>({{h, h}, {h, h}});
+  EXPECT_EQ(output, *expected);
+}
+
 TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
   Literal output;
   output.PopulateWithValue<float>(2.5f, {});
@@ -975,6 +1023,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
     {{half(26.0), half(0.0), half(28.0), half(0.0)},
      {half(0.0), half(31.0), half(0.0), half(33.0)}},
   }}, layout_r4_dim0major_);
+  auto bf16 = Literal::CreateR4WithLayout<bfloat16>({{
+    {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)},
+     {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}},
+    {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)},
+     {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}},
+    {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)},
+     {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}},
+  }}, layout_r4_dim0major_);
   auto f32 = Literal::CreateR4WithLayout<float>({{
     {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}},
     {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}},
@@ -1008,6 +1064,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) {
   conv = s8->Convert(PRED).ConsumeValueOrDie();
   EXPECT_EQ(*conv, *pred);
 
+  conv = bf16->Convert(S32).ConsumeValueOrDie();
+  EXPECT_EQ(*conv, *s32);
+
+  conv = bf16->Convert(F32).ConsumeValueOrDie();
+  EXPECT_EQ(*conv, *f32);
+
   conv = pred->Convert(S32).ConsumeValueOrDie();
   EXPECT_EQ(*conv, *int32_pred);
 
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc
index 2113b5e06f3..2bce56b7bd2 100644
--- a/tensorflow/compiler/xla/primitive_util.cc
+++ b/tensorflow/compiler/xla/primitive_util.cc
@@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType<double>() {
   return F64;
 }
 
+template <>
+PrimitiveType NativeToPrimitiveType<bfloat16>() {
+  return BF16;
+}
+
 template <>
 PrimitiveType NativeToPrimitiveType<half>() {
   return F16;
@@ -89,7 +94,7 @@ PrimitiveType NativeToPrimitiveType<complex64>() {
 }
 
 bool IsFloatingPointType(PrimitiveType type) {
-  return type == F16 || type == F32 || type == F64;
+  return type == F16 || type == F32 || type == F64 || type == BF16;
 }
 
 bool IsComplexType(PrimitiveType type) { return type == C64; }
@@ -118,6 +123,7 @@ int BitWidth(PrimitiveType type) {
     case S16:
     case U16:
     case F16:
+    case BF16:
       return 16;
 
     case U32:
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index a49c8b86fcf..19c6a138885 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -77,6 +77,8 @@ template <>
 PrimitiveType NativeToPrimitiveType<double>();
 template <>
 PrimitiveType NativeToPrimitiveType<half>();
+template <>
+PrimitiveType NativeToPrimitiveType<bfloat16>();
 
 // Complex
 template <>
@@ -167,6 +169,11 @@ struct PrimitiveTypeToNative<F16> {
   using type = half;
 };
 
+template <>
+struct PrimitiveTypeToNative<BF16> {
+  using type = bfloat16;
+};
+
 // Complex
 template <>
 struct PrimitiveTypeToNative<C64> {
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index 9abe30e3f37..05f2d062784 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
+#define EIGEN_USE_THREADS
+
 #include "tensorflow/compiler/xla/service/backend.h"
 
 #include <algorithm>
 #include <string>
 #include <utility>
 
-#define EIGEN_USE_THREADS
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/compiler/xla/service/compiler.h"
 #include "tensorflow/compiler/xla/service/platform_util.h"
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
index f8e260dd901..f385829cdf5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
@@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-
+#define EIGEN_USE_THREADS
 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
 
 #include <memory>
 #include <string>
 #include <tuple>
 
-#define EIGEN_USE_THREADS
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/compiler/xla/array2d.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 88b77ccdd03..a722d1b3d99 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1450,6 +1450,10 @@ HloEvaluator::HloEvaluator() {
   typed_visitors_[F32] = MakeUnique<TypedVisitor<float>>(this);
   typed_visitors_[F64] = MakeUnique<TypedVisitor<double>>(this);
   typed_visitors_[C64] = MakeUnique<TypedVisitor<complex64>>(this);
+
+  typed_visitors_[BF16] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
+    return Unimplemented("HloEvaluator: unhandled primitive type: BF16.");
+  });
   typed_visitors_[TUPLE] = MakeUnique<FunctionVisitor>([](HloInstruction*) {
     return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE.");
   });
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index f463e57d995..158fb9a546c 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#define EIGEN_USE_THREADS
 
 #include "tensorflow/compiler/xla/service/hlo_runner.h"
 
@@ -19,8 +20,6 @@ limitations under the License.
 #include <string>
 #include <utility>
 
-#define EIGEN_USE_THREADS
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/compiler/xla/layout_util.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index b5eb81dfc6a..4d0bafa9087 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -263,6 +263,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
     case S32:
     case S64:
     case F16:
+    case BF16:
     case F32:
     case F64:
       return true;
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
index 95a52ecd2f5..75c9a0d3fb5 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.cc
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -116,16 +116,18 @@ template <typename FloatT, typename UnsignedT>
 ::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
   auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
   auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
+  auto lhs_double = static_cast<double>(lhs);
+  auto rhs_double = static_cast<double>(rhs);
   if (ulhs != urhs) {
     return ::testing::AssertionFailure() << tensorflow::strings::Printf(
                "floating values are not bitwise-equal; and equality testing "
                "was requested: %s=%g=%a vs %s=%g=%a",
                tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs))
                    .c_str(),
-               lhs, lhs,
+               lhs_double, lhs_double,
                tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs))
                    .c_str(),
-               rhs, rhs);
+               rhs_double, rhs_double);
   }
   return ::testing::AssertionSuccess();
 }
@@ -149,6 +151,10 @@ template <typename NativeT>
 // Specializations for floating types that do bitwise comparisons when equality
 // comparison is requested.
 template <>
+::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
+  return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
+}
+template <>
 ::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
   return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
 }
@@ -238,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
     case U64:
       match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
       break;
+    case BF16:
+      match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
+      break;
     case F32:
       match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
       break;
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index c11e1df0a78..d98875dbc20 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
+#define EIGEN_USE_THREADS
 
 #include "tensorflow/compiler/xla/tests/local_client_test_base.h"
 
 #include <vector>
 
-#define EIGEN_USE_THREADS
-
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/map_util.h"
diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h
index 3b19ca321ca..9fa4297523b 100644
--- a/tensorflow/compiler/xla/types.h
+++ b/tensorflow/compiler/xla/types.h
@@ -19,6 +19,7 @@ limitations under the License.
 #include <complex>
 
 #include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/framework/numeric_types.h"
 #include "tensorflow/core/platform/types.h"
 
 #include <Eigen/Core>
@@ -32,6 +33,8 @@ using ::tensorflow::int16;
 using ::tensorflow::int32;
 using ::tensorflow::int64;
 
+using ::tensorflow::bfloat16;
+
 using ::tensorflow::uint8;
 using ::tensorflow::uint16;
 using ::tensorflow::uint32;
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 71466047080..eac8f2ff07e 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -46,6 +46,12 @@ enum PrimitiveType {
   // converted to f16 from f32 at arbirary points in the computation.
   F16 = 10;
   F32 = 11;
+
+  // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
+  // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
+  // and 7 bits for the mantissa.
+  BF16 = 16;
+
   F64 = 12;
 
   // Complex values of fixed width.
@@ -63,6 +69,8 @@ enum PrimitiveType {
   // An opaque type used for passing context specific data to a custom
   // operation.
   OPAQUE = 14;
+
+  // Next = 17
 }
 
 // Describes the value held inside padding elements.
@@ -310,7 +318,10 @@ message LiteralProto {
   repeated double f64s = 9;
   repeated float c64s = 12;  // Stored as interleaved real, imag floats.
   repeated LiteralProto tuple_literals = 10;
-  bytes f16s = 11;  // Note: the F16s are encoded in little endian byte order
+  // The F16s and BF16s are encoded in little endian byte order
+  bytes f16s = 11;
+  bytes bf16s = 13;
+  // Next = 14
 }
 
 message WindowDimension {
diff --git a/tensorflow/core/framework/bfloat16.cc b/tensorflow/core/framework/bfloat16.cc
index a5ac0e1a8df..1a6f355c774 100644
--- a/tensorflow/core/framework/bfloat16.cc
+++ b/tensorflow/core/framework/bfloat16.cc
@@ -18,32 +18,24 @@ limitations under the License.
 namespace tensorflow {
 
 void FloatToBFloat16(const float* src, bfloat16* dst, int64 size) {
-  const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
-  uint16_t* q = reinterpret_cast<uint16_t*>(dst);
-#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
-    for (; size != 0; p += 2, q++, size--) {  
-      *q = p[0];  
-    }  
-#else
-    for (; size != 0; p += 2, q++, size--) {  
-     *q = p[1];  
-    }  
-#endif
+  for (int64 i = 0; i < size; ++i) {
+    dst[i] = bfloat16(src[i]);
+  }
 }
 
 void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size) {
   const uint16_t* p = reinterpret_cast<const uint16_t*>(src);
   uint16_t* q = reinterpret_cast<uint16_t*>(dst);
 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
-    for (; size != 0; p++, q += 2, size--) {  
-      q[0] = *p;  
-      q[1] = 0;  
+  for (; size != 0; p++, q += 2, size--) {
+    q[0] = *p;
+    q[1] = 0;
     }
-#else  
-    for (; size != 0; p++, q += 2, size--) {  
-      q[0] = 0;  
-      q[1] = *p;  
-    } 
+#else
+  for (; size != 0; p++, q += 2, size--) {
+    q[0] = 0;
+    q[1] = *p;
+  }
 #endif
 }
 
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
index af4e6a44116..a25b764ea21 100644
--- a/tensorflow/core/framework/bfloat16_test.cc
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/core/framework/bfloat16.h"
 
+#include "tensorflow/core/lib/core/casts.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/platform/test_benchmark.h"
 
@@ -27,6 +28,97 @@ TEST(Bfloat16Test, Simple) {
   EXPECT_EQ(0x4140, a.value);
 }
 
+float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
+                    uint32_t low_mantissa) {
+  return bit_cast<float>((sign << 31) + (exponent << 23) +
+                         (high_mantissa << 16) + low_mantissa);
+}
+
+struct Bfloat16TestParam {
+  float input;
+  float expected;
+};
+
+class Bfloat16Test : public ::testing::Test,
+                     public ::testing::WithParamInterface<Bfloat16TestParam> {};
+
+TEST_P(Bfloat16Test, RoundOrTruncate) {
+  bfloat16 a(GetParam().input);
+  if (std::isnan(GetParam().input)) {
+    EXPECT_TRUE(std::isnan(float(a)));
+    return;
+  }
+  EXPECT_EQ(GetParam().expected, float(a));
+}
+
+INSTANTIATE_TEST_CASE_P(
+    Bfloat16Test_Instantiation, Bfloat16Test,
+    ::testing::Values(
+        // More than half.
+        Bfloat16TestParam{
+            BinaryToFloat(0, 0b10000000, 0b1001000, 0b1111010111000011),
+            BinaryToFloat(0, 0b10000000, 0b1001001, 0b0000000000000000)},
+
+        Bfloat16TestParam{
+            BinaryToFloat(1, 0b10000000, 0b1001000, 0b1111010111000011),
+            BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
+
+        // Exact half.
+        Bfloat16TestParam{
+            BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
+            BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
+
+        // NaN stays at NaN.
+        Bfloat16TestParam{
+            BinaryToFloat(0, 0b11111111, 0b0000000, 0b0000000000000001),
+            BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
+
+        // NaN stays at NaN -- no exponents overflow.
+        Bfloat16TestParam{
+            BinaryToFloat(0, 0b11111111, 0b1111111, 0b1111111111111111),
+            BinaryToFloat(0, 0b11111111, 0b1000000, 0b0000000000000000)},
+
+        // More than half, round to an odd number.
+        Bfloat16TestParam{
+            BinaryToFloat(1, 0b10000000, 0b1001000, 0b1100000000000000),
+            BinaryToFloat(1, 0b10000000, 0b1001001, 0b0000000000000000)},
+
+        // Less than half, truncate.
+        Bfloat16TestParam{
+            BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000),
+            BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
+
+        // Less than half, truncate.
+        Bfloat16TestParam{
+            BinaryToFloat(0, 0b10000000, 0b1001000, 0b0100000000000000),
+            BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
+
+        // Exact at half, but result is already even.
+        Bfloat16TestParam{
+            BinaryToFloat(0, 0b10000000, 0b1001000, 0b1000000000000000),
+            BinaryToFloat(0, 0b10000000, 0b1001000, 0b0000000000000000)},
+
+        // Denormal values.
+        Bfloat16TestParam{
+            BinaryToFloat(0, 0b00000000, 0b1001000, 0b1000000000000000),
+            BinaryToFloat(0, 0b00000000, 0b1001000, 0b0000000000000000)},
+        Bfloat16TestParam{
+            BinaryToFloat(0, 0b00000000, 0b1111111, 0b1100000000000000),
+            BinaryToFloat(0, 0b00000001, 0b0000000, 0b0000000000000000)}));
+TEST(Bfloat16Test, RoundWithFractionOverflow) {
+  // Still works with fraction overflow -- round to 4./
+  //
+  // Input 3.9960938:
+  // Sign |  Exp (8 bit) | Frac (first 7 bit) | Frac (last 16 bit)
+  //  0     1 0 0 0 0 0 0      1 1 1 1 1 1 1     1100000000000000
+  //
+  // Should round to 4.0:
+  // Sign |  Exp (8 bit)  | Frac (first 7 bit)
+  //  0     1 0 0 0 0 0 1      0 0 0 0 0 0 0
+  bfloat16 a(3.9960938f);
+  EXPECT_EQ(4.0, float(a));
+}
+
 TEST(Bfloat16Test, Conversion) {
   float a[100];
   for (int i = 0; i < 100; ++i) {
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index a630bee38d8..d005de2af1e 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -44,29 +44,262 @@ typedef Eigen::QUInt16 quint16;
 // see framework/bfloat16.h for description.
 struct bfloat16 {
   EIGEN_DEVICE_FUNC bfloat16() {}
-  EIGEN_DEVICE_FUNC explicit bfloat16(const float v) {
-    const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
+
+  explicit EIGEN_DEVICE_FUNC bfloat16(float v) {
+    uint32_t input;
+    memcpy(&input, &v, sizeof(uint32_t));
+
+    if ((~input & 0x7f800000) == 0 && (input & 0x007fffff) != 0) {
+      // If the value is a NaN, squash it to a qNaN with msb of fraction set,
+      // this makes sure after truncation we don't end up with an inf.
+      //
+      // qNaN magic: All exponent bits set + most significant bit of fraction
+      // set.
+      value = 0x7fc0;
+    } else {
+      // Fast rounding algorithm that rounds a half value to nearest even. This
+      // reduces expected error when we convert a large number of floats. Here
+      // is how it works:
+      //
+      // Definitions:
+      // To convert a float 32 to bfloat16, a float 32 can be viewed as 32 bits
+      // with the following tags:
+      //
+      // Sign |  Exp (8 bits) | Frac (23 bits)
+      //  S     EEEEEEEE         FFFFFFLRTTTTTTTTTTTTTTT
+      //
+      //  S: Sign bit.
+      //  E: Exponent bits.
+      //  F: First 6 bits of fraction.
+      //  L: Least significant bit of resulting bfloat16 if we truncate away the
+      //  rest of the float32. This is also the 7th bit of fraction
+      //  R: Rounding bit, 8th bit of fraction.
+      //  T: Sticky bits, rest of fraction, 15 bits.
+      //
+      // To round half to nearest even, there are 3 cases where we want to round
+      // down (simply truncate the result of the bits away, which consists of
+      // rounding bit and sticky bits) and two cases where we want to round up
+      // (truncate then add one to the result).
+      //
+      // The fast converting algorithm simply adds lsb (L) to 0x7fff (15 bits of
+      // 1s) as the rounding bias, adds the rounding bias to the input, then
+      // truncates the last 16 bits away.
+      //
+      // To understand how it works, we can analyze this algorithm case by case:
+      //
+      // 1. L = 0, R = 0:
+      //   Expect: round down, this is less than half value.
+      //
+      //   Algorithm:
+      //   - Rounding bias: 0x7fff + 0 = 0x7fff
+      //   - Adding rounding bias to input may create any carry, depending on
+      //   whether there is any value set to 1 in T bits.
+      //   - R may be set to 1 if there is a carry.
+      //   - L remains 0.
+      //   - Note that this case also handles Inf and -Inf, where all fraction
+      //   bits, including L, R and Ts are all 0. The output remains Inf after
+      //   this algorithm.
+      //
+      // 2. L = 1, R = 0:
+      //   Expect: round down, this is less than half value.
+      //
+      //   Algorithm:
+      //   - Rounding bias: 0x7fff + 1 = 0x8000
+      //   - Adding rounding bias to input doesn't change sticky bits but
+      //   adds 1 to rounding bit.
+      //   - L remains 1.
+      //
+      // 3. L = 0, R = 1, all of T are 0:
+      //   Expect: round down, this is exactly at half, the result is already
+      //   even (L=0).
+      //
+      //   Algorithm:
+      //   - Rounding bias: 0x7fff + 0 = 0x7fff
+      //   - Adding rounding bias to input sets all sticky bits to 1, but
+      //   doesn't create a carry.
+      //   - R remains 1.
+      //   - L remains 0.
+      //
+      // 4. L = 1, R = 1:
+      //   Expect: round up, this is exactly at half, the result needs to be
+      //   round to the next even number.
+      //
+      //   Algorithm:
+      //   - Rounding bias: 0x7fff + 1 = 0x8000
+      //   - Adding rounding bias to input doesn't change sticky bits, but
+      //   creates a carry from rounding bit.
+      //   - The carry sets L to 0, creates another carry bit and propagate
+      //   forward to F bits.
+      //   - If all the F bits are 1, a carry then propagates to the exponent
+      //   bits, which then creates the minimum value with the next exponent
+      //   value. Note that we won't have the case where exponents are all 1,
+      //   since that's either a NaN (handled in the other if condition) or inf
+      //   (handled in case 1).
+      //
+      // 5. L = 0, R = 1, any of T is 1:
+      //   Expect: round up, this is greater than half.
+      //
+      //   Algorithm:
+      //   - Rounding bias: 0x7fff + 0 = 0x7fff
+      //   - Adding rounding bias to input creates a carry from sticky bits,
+      //   sets rounding bit to 0, then create another carry.
+      //   - The second carry sets L to 1.
+      //
+      // Examples:
+      //
+      //  Exact half value that is already even:
+      //    Input:
+      //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
+      //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
+      //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0     1000000000000000
+      //
+      //     This falls into case 3. We truncate the rest of 16 bits and no
+      //     carry is created into F and L:
+      //
+      //    Output:
+      //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
+      //     S     E E E E E E E E      F F F F F F L
+      //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0
+      //
+      //  Exact half value, round to next even number:
+      //    Input:
+      //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
+      //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
+      //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 0 1     1000000000000000
+      //
+      //     This falls into case 4. We create a carry from R and T,
+      //     which then propagates into L and F:
+      //
+      //    Output:
+      //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
+      //     S     E E E E E E E E      F F F F F F L
+      //     0     0 0 0 0 0 0 0 0      0 0 0 0 0 1 0
+      //
+      //
+      //  Max denormal value round to min normal value:
+      //    Input:
+      //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
+      //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
+      //     0     0 0 0 0 0 0 0 0      1 1 1 1 1 1 1     1111111111111111
+      //
+      //     This falls into case 4. We create a carry from R and T,
+      //     propagate into L and F, which then propagates into exponent
+      //     bits:
+      //
+      //    Output:
+      //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
+      //     S     E E E E E E E E      F F F F F F L
+      //     0     0 0 0 0 0 0 0 1      0 0 0 0 0 0 0
+      //
+      //  Max normal value round to Inf:
+      //    Input:
+      //    Sign |  Exp (8 bit)     | Frac (first 7 bit) | Frac (last 16 bit)
+      //     S     E E E E E E E E      F F F F F F L     RTTTTTTTTTTTTTTT
+      //     0     1 1 1 1 1 1 1 0      1 1 1 1 1 1 1     1111111111111111
+      //
+      //     This falls into case 4. We create a carry from R and T,
+      //     propagate into L and F, which then propagates into exponent
+      //     bits:
+      //
+      //    Sign |  Exp (8 bit)     | Frac (first 7 bit)
+      //     S     E E E E E E E E      F F F F F F L
+      //     0     1 1 1 1 1 1 1 1      0 0 0 0 0 0 0
+      //
+      //
+      // Least significant bit of resulting bfloat.
+      uint32_t lsb = (input >> 16) & 1;
+      uint32_t rounding_bias = 0x7fff + lsb;
+      input += rounding_bias;
+      value = static_cast<uint16_t>(input >> 16);
+    }
+  }
+
+  template <class T>
+  explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
+      : bfloat16(static_cast<float>(val)) {}
+
+  EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
+    float result;
+
+    uint16_t* q = reinterpret_cast<uint16_t*>(&result);
+
 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
-    value = p[0];
+    q[0] = value;
+    q[1] = 0;
 #else
-    value = p[1];
+    q[0] = 0;
+    q[1] = value;
 #endif
+    return result;
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator bool() const {
+    return static_cast<bool>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator Eigen::half() const {
+    return static_cast<Eigen::half>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator short() const {
+    return static_cast<short>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator int() const {
+    return static_cast<int>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator char() const {
+    return static_cast<char>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator signed char() const {
+    return static_cast<signed char>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator unsigned char() const {
+    return static_cast<unsigned char>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator unsigned int() const {
+    return static_cast<unsigned int>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator unsigned long() const {
+    return static_cast<unsigned long>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator unsigned long long() const {
+    return static_cast<unsigned long long>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator long long() const {
+    return static_cast<long long>(float(*this));
+  }
+
+  EIGEN_DEVICE_FUNC explicit operator double() const {
+    return static_cast<double>(float(*this));
   }
 
   uint16_t value;
 };
 
+inline bool operator==(const bfloat16 a, const bfloat16 b) {
+  return a.value == b.value;
+}
+
+inline bool operator!=(const bfloat16 a, const bfloat16 b) {
+  return a.value != b.value;
+}
+
 }  // end namespace tensorflow
 
 namespace Eigen {
 template <>
 struct NumTraits<tensorflow::bfloat16> : GenericNumTraits<uint16_t> {};
 
-EIGEN_STRONG_INLINE bool operator==(const tensorflow::bfloat16 a,
-                                    const tensorflow::bfloat16 b) {
-  return a.value == b.value;
-}
-
+using ::tensorflow::operator==;
+using ::tensorflow::operator!=;
 }  // namespace Eigen
 
 #ifdef COMPILER_MSVC

From 3c41cb6bff409f37e35d2e1b2619d5dc6742dbe5 Mon Sep 17 00:00:00 2001
From: Saurabh Saxena <srbs@google.com>
Date: Thu, 9 Nov 2017 21:01:00 -0800
Subject: [PATCH 110/115] Get rid of IteratorBase::is_exhausted flag since it
 is not possible to rely on it unless we lock each call to GetNext which is
 not preferable. Each iterator now handles saving/restoring exhausted state.
 As a guideline, we always reset the input_impl(s) when they get exhausted.
 This can be used as an indicator of exhausted-ness for non-terminal
 iterators. Also reduces memory overhead. Each iterator should also handle
 calls to GetNextInternal when it is exhausted. Fixed this for some datasets.
 Also fix a bug in dataset_serialization_test_base. We were not saving a
 checkpoint after exhausting the iterator so verify_exhausted_iterator was not
 really testing restoring an exhausted iterator.

PiperOrigin-RevId: 175253023
---
 .../dataset_serialization_test_base.py        |  4 +--
 tensorflow/core/kernels/batch_dataset_op.cc   | 21 +++++++++++--
 .../core/kernels/concatenate_dataset_op.cc    |  8 ++++-
 tensorflow/core/kernels/dataset.cc            |  1 -
 tensorflow/core/kernels/dataset.h             | 23 ++------------
 tensorflow/core/kernels/range_dataset_op.cc   |  1 -
 tensorflow/core/kernels/reader_dataset_ops.cc |  1 -
 tensorflow/core/kernels/repeat_dataset_op.cc  | 18 +++++++++--
 tensorflow/core/kernels/shuffle_dataset_op.cc | 31 ++++++++++---------
 tensorflow/core/kernels/skip_dataset_op.cc    | 21 +++++++++++--
 tensorflow/core/kernels/take_dataset_op.cc    | 17 ++++++++--
 tensorflow/core/kernels/zip_dataset_op.cc     | 31 +++++++++++++++----
 12 files changed, 120 insertions(+), 57 deletions(-)

diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
index 369b789a521..07fecf04fae 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
@@ -337,11 +337,11 @@ class DatasetSerializationTestBase(test.TestCase):
           num_iters = end - start
           for _ in range(num_iters):
             outputs.append(sess.run(get_next_op))
-          self._save(sess, saver)
-          ckpt_saved = True
           if i == len(break_points) and verify_exhausted:
             with self.assertRaises(errors.OutOfRangeError):
               sess.run(get_next_op)
+          self._save(sess, saver)
+          ckpt_saved = True
 
     return outputs
 
diff --git a/tensorflow/core/kernels/batch_dataset_op.cc b/tensorflow/core/kernels/batch_dataset_op.cc
index 2e52ad39f8e..6a5fd17a9e6 100644
--- a/tensorflow/core/kernels/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/batch_dataset_op.cc
@@ -143,9 +143,13 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
         // Each row of `batch_elements` is a tuple of tensors from the
         // input iterator.
         std::vector<std::vector<Tensor>> batch_elements;
-        batch_elements.reserve(dataset()->batch_size_);
         {
           mutex_lock l(mu_);
+          if (!input_impl_) {
+            *end_of_sequence = true;
+            return Status::OK();
+          }
+          batch_elements.reserve(dataset()->batch_size_);
           *end_of_sequence = false;
           for (int i = 0; i < dataset()->batch_size_ && !*end_of_sequence;
                ++i) {
@@ -154,6 +158,8 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
                                                     end_of_sequence));
             if (!*end_of_sequence) {
               batch_elements.emplace_back(std::move(batch_element_tuple));
+            } else {
+              input_impl_.reset();
             }
           }
         }
@@ -194,14 +200,23 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
      protected:
       Status SaveInternal(IteratorStateWriter* writer) override {
         mutex_lock l(mu_);
-        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        if (!input_impl_) {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("input_impl_empty"), ""));
+        } else {
+          TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        }
         return Status::OK();
       }
 
       Status RestoreInternal(OpKernelContext* ctx,
                              IteratorStateReader* reader) override {
         mutex_lock l(mu_);
-        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        if (!reader->Contains(full_name("input_impl_empty"))) {
+          TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        } else {
+          input_impl_.reset();
+        }
         return Status::OK();
       }
 
diff --git a/tensorflow/core/kernels/concatenate_dataset_op.cc b/tensorflow/core/kernels/concatenate_dataset_op.cc
index 711c234129f..c3bd89c479f 100644
--- a/tensorflow/core/kernels/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/concatenate_dataset_op.cc
@@ -104,6 +104,10 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
         mutex_lock l(mu_);
+        if (!input_impl_) {
+          *end_of_sequence = true;
+          return Status::OK();
+        }
         while (i_ < 2) {
           TF_RETURN_IF_ERROR(
               input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -140,7 +144,9 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
         } else if (i_ == 2) {
           input_impl_.reset();
         }
-        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        if (input_impl_) {
+          TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        }
         return Status::OK();
       }
 
diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc
index 0414875a5d5..fcfa2956f78 100644
--- a/tensorflow/core/kernels/dataset.cc
+++ b/tensorflow/core/kernels/dataset.cc
@@ -126,7 +126,6 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
   MakeDataset(ctx, input, another_input, output);
 }
 
-const char IteratorBase::kIteratorExhausted[] = "ITERATOR_EXHAUSTED";
 const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
 const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
     "_DATASET_GRAPH_OUTPUT_NODE";
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index 4a42ac80c37..aa4f436b390 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -306,27 +306,14 @@ class IteratorBase {
 
   // Saves the state of this iterator.
   virtual Status Save(IteratorStateWriter* writer) {
-    if (is_exhausted_) {
-      LOG(INFO) << "Iterator exhausted.";
-      return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
-    } else {
-      return SaveInternal(writer);
-    }
+    return SaveInternal(writer);
   }
 
   // Restores the state of this iterator.
   virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
-    if (reader->Contains(kIteratorExhausted)) {
-      LOG(INFO) << "Iterator exhausted. Nothing to restore.";
-      is_exhausted_ = true;
-      return Status::OK();
-    } else {
-      return RestoreInternal(ctx, reader);
-    }
+    return RestoreInternal(ctx, reader);
   }
 
-  static const char kIteratorExhausted[];
-
  protected:
   // This is needed so that sub-classes of IteratorBase can call
   // `SaveInternal` on their parent iterators, e.g., in
@@ -354,8 +341,6 @@ class IteratorBase {
                                  IteratorStateReader* reader) {
     return errors::Unimplemented("RestoreInternal");
   }
-
-  bool is_exhausted_ = false;  // Whether the iterator has been exhausted.
 };
 
 // Represents a (potentially infinite) range of outputs, where each
@@ -491,10 +476,6 @@ class DatasetIterator : public IteratorBase {
   Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
                  bool* end_of_sequence) final {
     port::Tracing::TraceMe activity(params_.prefix);
-    if (is_exhausted_) {
-      *end_of_sequence = true;
-      return Status::OK();
-    }
     return GetNextInternal(ctx, out_tensors, end_of_sequence);
   }
 
diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc
index 7adfcc4f8d2..e7ae840fc7d 100644
--- a/tensorflow/core/kernels/range_dataset_op.cc
+++ b/tensorflow/core/kernels/range_dataset_op.cc
@@ -99,7 +99,6 @@ class RangeDatasetOp : public DatasetOpKernel {
         if ((dataset()->step_ > 0 && next_ >= dataset()->stop_) ||
             (dataset()->step_ < 0 && next_ <= dataset()->stop_)) {
           *end_of_sequence = true;
-          is_exhausted_ = true;
           return Status::OK();
         }
         Tensor value_tensor(cpu_allocator(), DT_INT64, {});
diff --git a/tensorflow/core/kernels/reader_dataset_ops.cc b/tensorflow/core/kernels/reader_dataset_ops.cc
index 39ef92a5dec..c08e42be1d9 100644
--- a/tensorflow/core/kernels/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/reader_dataset_ops.cc
@@ -402,7 +402,6 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
           // Iteration ends when there are no more files to process.
           if (current_file_index_ == dataset()->filenames_.size()) {
             *end_of_sequence = true;
-            is_exhausted_ = true;
             return Status::OK();
           }
 
diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc
index 6c0f4118e6d..0167b9ea64b 100644
--- a/tensorflow/core/kernels/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/repeat_dataset_op.cc
@@ -117,6 +117,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
         mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
+        if (!input_impl_) {
+          *end_of_sequence = true;
+          return Status::OK();
+        }
         while (i_ < dataset()->count_) {
           TF_RETURN_IF_ERROR(
               input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -127,7 +131,6 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
           input_impl_ = dataset()->input_->MakeIterator(prefix());
         }
         *end_of_sequence = true;
-        is_exhausted_ = true;
         input_impl_.reset();
         return Status::OK();
       }
@@ -136,7 +139,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
       Status SaveInternal(IteratorStateWriter* writer) override {
         mutex_lock l(mu_);
         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
-        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        if (!input_impl_) {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("input_impl_empty"), ""));
+        } else {
+          TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        }
         return Status::OK();
       }
 
@@ -144,7 +152,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
                              IteratorStateReader* reader) override {
         mutex_lock l(mu_);
         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
-        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        if (!reader->Contains(full_name("input_impl_empty"))) {
+          TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        } else {
+          input_impl_.reset();
+        }
         return Status::OK();
       }
 
diff --git a/tensorflow/core/kernels/shuffle_dataset_op.cc b/tensorflow/core/kernels/shuffle_dataset_op.cc
index 2146ba2aa17..dd0ab57e9dc 100644
--- a/tensorflow/core/kernels/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/shuffle_dataset_op.cc
@@ -105,8 +105,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
         mutex_lock l(mu_);
         int64 start_micros = ctx->env()->NowMicros();
         int64 num_log_entries = 0;
-        while (!end_of_input_sequence_ &&
-               buffer_.size() < dataset()->buffer_size_) {
+        while (input_impl_ && buffer_.size() < dataset()->buffer_size_) {
           if (ctx->env()->NowMicros() >
               ((num_log_entries + 1) * kLogIntervalMicros) + start_micros) {
             num_log_entries++;
@@ -114,9 +113,10 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
                       << buffer_.size() << " of " << dataset()->buffer_size_;
           }
           std::vector<Tensor> input_element;
+          bool end_of_input_sequence;
           TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &input_element,
-                                                  &end_of_input_sequence_));
-          if (!end_of_input_sequence_) {
+                                                  &end_of_input_sequence));
+          if (!end_of_input_sequence) {
             buffer_.emplace_back(std::move(input_element));
           } else {
             input_impl_.reset();
@@ -135,7 +135,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
           std::swap(buffer_[index], buffer_.back());
           buffer_.pop_back();
         } else {
-          DCHECK(end_of_input_sequence_);
+          DCHECK(input_impl_ == nullptr);
           *end_of_sequence = true;
         }
         return Status::OK();
@@ -148,11 +148,11 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
         // Save the tensors in the buffer.
         TF_RETURN_IF_ERROR(
             writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
-        for (int i = 0; i < buffer_.size(); i++) {
+        for (size_t i = 0; i < buffer_.size(); i++) {
           TF_RETURN_IF_ERROR(writer->WriteScalar(
               full_name(strings::StrCat("buffer_", i, "_size")),
               buffer_[i].size()));
-          for (int j = 0; j < buffer_[i].size(); j++) {
+          for (size_t j = 0; j < buffer_[i].size(); j++) {
             TF_RETURN_IF_ERROR(writer->WriteTensor(
                 full_name(strings::StrCat("buffer_", i, "_", j)),
                 buffer_[i][j]));
@@ -165,7 +165,7 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
 
         // Save input iterator if it hasn't been exhausted else write
         // "end_of_input_sequence".
-        if (end_of_input_sequence_) {
+        if (!input_impl_) {
           TF_RETURN_IF_ERROR(
               writer->WriteScalar(full_name("end_of_input_sequence"), ""));
         } else {
@@ -180,10 +180,15 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
         buffer_.clear();
 
         // Restore the buffer.
-        int64 buffer_size;
-        TF_RETURN_IF_ERROR(
-            reader->ReadScalar(full_name("buffer_size"), &buffer_size));
-        for (int i = 0; i < buffer_size; i++) {
+        size_t buffer_size;
+        {
+          int64 temp;
+          TF_RETURN_IF_ERROR(
+              reader->ReadScalar(full_name("buffer_size"), &temp));
+          buffer_size = static_cast<size_t>(temp);
+        }
+        buffer_.reserve(buffer_size);
+        for (size_t i = 0; i < buffer_size; i++) {
           int64 list_size;
           TF_RETURN_IF_ERROR(reader->ReadScalar(
               full_name(strings::StrCat("buffer_", i, "_size")), &list_size));
@@ -205,7 +210,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
           input_impl_ = dataset()->input_->MakeIterator(prefix());
           TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
         } else {
-          end_of_input_sequence_ = true;
           input_impl_.reset();
         }
         return Status::OK();
@@ -230,7 +234,6 @@ class ShuffleDatasetOp : public UnaryDatasetOpKernel {
       mutex mu_;
       std::vector<std::vector<Tensor>> buffer_ GUARDED_BY(mu_);
       std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
-      bool end_of_input_sequence_ GUARDED_BY(mu_) = false;
       const int64 seed_ GUARDED_BY(mu_);
       const int64 seed2_ GUARDED_BY(mu_);
       random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/skip_dataset_op.cc b/tensorflow/core/kernels/skip_dataset_op.cc
index 05152db1ae2..7ee945dd4c4 100644
--- a/tensorflow/core/kernels/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/skip_dataset_op.cc
@@ -118,6 +118,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
                              bool* end_of_sequence) override {
         mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
 
+        if (!input_impl_) {
+          *end_of_sequence = true;
+          return Status::OK();
+        }
+
         // Keep calling GetNext().  TODO(vrv): Figure out a way to
         // skip records without reading, perhaps by adding an
         // interface to iterator.
@@ -138,6 +143,9 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
         // Return GetNext() on the underlying iterator.
         TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, out_tensors,
                                                 end_of_sequence));
+        if (*end_of_sequence) {
+          input_impl_.reset();
+        }
         return Status::OK();
       }
 
@@ -145,7 +153,12 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
       Status SaveInternal(IteratorStateWriter* writer) override {
         mutex_lock l(mu_);
         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
-        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        if (input_impl_) {
+          TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        } else {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("input_impl_empty"), ""));
+        }
         return Status::OK();
       }
 
@@ -153,7 +166,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
                              IteratorStateReader* reader) override {
         mutex_lock l(mu_);
         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
-        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        if (!reader->Contains(full_name("input_impl_empty"))) {
+          TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        } else {
+          input_impl_.reset();
+        }
         return Status::OK();
       }
 
diff --git a/tensorflow/core/kernels/take_dataset_op.cc b/tensorflow/core/kernels/take_dataset_op.cc
index f9f675abdae..fb294a96b15 100644
--- a/tensorflow/core/kernels/take_dataset_op.cc
+++ b/tensorflow/core/kernels/take_dataset_op.cc
@@ -118,6 +118,10 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
         mutex_lock l(mu_);  // TODO(mrry): Make locking less conservative.
+        if (!input_impl_) {
+          *end_of_sequence = true;
+          return Status::OK();
+        }
         while (i_ < dataset()->count_) {
           TF_RETURN_IF_ERROR(
               input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -136,7 +140,12 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
       Status SaveInternal(IteratorStateWriter* writer) override {
         mutex_lock l(mu_);
         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
-        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        if (input_impl_) {
+          TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        } else {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("input_impl_empty"), ""));
+        }
         return Status::OK();
       }
 
@@ -144,7 +153,11 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
                              IteratorStateReader* reader) override {
         mutex_lock l(mu_);
         TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
-        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        if (!reader->Contains(full_name("input_impl_empty"))) {
+          TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        } else {
+          input_impl_.reset();
+        }
         return Status::OK();
       }
 
diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc
index 30d64ea6343..f466c8b268d 100644
--- a/tensorflow/core/kernels/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/zip_dataset_op.cc
@@ -109,6 +109,10 @@ class ZipDatasetOp : public DatasetOpKernel {
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
         mutex_lock l(mu_);
+        if (input_impls_.empty()) {
+          *end_of_sequence = true;
+          return Status::OK();
+        }
         out_tensors->clear();
         out_tensors->reserve(dataset()->output_dtypes().size());
         for (const auto& input_impl : input_impls_) {
@@ -116,28 +120,43 @@ class ZipDatasetOp : public DatasetOpKernel {
           TF_RETURN_IF_ERROR(
               input_impl->GetNext(ctx, &input_tensors, end_of_sequence));
           if (*end_of_sequence) {
-            return Status::OK();
+            break;
           }
           out_tensors->insert(out_tensors->end(), input_tensors.begin(),
                               input_tensors.end());
         }
-        *end_of_sequence = false;
+        if (*end_of_sequence) {
+          out_tensors->clear();
+          input_impls_.clear();
+        } else {
+          *end_of_sequence = false;
+        }
         return Status::OK();
       }
 
      protected:
       Status SaveInternal(IteratorStateWriter* writer) override {
         mutex_lock l(mu_);
-        for (auto& input_impl : input_impls_)
-          TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
+        if (input_impls_.empty()) {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("input_impls_empty"), ""));
+        } else {
+          for (auto& input_impl : input_impls_)
+            TF_RETURN_IF_ERROR(SaveParent(writer, input_impl));
+        }
         return Status::OK();
       }
 
       Status RestoreInternal(OpKernelContext* ctx,
                              IteratorStateReader* reader) override {
         mutex_lock l(mu_);
-        for (auto& input_impl : input_impls_)
-          TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
+        if (reader->Contains(full_name("input_impls_empty"))) {
+          input_impls_.clear();
+        } else {
+          DCHECK_EQ(input_impls_.size(), dataset()->inputs_.size());
+          for (auto& input_impl : input_impls_)
+            TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl));
+        }
         return Status::OK();
       }
 

From 23dc70389b3bf51886156de88fae6b922619a6ff Mon Sep 17 00:00:00 2001
From: Sanjoy Das <sanjoy@google.com>
Date: Thu, 9 Nov 2017 22:39:33 -0800
Subject: [PATCH 111/115] [XLA:CPU] Make one of the tile dimensions in the LLVM
 IR GEMV tunable.

The tiling dimension corresponding to the number of vector registers in the tile
can be changed easily.  Expose this value as a backend specific flag so that we
can experiment with it to find a good default value.

This CL also fixes a bug exposed by a variable tiling factor in the row major
GEMV implementation.  This wasn't caught before because having tile_rows ==
tile_cols hides the bug.

PiperOrigin-RevId: 175258553
---
 tensorflow/compiler/xla/service/cpu/BUILD        |  2 ++
 .../compiler/xla/service/cpu/cpu_options.cc      | 16 ++++++++++++++++
 .../compiler/xla/service/cpu/cpu_options.h       |  2 ++
 .../compiler/xla/service/cpu/dot_op_emitter.cc   | 11 ++++++++---
 .../compiler/xla/service/cpu/dot_op_emitter.h    |  9 +++++++++
 5 files changed, 37 insertions(+), 3 deletions(-)

diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 10ec677e2f2..4f6e69ebd4e 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -280,6 +280,7 @@ cc_library(
     srcs = ["dot_op_emitter.cc"],
     hdrs = ["dot_op_emitter.h"],
     deps = [
+        ":cpu_options",
         ":cpu_runtime",
         ":ir_emission_utils",
         "//tensorflow/compiler/xla:shape_util",
@@ -719,6 +720,7 @@ cc_library(
     hdrs = ["cpu_options.h"],
     deps = [
         "//tensorflow/compiler/xla/service:hlo_module_config",
+        "//tensorflow/core:lib",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index dba140d1120..09f028463af 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -15,11 +15,14 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
 
+#include "tensorflow/core/lib/strings/numbers.h"
+
 namespace {
 
 const char* const kXlaParallelCpuOption = "xla_cpu_parallel";
 const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
 const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce";
+const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
 
 }  // namespace
 
@@ -45,6 +48,19 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) {
   return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0;
 }
 
+tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
+    const HloModuleConfig& config) {
+  const auto& extra_options_map =
+      config.debug_options().xla_backend_extra_options();
+  auto it = extra_options_map.find(kLlvmIrDotTilingFactor);
+  int64 tiling_factor;
+  if (it != extra_options_map.end() &&
+      tensorflow::strings::safe_strto64(it->second, &tiling_factor)) {
+    return tiling_factor;
+  }
+  return tensorflow::gtl::nullopt;
+}
+
 }  // namespace options
 }  // namespace cpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index 5dc24ebc7b8..6ba0fd24538 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -27,6 +27,8 @@ namespace options {
 bool CpuParallelBackendRequested(const HloModuleConfig& config);
 bool OptimizeForSizeRequested(const HloModuleConfig& config);
 bool VectorizedReduceDisabled(const HloModuleConfig& config);
+tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
+    const HloModuleConfig& config);
 
 }  // namespace options
 }  // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 1cbd4094a35..2a447a54b01 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -366,7 +366,7 @@ class RowMajorMatrixVectorProductEmitter {
         result_(result),
         ir_builder_(ir_builder),
         ksl_(ir_builder_),
-        vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") {
+        vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") {
     CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast<uint64>(tile_cols_)));
   }
 
@@ -573,11 +573,15 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
     return false;
   }
 
+  int64 tiling_factor = GetGemvTilingFactor();
+  CHECK_GT(tiling_factor, 0);
+
   if (is_column_major_matrix_vector) {
     VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
             << " and k = " << k;
     ColumnMajorMatrixVectorProductEmitter emitter(
-        dot_.shape().element_type(), 8, 8, m, k,
+        dot_.shape().element_type(), /*tile_rows=*/8,
+        /*tile_cols=*/tiling_factor, m, k,
         swap_operands ? rhs_array_.GetBasePointer()
                       : lhs_array_.GetBasePointer(),
         swap_operands ? lhs_array_.GetBasePointer()
@@ -588,7 +592,8 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
     VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
             << " and k = " << k;
     RowMajorMatrixVectorProductEmitter emitter(
-        dot_.shape().element_type(), 8, 8, m, k,
+        dot_.shape().element_type(), /*tile_rows=*/tiling_factor,
+        /*tile_cols=*/8, m, k,
         swap_operands ? rhs_array_.GetBasePointer()
                       : lhs_array_.GetBasePointer(),
         swap_operands ? lhs_array_.GetBasePointer()
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index 182e1b8c680..470bf6ffb4c 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -17,6 +17,7 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
 
 #include "llvm/IR/IRBuilder.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
@@ -105,6 +106,14 @@ class DotOpEmitter {
   // of rank 2 as well).
   MatMultDims GetMatMultDims() const;
 
+  // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector
+  // registers.
+  int64 GetGemvTilingFactor() const {
+    const int64 kDefaultTilingFactor = 8;
+    return options::LlvmIrGemvTilingFactor(hlo_module_config_)
+        .value_or(kDefaultTilingFactor);
+  }
+
   const HloInstruction& dot_;
   const bool transpose_lhs_;
   const bool transpose_rhs_;

From f6931a687874190bb6f5cbc927da2bdc97a18b38 Mon Sep 17 00:00:00 2001
From: Asim Shankar <ashankar@google.com>
Date: Thu, 9 Nov 2017 23:29:56 -0800
Subject: [PATCH 112/115] Java/OS X: Workaround for how the framework library
 is packaged in the .jar by the release process.

See #13872

PiperOrigin-RevId: 175261983
---
 .../java/org/tensorflow/NativeLibrary.java    | 43 ++++++++++++++-----
 1 file changed, 32 insertions(+), 11 deletions(-)

diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
index 2b431eebf5f..499757e8cf4 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
@@ -43,7 +43,6 @@ final class NativeLibrary {
   private static final boolean DEBUG =
       System.getProperty("org.tensorflow.NativeLibrary.DEBUG") != null;
   private static final String JNI_LIBNAME = "tensorflow_jni";
-  private static final String FRAMEWORK_LIBNAME = "tensorflow_framework";
 
   public static void load() {
     if (isLoaded() || tryLoadLibrary()) {
@@ -59,12 +58,15 @@ final class NativeLibrary {
     }
     // Native code is not present, perhaps it has been packaged into the .jar file containing this.
     // Extract the JNI library itself
-    final String jniResourceName = makeResourceName(JNI_LIBNAME);
+    final String jniLibName = System.mapLibraryName(JNI_LIBNAME);
+    final String jniResourceName = makeResourceName(jniLibName);
     log("jniResourceName: " + jniResourceName);
     final InputStream jniResource =
         NativeLibrary.class.getClassLoader().getResourceAsStream(jniResourceName);
     // Extract the JNI's dependency
-    final String frameworkResourceName = makeResourceName(FRAMEWORK_LIBNAME);
+    final String frameworkLibName =
+        maybeAdjustForMacOS(System.mapLibraryName("tensorflow_framework"));
+    final String frameworkResourceName = makeResourceName(frameworkLibName);
     log("frameworkResourceName: " + frameworkResourceName);
     final InputStream frameworkResource =
         NativeLibrary.class.getClassLoader().getResourceAsStream(frameworkResourceName);
@@ -88,12 +90,15 @@ final class NativeLibrary {
       tempPath.deleteOnExit();
       final String tempDirectory = tempPath.toString();
       if (frameworkResource != null) {
-        extractResource(frameworkResource, FRAMEWORK_LIBNAME, tempDirectory);
+        extractResource(frameworkResource, frameworkLibName, tempDirectory);
       } else {
-        log(frameworkResourceName + " not found. This is fine assuming " + jniResourceName
-            + " is not built to depend on it.");
+        log(
+            frameworkResourceName
+                + " not found. This is fine assuming "
+                + jniResourceName
+                + " is not built to depend on it.");
       }
-      System.load(extractResource(jniResource, JNI_LIBNAME, tempDirectory));
+      System.load(extractResource(jniResource, jniLibName, tempDirectory));
     } catch (IOException e) {
       throw new UnsatisfiedLinkError(
           String.format(
@@ -121,9 +126,27 @@ final class NativeLibrary {
     }
   }
 
+  private static String maybeAdjustForMacOS(String libFilename) {
+    if (!System.getProperty("os.name").contains("OS X")) {
+      return libFilename;
+    }
+    // This is macOS, and the TensorFlow release process might have setup dependencies on
+    // libtensorflow_framework.so instead of libtensorflow_framework.dylib. Adjust for that.
+    final ClassLoader cl = NativeLibrary.class.getClassLoader();
+    if (cl.getResource(makeResourceName(libFilename)) != null) {
+      return libFilename;
+    }
+    // liftensorflow_framework.dylib not found, try libtensorflow_framework.so
+    final String suffix = ".dylib";
+    if (!libFilename.endsWith(suffix)) {
+      return libFilename;
+    }
+    return libFilename.substring(0, libFilename.length() - suffix.length()) + ".so";
+  }
+
   private static String extractResource(
       InputStream resource, String resourceName, String extractToDirectory) throws IOException {
-    final File dst = new File(extractToDirectory, System.mapLibraryName(resourceName));
+    final File dst = new File(extractToDirectory, resourceName);
     dst.deleteOnExit();
     final String dstPath = dst.toString();
     log("extracting native library to: " + dstPath);
@@ -157,9 +180,7 @@ final class NativeLibrary {
   }
 
   private static String makeResourceName(String baseName) {
-    return "org/tensorflow/native/"
-        + String.format("%s-%s/", os(), architecture())
-        + System.mapLibraryName(baseName);
+    return "org/tensorflow/native/" + String.format("%s-%s/", os(), architecture()) + baseName;
   }
 
   private static long copy(InputStream src, File dstFile) throws IOException {

From 8d46b72fdcf675245addb006aadcf358ddf7dd7d Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 10 Nov 2017 02:48:04 -0800
Subject: [PATCH 113/115] Correct comment in K-FAC's layer_collection

PiperOrigin-RevId: 175275184
---
 tensorflow/contrib/kfac/python/ops/layer_collection.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 4eabb59b3e4..7300a7998c2 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -448,10 +448,10 @@ class LayerCollection(object):
         tf.get_variable_scope().reuse.
 
     Raises:
-      ValueError: If reuse=True and name != None.
-      ValueError: If reuse=True and seed != None.
-      KeyError: If reuse=True and no existing LossFunction with 'name' found.
-      KeyError: If reuse=False and existing LossFunction with 'name' found.
+      ValueError: If reuse == True and name == None.
+      ValueError: If reuse == True and seed != None.
+      KeyError: If reuse == True and no existing LossFunction with 'name' found.
+      KeyError: If reuse == False and existing LossFunction with 'name' found.
     """
     name = name or self._graph.unique_name(
         "register_categorical_predictive_distribution")

From 593dfb6a340ed5348f935f725285c659b574327c Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 10 Nov 2017 03:30:53 -0800
Subject: [PATCH 114/115] Extend the Array class with more functionality.

PiperOrigin-RevId: 175277161
---
 tensorflow/compiler/xla/BUILD                 |   1 +
 tensorflow/compiler/xla/array.h               | 159 +++++++++++++++++-
 tensorflow/compiler/xla/array_test.cc         |  45 +++++
 .../compiler/xla/client/computation_builder.h |   1 +
 .../compiler/xla/service/hlo_instruction.h    |   5 +
 .../compiler/xla/service/hlo_sharding.cc      |   3 +-
 6 files changed, 205 insertions(+), 9 deletions(-)

diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index f6e405744a1..515b572b0eb 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -340,6 +340,7 @@ cc_library(
     name = "array",
     hdrs = ["array.h"],
     deps = [
+        ":status",
         ":types",
         "//tensorflow/core:lib",
     ],
diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h
index ba898d1f4e9..213e0bac6c7 100644
--- a/tensorflow/compiler/xla/array.h
+++ b/tensorflow/compiler/xla/array.h
@@ -23,8 +23,10 @@ limitations under the License.
 #include <iterator>
 #include <memory>
 #include <random>
+#include <type_traits>
 #include <vector>
 
+#include "tensorflow/compiler/xla/status.h"
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/core/lib/core/bits.h"
 #include "tensorflow/core/lib/strings/str_util.h"
@@ -35,10 +37,63 @@ limitations under the License.
 
 namespace xla {
 
+namespace array_impl {
+
+// conjunction
+//
+// Performs a compile-time logical AND operation on the passed types (which
+// must have  `::value` members convertible to `bool`. Short-circuits if it
+// encounters any `false` members (and does not compare the `::value` members
+// of any remaining arguments).
+//
+// This metafunction is designed to be a drop-in replacement for the C++17
+// `std::conjunction` metafunction.
+template <typename... Ts>
+struct conjunction;
+
+template <typename T, typename... Ts>
+struct conjunction<T, Ts...>
+    : std::conditional<T::value, conjunction<Ts...>, T>::type {};
+
+template <>
+struct conjunction<> : std::true_type {};
+
+// A type trait that is valid when all elements in a parameter pack are of
+// integral type.
+template <typename... T>
+using pack_is_integral = conjunction<std::is_integral<T>...>;
+
+// Compares three same-sized vectors elementwise. For each item in `values`,
+// returns false if any of values[i] is outside the half-open range [starts[i],
+// ends[i]).
+template <typename C1, typename C2, typename C3>
+bool all_inside_range(const C1& values, const C2& range_starts,
+                      const C3& range_ends) {
+  for (size_t i = 0, e = values.size(); i < e; ++i) {
+    if (values[i] < range_starts[i] || values[i] >= range_ends[i]) {
+      return false;
+    }
+  }
+  return true;
+}
+
+}  // namespace array_impl
+
 // General N dimensional array class with arbitrary value type.
 template <typename T>
 class Array {
  public:
+  // Type inference can have a hard time parsing very deep initializer list
+  // nests, especially if one or more dimensions is one as the compiler just
+  // sees a single-element integer initializer. These typedefs allow casting
+  // explicitly with less typing.
+  using InitializerList1D = std::initializer_list<T>;
+  using InitializerList2D = std::initializer_list<InitializerList1D>;
+  using InitializerList3D = std::initializer_list<InitializerList2D>;
+  using InitializerList4D = std::initializer_list<InitializerList3D>;
+
+  using value_type = T;
+
   // Creates a new array with the specified dimensions.
   explicit Array(tensorflow::gtl::ArraySlice<int64> sizes)
       : Array(sizes, T()) {}
@@ -53,7 +108,7 @@ class Array {
   // Creates a 2D array from the given nested initializer list. The outer
   // initializer list is the first dimension, the inner is the second dimension.
   // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
-  Array(std::initializer_list<std::initializer_list<T>> values)
+  Array(InitializerList2D values)
       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
     int64 idx = 0;
     for (const auto& it1 : values) {
@@ -67,8 +122,7 @@ class Array {
 
   // Creates a 3D array from the given nested initializer list. The outer
   // initializer list is the first dimension, and so on.
-  Array(std::initializer_list<std::initializer_list<std::initializer_list<T>>>
-            values)
+  Array(InitializerList3D values)
       : Array(ToInt64Vector({values.size(), values.begin()->size(),
                              values.begin()->begin()->size()})) {
     int64 idx = 0;
@@ -85,9 +139,7 @@ class Array {
 
   // Creates a 4D array from the given nested initializer list. The outer
   // initializer list is the first dimension, and so on.
-  Array(std::initializer_list<
-        std::initializer_list<std::initializer_list<std::initializer_list<T>>>>
-            values)
+  Array(InitializerList4D values)
       : Array(ToInt64Vector({values.size(), values.begin()->size(),
                              values.begin()->begin()->size(),
                              values.begin()->begin()->begin()->size()})) {
@@ -173,10 +225,46 @@ class Array {
     }
   }
 
+  // Invokes a callback with the (indices, value_ptr) for each cell in the
+  // array. If a callback returns a non-OK status, returns that else returns
+  // Status::OK().
+  Status EachStatus(
+      std::function<Status(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
+    std::vector<int64> index(sizes_.size());
+    for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
+      Status s = f(index, &values_[i]);
+      if (!s.ok()) {
+        return s;
+      }
+    }
+    return Status::OK();
+  }
+
+  // Invokes a callback with the (indices, value) for each cell in the array.
+  // If a callback returns a non-OK status, returns that else returns
+  // Status::OK().
+  Status EachStatus(
+      std::function<Status(tensorflow::gtl::ArraySlice<int64>, T)> f) const {
+    std::vector<int64> index(sizes_.size());
+    for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
+      Status s = f(index, values_[i]);
+      if (!s.ok()) {
+        return s;
+      }
+    }
+    return Status::OK();
+  }
+
   // Returns the value at the cell specified by the indexes. The number of
   // arguments have to match with the number of dimensions for the array.
+  //
+  // The type trait is required to avoid this overload participating too
+  // eagerly; a parameter pack can take zero or more elements, so we must
+  // restrict this to only parameter packs that are all of integral type.
   template <typename... Dims>
-  const T& operator()(Dims... dims) const {
+  typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
+                          const T&>::type
+  operator()(Dims... dims) const {
     // We are using a std::array to avoid having to allocate memory in this
     // function for performance reasons.
     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
@@ -186,7 +274,9 @@ class Array {
   // Returns the value at the cell specified by the indexes. The number of
   // arguments have to match with the number of dimensions for the array.
   template <typename... Dims>
-  T& operator()(Dims... dims) {
+  typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
+                          T&>::type
+  operator()(Dims... dims) {
     // We are using a std::array to avoid having to allocate memory in this
     // function for performance reasons.
     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
@@ -255,6 +345,59 @@ class Array {
 
   bool operator!=(const Array<T>& other) const { return !(*this == other); }
 
+  // Performs the equivalent of a slice operation on this array.
+  Array<T> Slice(tensorflow::gtl::ArraySlice<int64> starts,
+                 tensorflow::gtl::ArraySlice<int64> limits) const {
+    CHECK_EQ(starts.size(), num_dimensions());
+    CHECK_EQ(limits.size(), num_dimensions());
+
+    std::vector<int64> sizes;
+    std::transform(starts.begin(), starts.end(), limits.begin(),
+                   std::back_inserter(sizes),
+                   [](int64 start, int64 limit) { return limit - start; });
+    Array<T> result(sizes);
+
+    std::vector<int64> index(sizes_.size());
+    int64 slice_i = 0;
+    for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
+      if (array_impl::all_inside_range(index, starts, limits)) {
+        // Even though the bounds of result are different to our bounds, we're
+        // iterating in the same order. So we can simply write successive linear
+        // indices instead of recalculating a multi-dimensional index.
+        result.values_[slice_i++] = values_[i];
+      }
+    }
+    return result;
+  }
+
+  // Performs the equivalent of a DynamicUpdateSlice in-place on this array.
+  void UpdateSlice(const Array<T>& from,
+                   tensorflow::gtl::ArraySlice<int64> start_indices) {
+    CHECK_EQ(from.num_dimensions(), num_dimensions());
+    std::vector<int64> limit_indices;
+    std::transform(start_indices.begin(), start_indices.end(),
+                   from.dimensions().begin(), std::back_inserter(limit_indices),
+                   std::plus<int64>{});
+    std::vector<int64> index(sizes_.size());
+    int64 from_i = 0;
+    for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
+      if (array_impl::all_inside_range(index, start_indices, limit_indices)) {
+        // Even though the bounds of from are different to our bounds, we're
+        // iterating in the same order. So we can simply write successive linear
+        // indices instead of recalculating a multi-dimensional index.
+        values_[i] = from.values_[from_i++];
+      }
+    }
+  }
+
+  // Performs an in-place reshape, modifying the dimensions but not the
+  // underlying data.
+  void Reshape(tensorflow::gtl::ArraySlice<int64> new_dimensions) {
+    int64 old_num_elements = num_elements();
+    sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
+    CHECK_EQ(num_elements(), old_num_elements);
+  }
+
   // Returns a string representation of the array suitable for debugging.
   string ToString() const {
     std::vector<string> pieces;
diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc
index 093784f541b..8b941947747 100644
--- a/tensorflow/compiler/xla/array_test.cc
+++ b/tensorflow/compiler/xla/array_test.cc
@@ -71,6 +71,19 @@ TEST(ArrayTest, IndexingReadWrite) {
   EXPECT_EQ(arr(1, 2), 61);
 }
 
+TEST(ArrayTest, DynamicIndexingReadWrite) {
+  Array<int> arr({2, 3});
+
+  std::vector<int64> index1 = {1, 1};
+  std::vector<int64> index2 = {1, 2};
+  EXPECT_EQ(arr(index1), 0);
+  EXPECT_EQ(arr(index2), 0);
+  arr(index1) = 51;
+  arr(index2) = 61;
+  EXPECT_EQ(arr(1, 1), 51);
+  EXPECT_EQ(arr(1, 2), 61);
+}
+
 TEST(ArrayTest, IndexingReadWriteBool) {
   Array<bool> arr{{false, true, false}, {false, true, false}};
 
@@ -141,5 +154,37 @@ TEST(ArrayTest, Each) {
   EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum);
 }
 
+TEST(ArrayTest, Slice) {
+  Array<int64> arr({2, 4});
+  arr.FillWithMultiples(1);
+
+  Array<int64> identity_slice = arr.Slice({0, 0}, {2, 4});
+  EXPECT_EQ(identity_slice.dimensions(), arr.dimensions());
+  for (auto it1 = arr.begin(), it2 = identity_slice.begin(), e = arr.end();
+       it1 != e; ++it1, ++it2) {
+    EXPECT_EQ(*it1, *it2);
+  }
+
+  Array<int64> sub_slice = arr.Slice({1, 0}, {2, 2});
+  EXPECT_EQ(sub_slice.dimensions(), (std::vector<int64>{1, 2}));
+  const string expected = R"([[4, 5]])";
+  EXPECT_EQ(expected, sub_slice.ToString());
+}
+
+TEST(ArrayTest, UpdateSlice) {
+  Array<int64> arr({3, 4});
+  arr.FillWithMultiples(1);
+
+  Array<int64> sub_arr({2, 2});
+  sub_arr.FillWithMultiples(3);
+
+  arr.UpdateSlice(sub_arr, {1, 1});
+
+  const string expected = R"([[0, 1, 2, 3],
+ [4, 0, 3, 7],
+ [8, 6, 9, 11]])";
+  EXPECT_EQ(expected, arr.ToString());
+}
+
 }  // namespace
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 8e1b4be1f3e..4c6e320557f 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -68,6 +68,7 @@ class ShardingBuilder {
                          const TileAssignment& tile_assignment) {
     OpSharding result;
     result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
+    *result.mutable_tile_shape() = tile_shape;
     for (int64 dim : tile_assignment.dimensions()) {
       result.add_tile_assignment_dimensions(dim);
     }
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 64a88164a70..d174f05aa6b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -863,6 +863,11 @@ class HloInstruction {
     return *window_;
   }
 
+  // Sets the window data in a windowed operation such as convolution.
+  void set_window(const Window& window) {
+    window_ = MakeUnique<Window>(window);
+  }
+
   // Returns the padding configuration for a pad node.
   //
   // Precondition: opcode() == HloOpcode::kPad
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index bc5663513b9..73566634542 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -249,7 +249,8 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
     return HloSharding(tuple_shardings);
   } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
     return Replicate();
-  } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) {
+  } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL ||
+             proto.tile_assignment_devices().size() == 1) {
     return HloSharding(proto.tile_assignment_devices(0));
   }
   // Some versions of gcc cannot infer the TileAssignment constructor from a

From 8cf98b7d0e9118dc45f06a5fed9bfc62b2a86c44 Mon Sep 17 00:00:00 2001
From: Yifei Feng <fengyifei2026@gmail.com>
Date: Fri, 10 Nov 2017 08:19:04 -0800
Subject: [PATCH 115/115] Add "no_pip" to contrib/data/python/kernel_tests

Add "no_pip" to tests under contrib/data/python/kernel_tests that depend on tensorflow.contrib.data.python.kernel_tests
---
 tensorflow/contrib/data/python/kernel_tests/BUILD | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index d811683ecda..241fc2ab4f1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -11,6 +11,7 @@ py_test(
     size = "small",
     srcs = ["batch_dataset_op_test.py"],
     srcs_version = "PY2AND3",
+    tags = ["no_pip"],
     deps = [
         ":dataset_serialization_test",
         "//tensorflow/contrib/data/python/ops:dataset_ops",
@@ -364,6 +365,7 @@ py_test(
     size = "small",
     srcs = ["sequence_dataset_op_test.py"],
     srcs_version = "PY2AND3",
+    tags = ["no_pip"],
     deps = [
         ":dataset_serialization_test",
         "//tensorflow/contrib/data/python/ops:dataset_ops",
@@ -428,6 +430,7 @@ py_test(
     size = "small",
     srcs = ["zip_dataset_op_test.py"],
     srcs_version = "PY2AND3",
+    tags = ["no_pip"],
     deps = [
         ":dataset_serialization_test",
         "//tensorflow/contrib/data/python/ops:dataset_ops",