From 1038927c096ecc81ca48665871d1be390444b121 Mon Sep 17 00:00:00 2001
From: Saurabh Saxena <srbs@google.com>
Date: Mon, 23 Oct 2017 11:07:10 -0700
Subject: [PATCH] Add SerializeIterator op that serializes an IteratorResource
 into a variant tensor. Add DeserializeIterator op that builds
 IteratorResource from a variant tensor. Move BundleReaderWrapper and
 BundleWriterWrapper from dataset.h to iterator_ops.cc. Add generic key-value
 store interfaces IteratorStateReader and IteratorStateWriter for
 reading/writing state of iterators. Get rid of IteratorBundleReader and
 IteratorBundleWriter.

PiperOrigin-RevId: 173140858
---
 .../contrib/data/python/kernel_tests/BUILD    |   4 +
 .../python/kernel_tests/iterator_ops_test.py  |  29 +-
 .../kernel_tests/range_dataset_op_test.py     |  67 ++--
 .../kernel_tests/reader_dataset_ops_test.py   |  25 +-
 tensorflow/core/BUILD                         |   1 +
 tensorflow/core/framework/iterator.proto      |  17 +
 tensorflow/core/kernels/BUILD                 |   1 +
 tensorflow/core/kernels/dataset.h             | 195 ++++------
 tensorflow/core/kernels/iterator_ops.cc       | 363 +++++++++++++-----
 tensorflow/core/kernels/parse_tensor_op.cc    |   1 +
 tensorflow/core/kernels/range_dataset_op.cc   |  11 +-
 tensorflow/core/kernels/reader_dataset_ops.cc |  17 +-
 tensorflow/core/kernels/repeat_dataset_op.cc  |  13 +-
 .../core/ops/compat/ops_history.v1.pbtxt      |  24 --
 tensorflow/core/ops/dataset_ops.cc            |  42 +-
 tensorflow/python/kernel_tests/BUILD          |   5 +
 .../python/kernel_tests/iterator_ops_test.py  |  29 +-
 .../kernel_tests/range_dataset_op_test.py     |  67 ++--
 .../kernel_tests/reader_dataset_ops_test.py   |  26 +-
 19 files changed, 544 insertions(+), 393 deletions(-)
 create mode 100644 tensorflow/core/framework/iterator.proto

diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index c34c9dad9b5..b3175e3e56c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -185,6 +185,7 @@ py_test(
         "//tensorflow/python:function",
         "//tensorflow/python:functional_ops",
         "//tensorflow/python:gradients",
+        "//tensorflow/python:io_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:parsing_ops",
         "//tensorflow/python:script_ops",
@@ -252,6 +253,8 @@ py_test(
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:io_ops",
+        "//tensorflow/python:parsing_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:variables",
@@ -274,6 +277,7 @@ py_test(
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:io_ops",
         "//tensorflow/python:lib",
         "//tensorflow/python:parsing_ops",
         "//tensorflow/python:tensor_shape",
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 20f6d6ba34f..bda9a2a4a37 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import script_ops
@@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
 
   def testIncorrectIteratorRestore(self):
 
-    def _iterator_checkpoint_prefix():
+    def _path():
       return os.path.join(self.get_temp_dir(), "iterator")
 
+    def _save_op(iterator_resource):
+      iterator_state_variant = gen_dataset_ops.serialize_iterator(
+          iterator_resource)
+      save_op = io_ops.write_file(
+          _path(), parsing_ops.serialize_tensor(iterator_state_variant))
+      return save_op
+
+    def _restore_op(iterator_resource):
+      iterator_state_variant = parsing_ops.parse_tensor(
+          io_ops.read_file(_path()), dtypes.variant)
+      restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                        iterator_state_variant)
+      return restore_op
+
     def _build_range_dataset_graph():
       start = 1
       stop = 10
@@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
                                            stop).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = _iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = _save_op(iterator._iterator_resource)
+      restore_op = _restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     def _build_reader_dataset_graph():
       filenames = ["test"]  # Does not exist but we don't care in this test.
-      path = _iterator_checkpoint_prefix()
       iterator = readers.FixedLengthRecordDataset(
           filenames, 1, 0, 0).make_initializable_iterator()
       init_op = iterator.initializer
       get_next_op = iterator.get_next()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = _save_op(iterator._iterator_resource)
+      restore_op = _restore_op(iterator._iterator_resource)
       return init_op, get_next_op, save_op, restore_op
 
     # Saving iterator for RangeDataset graph.
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index c8a0072809c..c944eb4a49c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -29,6 +29,8 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
@@ -193,6 +195,21 @@ class RangeDatasetTest(test.TestCase):
   def _iterator_checkpoint_prefix(self):
     return os.path.join(self.get_temp_dir(), "iterator")
 
+  def _save_op(self, iterator_resource):
+    iterator_state_variant = gen_dataset_ops.serialize_iterator(
+        iterator_resource)
+    save_op = io_ops.write_file(
+        self._iterator_checkpoint_prefix(),
+        parsing_ops.serialize_tensor(iterator_state_variant))
+    return save_op
+
+  def _restore_op(self, iterator_resource):
+    iterator_state_variant = parsing_ops.parse_tensor(
+        io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
+    restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                      iterator_state_variant)
+    return restore_op
+
   def testSaveRestore(self):
 
     def _build_graph(start, stop):
@@ -200,10 +217,8 @@ class RangeDatasetTest(test.TestCase):
                                            stop).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     # Saving and restoring in different sessions.
@@ -246,14 +261,13 @@ class RangeDatasetTest(test.TestCase):
 
   def testRestoreWithoutBuildingDatasetGraph(self):
 
-    def _build_graph(start, stop, num_epochs, path):
+    def _build_graph(start, stop, num_epochs):
       dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
       iterator = dataset.make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     # Saving and restoring in different sessions.
@@ -262,10 +276,8 @@ class RangeDatasetTest(test.TestCase):
     num_epochs = 5
     break_point = 5
     break_epoch = 3
-    path = self._iterator_checkpoint_prefix()
     with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
-                                                   path)
+      init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
       with self.test_session(graph=g) as sess:
         sess.run(variables.global_variables_initializer())
         sess.run(init_op)
@@ -282,8 +294,7 @@ class RangeDatasetTest(test.TestCase):
       output_shapes = tensor_shape.scalar()
       iterator = iterator_ops.Iterator.from_structure(output_types,
                                                       output_shapes)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      restore_op = self._restore_op(iterator._iterator_resource)
       get_next = iterator.get_next()
       with self.test_session(graph=g) as sess:
         sess.run(restore_op)
@@ -302,10 +313,8 @@ class RangeDatasetTest(test.TestCase):
       iterator = dataset.make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     # Saving and restoring in different sessions.
@@ -343,10 +352,8 @@ class RangeDatasetTest(test.TestCase):
       iterator = dataset.make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     # Saving and restoring in different sessions.
@@ -379,10 +386,8 @@ class RangeDatasetTest(test.TestCase):
                                            stop).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     start = 2
@@ -424,10 +429,8 @@ class RangeDatasetTest(test.TestCase):
           start, stop).repeat(num_epochs).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     start = 2
@@ -471,10 +474,8 @@ class RangeDatasetTest(test.TestCase):
           start, stop).repeat(num_epochs).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     start = 2
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index c9f88f3dfc9..2682e8bdfa3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_shape
 from tensorflow.python.lib.io import python_io
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import test
 from tensorflow.python.util import compat
@@ -276,18 +277,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
   def _iterator_checkpoint_path(self):
     return os.path.join(self.get_temp_dir(), "iterator")
 
+  def _save_op(self, iterator_resource):
+    iterator_state_variant = gen_dataset_ops.serialize_iterator(
+        iterator_resource)
+    save_op = io_ops.write_file(
+        self._iterator_checkpoint_path(),
+        parsing_ops.serialize_tensor(iterator_state_variant))
+    return save_op
+
+  def _restore_op(self, iterator_resource):
+    iterator_state_variant = parsing_ops.parse_tensor(
+        io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
+    restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                      iterator_state_variant)
+    return restore_op
+
   def _build_iterator_graph(self, num_epochs):
     filenames = self._createFiles()
-    path = self._iterator_checkpoint_path()
     dataset = (readers.FixedLengthRecordDataset(
         filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
                .repeat(num_epochs))
     iterator = dataset.make_initializable_iterator()
     init_op = iterator.initializer
     get_next_op = iterator.get_next()
-    save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-    restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                  path)
+    save_op = self._save_op(iterator._iterator_resource)
+    restore_op = self._restore_op(iterator._iterator_resource)
     return init_op, get_next_op, save_op, restore_op
 
   def _restore_iterator(self):
@@ -295,8 +309,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
     output_shapes = tensor_shape.scalar()
     iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
     get_next = iterator.get_next()
-    restore_op = gen_dataset_ops.restore_iterator(
-        iterator._iterator_resource, self._iterator_checkpoint_path())
+    restore_op = self._restore_op(iterator._iterator_resource)
     return restore_op, get_next
 
   def testSaveRestore(self):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6ad93a73f4e..c4f880da9d3 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -163,6 +163,7 @@ CORE_PROTO_SRCS = [
     "framework/function.proto",
     "framework/graph.proto",
     "framework/graph_transfer_info.proto",
+    "framework/iterator.proto",
     "framework/kernel_def.proto",
     "framework/log_memory.proto",
     "framework/node_def.proto",
diff --git a/tensorflow/core/framework/iterator.proto b/tensorflow/core/framework/iterator.proto
new file mode 100644
index 00000000000..7e5f5ea2e0c
--- /dev/null
+++ b/tensorflow/core/framework/iterator.proto
@@ -0,0 +1,17 @@
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "IteratorProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.util";
+
+// Protocol buffer representing the metadata for an iterator's state stored
+// as a Variant tensor.
+message IteratorStateMetadata {
+  // A user-specified version string.
+  string version = 1;
+
+  // Keys for tensors in the VariantTensorDataProto.
+  repeated string keys = 2;
+}
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index d931f12f6dd..f5bfa60199f 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -6061,6 +6061,7 @@ tf_kernel_library(
         "//tensorflow/core:framework",
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
+        "//tensorflow/core:protos_all_cc",
     ],
 )
 
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index f9ffc4e065b..a906113466d 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -17,12 +17,14 @@ limitations under the License.
 
 #include <memory>
 
+#include "tensorflow/core/common_runtime/graph_runner.h"
 #include "tensorflow/core/framework/graph.pb.h"
 #include "tensorflow/core/framework/register_types.h"
 #include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/framework/variant_encode_decode.h"
 #include "tensorflow/core/framework/variant_tensor_data.h"
 #include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/graph_def_builder.h"
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -39,54 +41,25 @@ namespace tensorflow {
 
 class ResourceMgr;
 
-class BundleReaderWrapper {
+// Interface for reading values from a key-value store.
+// Used for restoring iterator state.
+class IteratorStateReader {
  public:
-  BundleReaderWrapper(BundleReader* bundle_reader)
-      : bundle_reader_(bundle_reader) {}
+  virtual Status ReadScalar(StringPiece key, int64* val) = 0;
+  virtual Status ReadScalar(StringPiece key, string* val) = 0;
+  virtual bool Contains(StringPiece key) = 0;
 
-  // Reads a scalar value.
-  template <typename T>
-  Status ReadScalar(StringPiece key, T* val) {
-    Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
-    TF_RETURN_IF_ERROR(Lookup(key, &val_t));
-    *val = val_t.scalar<T>()();
-    return Status::OK();
-  }
-
-  bool Contains(StringPiece key) { return bundle_reader_->Contains(key); }
-
- private:
-  Status Lookup(StringPiece key, Tensor* val) {
-    return bundle_reader_->Lookup(key, val);
-  }
-
-  BundleReader* bundle_reader_;
+  virtual ~IteratorStateReader() {}
 };
 
-class BundleWriterWrapper {
+// Interface for writing values to a key-value store.
+// Used for saving iterator state.
+class IteratorStateWriter {
  public:
-  // Note: We intentionally do not provide a constructor that builds a
-  // BundleWriter from the checkpoint path because we want the caller to be
-  // in-charge of calling BundleWriter::Finish(). If we expose the Finish()
-  // method here it may be called pre-maturely by users of this object.
-  explicit BundleWriterWrapper(BundleWriter* bundle_writer)
-      : bundle_writer_(bundle_writer) {}
+  virtual Status WriteScalar(StringPiece key, const int64& val) = 0;
+  virtual Status WriteScalar(StringPiece key, const string& val) = 0;
 
-  // Writes a scalar value.
-  template <typename T>
-  Status WriteScalar(StringPiece key, const T val) {
-    Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
-    val_t.scalar<T>()() = val;
-    TF_RETURN_IF_ERROR(Add(key, val_t));
-    return Status::OK();
-  }
-
- private:
-  Status Add(StringPiece key, const Tensor& val) {
-    return bundle_writer_->Add(key, val);
-  }
-
-  BundleWriter* bundle_writer_;
+  virtual ~IteratorStateWriter() {}
 };
 
 // Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
@@ -249,10 +222,6 @@ class IteratorContext {
 // range of outputs is typically represented by an `DatasetBase`,
 // defined below.
 class IteratorBase {
- protected:
-  class IteratorBundleReader;
-  class IteratorBundleWriter;
-
  public:
   virtual ~IteratorBase() {}
 
@@ -284,75 +253,17 @@ class IteratorBase {
   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
 
   // Saves the state of this iterator.
-  virtual Status Save(OpKernelContext* ctx, const string& path) {
-    BundleWriter bundle_writer(ctx->env(), path);
-    TF_RETURN_IF_ERROR(bundle_writer.status());
-    IteratorBundleWriter writer(&bundle_writer);
-    TF_RETURN_IF_ERROR(Save(ctx, &writer));
-    return bundle_writer.Finish();
-  }
-
-  virtual Status Restore(OpKernelContext* ctx, const string& path) {
-    if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
-      return errors::NotFound(
-          "Failed to restore Iterator state. No file found at ",
-          MetaFilename(path));
-    }
-    BundleReader bundle_reader(ctx->env(), path);
-    TF_RETURN_IF_ERROR(bundle_reader.status());
-    IteratorBundleReader reader(&bundle_reader);
-    return Restore(ctx, &reader);
-  }
-
-  static const char kIteratorExhausted[];
-
- protected:
-  // This is needed so that sub-classes of IteratorBase can call
-  // `RestoreInternal` on their parent iterators, e.g., in
-  // `RepeatDataasetOp::Dataset`.
-  class IteratorBundleReader : public BundleReaderWrapper {
-   public:
-    IteratorBundleReader(BundleReader* bundle_reader)
-        : BundleReaderWrapper(bundle_reader) {}
-
-    // Restores the state of a parent iterator recursively.
-    Status RestoreParent(OpKernelContext* ctx,
-                         const std::unique_ptr<IteratorBase>& parent) {
-      return parent->RestoreInternal(ctx, this);
-    }
-  };
-
-  // This is needed so that sub-classes of IteratorBase can call
-  // `SaveInternal` on their parent iterators, e.g., in
-  // `RepeatDataasetOp::Dataset`.
-  class IteratorBundleWriter : public BundleWriterWrapper {
-   public:
-    IteratorBundleWriter(BundleWriter* bundle_writer)
-        : BundleWriterWrapper(bundle_writer) {}
-    // Saves the state of a parent iterator recursively.
-    Status SaveParent(OpKernelContext* ctx,
-                      const std::unique_ptr<IteratorBase>& parent) {
-      return parent->SaveInternal(ctx, this);
-    }
-  };
-
-  virtual Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) {
+  virtual Status Save(IteratorStateWriter* writer) {
     if (is_exhausted_) {
       LOG(INFO) << "Iterator exhausted.";
-      return writer->WriteScalar<string>(kIteratorExhausted,
-                                         kIteratorExhausted);
+      return writer->WriteScalar(kIteratorExhausted, kIteratorExhausted);
     } else {
-      return SaveInternal(ctx, writer);
+      return SaveInternal(writer);
     }
   }
 
-  // Saves the state of this iterator.
-  virtual Status SaveInternal(OpKernelContext* ctx,
-                              IteratorBundleWriter* writer) {
-    return errors::Unimplemented("SaveInternal");
-  }
-
-  virtual Status Restore(OpKernelContext* ctx, IteratorBundleReader* reader) {
+  // Restores the state of this iterator.
+  virtual Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
     if (reader->Contains(kIteratorExhausted)) {
       LOG(INFO) << "Iterator exhausted. Nothing to restore.";
       is_exhausted_ = true;
@@ -362,9 +273,33 @@ class IteratorBase {
     }
   }
 
-  // Restores the state of this iterator.
+  static const char kIteratorExhausted[];
+
+ protected:
+  // This is needed so that sub-classes of IteratorBase can call
+  // `SaveInternal` on their parent iterators, e.g., in
+  // `RepeatDataasetOp::Dataset`.
+  Status SaveParent(IteratorStateWriter* writer,
+                    const std::unique_ptr<IteratorBase>& parent) {
+    return parent->SaveInternal(writer);
+  }
+
+  // This is needed so that sub-classes of IteratorBase can call
+  // `RestoreInternal` on their parent iterators, e.g., in
+  // `RepeatDataasetOp::Dataset`.
+  Status RestoreParent(OpKernelContext* ctx, IteratorStateReader* reader,
+                       const std::unique_ptr<IteratorBase>& parent) {
+    return parent->RestoreInternal(ctx, reader);
+  }
+
+  // Saves the state of this iterator recursively.
+  virtual Status SaveInternal(IteratorStateWriter* writer) {
+    return errors::Unimplemented("SaveInternal");
+  }
+
+  // Restores the state of this iterator recursively.
   virtual Status RestoreInternal(OpKernelContext* ctx,
-                                 IteratorBundleReader* reader) {
+                                 IteratorStateReader* reader) {
     return errors::Unimplemented("RestoreInternal");
   }
 
@@ -404,7 +339,7 @@ class DatasetBase : public core::RefCounted {
   virtual string DebugString() = 0;
 
   // Serializes the dataset and writes it to the `writer`.
-  virtual Status Save(BundleWriterWrapper* writer) const {
+  virtual Status Save(IteratorStateWriter* writer) const {
     return errors::Unimplemented("DatasetBase::Save");
   }
 
@@ -435,20 +370,14 @@ class GraphDatasetBase : public DatasetBase {
 
   const string op_name() const { return op_name_; }
 
-  Status Save(BundleWriterWrapper* writer) const override {
-    GraphDefBuilder b;
-    DatasetGraphDefBuilder db(&b);
-    Node* node = nullptr;
-    TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
-    string output_name = node->name();
-    GraphDef graph_def;
-    TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+  Status Save(IteratorStateWriter* writer) const override {
     string serialized_graph_def;
-    graph_def.SerializeToString(&serialized_graph_def);
+    string output_node;
+    TF_RETURN_IF_ERROR(Serialize(&serialized_graph_def, &output_node));
     TF_RETURN_IF_ERROR(
-        writer->WriteScalar<string>(kDatasetGraphKey, serialized_graph_def));
+        writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
     TF_RETURN_IF_ERROR(
-        writer->WriteScalar<string>(kDatasetGraphOutputNodeKey, output_name));
+        writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
     return Status::OK();
   }
 
@@ -460,6 +389,18 @@ class GraphDatasetBase : public DatasetBase {
   static const char kDatasetGraphOutputNodeKey[];
 
  private:
+  Status Serialize(string* serialized_graph_def, string* output_node) const {
+    GraphDefBuilder b;
+    DatasetGraphDefBuilder db(&b);
+    Node* node = nullptr;
+    TF_RETURN_IF_ERROR(AsGraphDefInternal(&db, &node));
+    *output_node = node->name();
+    GraphDef graph_def;
+    TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+    graph_def.SerializeToString(serialized_graph_def);
+    return Status::OK();
+  }
+
   const string op_name_;
 };
 
@@ -505,18 +446,18 @@ class DatasetIterator : public IteratorBase {
     return GetNextInternal(ctx, out_tensors, end_of_sequence);
   }
 
- protected:
-  Status Save(OpKernelContext* ctx, IteratorBundleWriter* writer) final {
+  Status Save(IteratorStateWriter* writer) final {
     TF_RETURN_IF_ERROR(dataset()->Save(writer));
-    return IteratorBase::Save(ctx, writer);
+    return IteratorBase::Save(writer);
   }
 
+ protected:
   // Internal implementation of GetNext that is wrapped in tracing logic.
   virtual Status GetNextInternal(IteratorContext* ctx,
                                  std::vector<Tensor>* out_tensors,
                                  bool* end_of_sequence) = 0;
 
-  string full_name(const string& name) {
+  string full_name(const string& name) const {
     return strings::StrCat(prefix(), ":", name);
   }
 
diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc
index df13edc83ae..b7c1fff2a9d 100644
--- a/tensorflow/core/kernels/iterator_ops.cc
+++ b/tensorflow/core/kernels/iterator_ops.cc
@@ -16,9 +16,11 @@ limitations under the License.
 
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/framework/iterator.pb.h"
 #include "tensorflow/core/framework/partial_tensor_shape.h"
 #include "tensorflow/core/framework/resource_op_kernel.h"
 #include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/kernels/ops_util.h"
 #include "tensorflow/core/lib/core/threadpool.h"
@@ -35,6 +37,8 @@ namespace {
 // See documentation in ../ops/dataset_ops.cc for a high-level
 // description of the following ops.
 
+const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
+
 Status VerifyTypesMatch(const DataTypeVector& expected,
                         const DataTypeVector& received) {
   if (expected.size() != received.size()) {
@@ -93,10 +97,10 @@ class IteratorResource : public ResourceBase {
     }
   }
 
-  Status Save(OpKernelContext* ctx, const string& path) {
+  Status Save(IteratorStateWriter* writer) {
     std::shared_ptr<IteratorBase> captured_iterator(iterator_);
     if (captured_iterator) {
-      return captured_iterator->Save(ctx, path);
+      return captured_iterator->Save(writer);
     } else {
       return errors::FailedPrecondition(
           "Save() failed because the iterator has not been initialized. "
@@ -105,49 +109,34 @@ class IteratorResource : public ResourceBase {
     }
   }
 
-  Status Restore(OpKernelContext* ctx, const string& path) {
-    if (!(ctx->env()->FileExists(MetaFilename(path)).ok())) {
-      return errors::NotFound(
-          "Failed to restore Iterator state. No file found at ",
-          MetaFilename(path));
+  Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
+    string serialized_graph_def;
+    TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey,
+                                          &serialized_graph_def));
+    GraphDef graph_def;
+    if (!graph_def.ParseFromString(serialized_graph_def)) {
+      return errors::Internal("Error parsing dataset GraphDef.");
     }
+    string output_node;
+    TF_RETURN_IF_ERROR(reader->ReadScalar(
+        GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
+    DatasetBase* dataset = nullptr;
+    Graph graph(OpRegistry::Global());
+    TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
+    std::vector<Tensor> outputs;
+    GraphRunner graph_runner(ctx->env());
+    TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
+                                        {output_node}, &outputs));
+    TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
 
-    BundleReader bundle_reader(ctx->env(), path);
-    TF_RETURN_IF_ERROR(bundle_reader.status());
-    BundleReaderWrapper reader(&bundle_reader);
-    if (reader.Contains(GraphDatasetBase::kDatasetGraphKey)) {
-      string serialized_graph_def;
-      TF_RETURN_IF_ERROR(reader.ReadScalar(GraphDatasetBase::kDatasetGraphKey,
-                                           &serialized_graph_def));
-      GraphDef graph_def;
-      graph_def.ParseFromString(serialized_graph_def);
-      // TODO(srbs): Is there a way of getting the op registry of the original
-      // graph.
-      Graph graph(OpRegistry::Global());
-      TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
-      string output_node;
-      TF_RETURN_IF_ERROR(reader.ReadScalar(
-          GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
-      std::vector<Tensor> outputs;
-      GraphRunner graph_runner(ctx->env());
-      TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
-                                          {output_node}, &outputs));
-      DatasetBase* dataset;
-      TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
-      TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
-    } else if (reader.Contains(IteratorBase::kIteratorExhausted)) {
-      TF_RETURN_IF_ERROR(set_iterator(std::unique_ptr<IteratorBase>(
-          new ExhaustedIterator(output_dtypes_, output_shapes_))));
-    }
+    TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
     std::shared_ptr<IteratorBase> captured_iterator(iterator_);
 
     if (captured_iterator) {
-      // TODO(srbs): Figure a way to pass bundle_reader here.
-      return captured_iterator->Restore(ctx, path);
+      return captured_iterator->Restore(ctx, reader);
     } else {
       return errors::FailedPrecondition(
-          "Failed to restore iterator from ", path,
-          ". Make sure the checkpoint ",
+          "Failed to restore iterator. Make sure the checkpoint ",
           "is not corrupt. If the checkpoint does not contain the GraphDef, ",
           "you will need to initialize your iterator before restoring.");
     }
@@ -174,43 +163,194 @@ class IteratorResource : public ResourceBase {
   }
 
  private:
-  // A no-op iterator which always sets end_of_sequence = true. An instance of
-  // this is returned when attempting to restore an exhausted iterator. This is
-  // needed because the Dataset GraphDef may not have been saved for exhausted
-  // iterators so the actual Iterator can not be built.
-  class ExhaustedIterator : public IteratorBase {
-   public:
-    ExhaustedIterator(const DataTypeVector& output_dtypes,
-                      const std::vector<PartialTensorShape>& output_shapes)
-        : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
-    Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
-                   bool* end_of_sequence) final {
-      *end_of_sequence = true;
-      return Status::OK();
-    }
-
-    const DataTypeVector& output_dtypes() const override {
-      return output_dtypes_;
-    }
-
-    const std::vector<PartialTensorShape>& output_shapes() const override {
-      return output_shapes_;
-    }
-
-    virtual const std::vector<PartialTensorShape>& output_shapes() {
-      return output_shapes_;
-    }
-
-   private:
-    const DataTypeVector output_dtypes_;
-    const std::vector<PartialTensorShape> output_shapes_;
-  };
-
   std::shared_ptr<IteratorBase> iterator_;
   const DataTypeVector output_dtypes_;
   const std::vector<PartialTensorShape> output_shapes_;
 };
 
+// Helper class for reading data from a VariantTensorData object.
+class VariantTensorDataReader : public IteratorStateReader {
+ public:
+  explicit VariantTensorDataReader(const VariantTensorData* data)
+      : data_(data) {
+    PreProcess();
+  }
+
+  // Returns OK iff the initialization was successful, i.e.,
+  // pre-processing did not have errors.
+  Status status() const { return status_; }
+
+  Status ReadScalar(StringPiece key, int64* val) override {
+    return ReadScalarInternal(key, val);
+  }
+
+  Status ReadScalar(StringPiece key, string* val) override {
+    return ReadScalarInternal(key, val);
+  }
+
+  bool Contains(StringPiece key) override {
+    return map_.find(key.ToString()) != map_.end();
+  }
+
+ private:
+  void PreProcess() {
+    string metadata;
+    data_->get_metadata(&metadata);
+    IteratorStateMetadata proto;
+    if (!proto.ParseFromString(metadata)) {
+      status_ = errors::Internal("Error parsing IteratorStateMetadata.");
+      return;
+    }
+    size_t num_entries = proto.keys_size();
+    CHECK_EQ(num_entries, data_->tensors_size());
+    for (size_t i = 0; i < num_entries; i++) {
+      map_[proto.keys(i)] = i;
+    }
+  }
+
+  template <typename T>
+  Status ReadScalarInternal(StringPiece key, T* val) {
+    if (map_.find(key.ToString()) == map_.end()) {
+      return errors::NotFound(key);
+    }
+    *val = data_->tensors(map_[key.ToString()]).scalar<T>()();
+    return Status::OK();
+  }
+
+  std::map<string, size_t> map_;
+  const VariantTensorData* data_;  // Not owned.
+  Status status_;
+};
+
+// Helper class for writing data to a VariantTensorData object.
+class VariantTensorDataWriter : public IteratorStateWriter {
+ public:
+  // Does not take ownership of data.
+  explicit VariantTensorDataWriter(VariantTensorData* data) : data_(data) {}
+
+  Status WriteScalar(StringPiece key, const int64& val) override {
+    return WriteScalarInternal(key, val);
+  }
+
+  Status WriteScalar(StringPiece key, const string& val) override {
+    return WriteScalarInternal(key, val);
+  }
+
+  // Writes the metadata to `data_`.
+  Status Flush() {
+    string metadata;
+    if (!metadata_proto_.SerializeToString(&metadata)) {
+      return errors::Internal("Unable to serialize IteratorStateMetadata.");
+    }
+    data_->set_metadata(metadata);
+    return Status::OK();
+  }
+
+ private:
+  template <typename T>
+  Status WriteScalarInternal(StringPiece key, const T& val) {
+    // Write key to the metadata proto. This gets written to `data_`
+    // when `Flush()` is called. We do this lazily to avoid multiple
+    // serialization calls.
+    metadata_proto_.add_keys(key.ToString());
+
+    // Update tensors.
+    Tensor val_t = Tensor(DataTypeToEnum<T>::v(), TensorShape({}));
+    val_t.scalar<T>()() = val;
+    *(data_->add_tensors()) = std::move(val_t);
+    return Status::OK();
+  }
+
+  VariantTensorData* data_;
+  // TODO(srbs): Set the version string.
+  IteratorStateMetadata metadata_proto_;
+};
+
+// Wrapper for encoding/decoding the iterator state stored in a Variant tensor.
+// The get() method returns an IteratorStateReader which can be used
+// to restore iterator state.
+//
+// Usage example:
+//
+// Encoding:
+//
+//   Tensor t(DT_VARIANT, TensorShape({}));
+//   t->scalar<Variant>()() = IteratorStateVariant(iterator_resource);
+//
+// Encode() sets the type_name of the VariantTensorData object to
+// IteratorStateVariant::TypeName().
+//
+// Decoding:
+//
+//   Variant v = <VariantTensorDataProto object>;
+//   DecodeUnaryVariant(&v);
+//   IteratorStateVariant* wrapper = v.get<IteratorStateVariant>();
+//   iterator_resource->Restore(ctx, wrapper->get())
+//
+// The type_name of the VariantTensorData object to be decoded must
+// match IteratorStateVariant::TypeName().
+class IteratorStateVariant {
+ public:
+  IteratorStateVariant() : data_(nullptr) {}
+  IteratorStateVariant(const IteratorStateVariant& other) : data_(nullptr) {
+    if (other.data_) {
+      Decode(*other.data_);
+    }
+  }
+  // Initializes this object with the current state of the iterator so
+  // that it can be written on the next call to Encode().
+  Status InitializeFromIterator(IteratorResource* iterator_resource) {
+    data_.reset(new VariantTensorData());
+    data_->set_type_name(TypeName());
+    VariantTensorDataWriter writer(data_.get());
+    TF_RETURN_IF_ERROR(iterator_resource->Save(&writer));
+    TF_RETURN_IF_ERROR(writer.Flush());
+    return Status::OK();
+  }
+  string TypeName() const { return kIteratorVariantTypeName; }
+  void Encode(VariantTensorData* data) const { *data = *data_; }
+  bool Decode(const VariantTensorData& data) {
+    if (data.type_name() != TypeName()) {
+      return false;
+    }
+    std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
+    *tensor_data = data;
+    std::unique_ptr<VariantTensorDataReader> reader(
+        new VariantTensorDataReader(tensor_data.get()));
+    status_ = reader->status();
+    if (!status_.ok()) {
+      return false;
+    }
+    data_ = std::move(tensor_data);
+    reader_ = std::move(reader);
+    return true;
+  }
+  IteratorStateReader* get() { return reader_.get(); }
+  Status status() const { return status_; }
+  string DebugString() const {
+    if (data_) {
+      return strings::StrCat("IteratorStateVariant<",
+                             "data: ", data_->DebugString(),
+                             " status: ", status_.ToString(), ">");
+    } else {
+      return strings::StrCat("IteratorStateVariant<empty>");
+    }
+  }
+
+ private:
+  std::unique_ptr<IteratorStateReader> reader_;
+  Status status_;
+  std::unique_ptr<VariantTensorData> data_;
+};
+
+// Register the reader class in the global variant decode_fn registry
+// so that a Variant containing a serialized representation of iterator state
+// can be decoded using DecodeUnaryVariant. If we don't do this we will need
+// to manually decode the returned Variant using MaybeDecodeAndCopy in
+// DeserializeIteratorOp which is not recommended.
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
+                                       kIteratorVariantTypeName);
+
 // TODO(mrry): Can we simply use the template kernel here?
 class IteratorHandleOp : public ResourceOpKernel<IteratorResource> {
  public:
@@ -294,37 +434,6 @@ class ToSingleElementOp : public OpKernel {
   }
 };
 
-class SaveIteratorOp : public OpKernel {
- public:
-  explicit SaveIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
-  void Compute(OpKernelContext* ctx) override {
-    IteratorResource* iterator_resource;
-    OP_REQUIRES_OK(
-        ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
-    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
-                errors::InvalidArgument("SaveIteratorOp: path must be scalar"));
-    const string& path = ctx->input(1).scalar<string>()();
-    OP_REQUIRES_OK(ctx, iterator_resource->Save(ctx, path));
-  }
-};
-
-class RestoreIteratorOp : public OpKernel {
- public:
-  explicit RestoreIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
-  void Compute(OpKernelContext* ctx) override {
-    IteratorResource* iterator_resource;
-    OP_REQUIRES_OK(
-        ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
-    OP_REQUIRES(
-        ctx, TensorShapeUtils::IsScalar(ctx->input(1).shape()),
-        errors::InvalidArgument("RestoreIteratorOp: path must be scalar"));
-    const string& path = ctx->input(1).scalar<string>()();
-    OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, path));
-  }
-};
-
 class OneShotIteratorOp : public AsyncOpKernel {
  public:
   explicit OneShotIteratorOp(OpKernelConstruction* ctx)
@@ -644,15 +753,55 @@ class IteratorFromStringHandleOp : public OpKernel {
   std::vector<PartialTensorShape> output_shapes_;
 };
 
+class SerializeIteratorOp : public OpKernel {
+ public:
+  explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    const Tensor& resource_handle_t = ctx->input(0);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
+                errors::InvalidArgument("resource_handle must be a scalar"));
+
+    // Validate that the handle corresponds to a real resource, and
+    // that it is an IteratorResource.
+    IteratorResource* iterator_resource;
+    OP_REQUIRES_OK(
+        ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
+    iterator_resource->Unref();
+    Tensor* variant_t;
+    OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &variant_t));
+    IteratorStateVariant v;
+    OP_REQUIRES_OK(ctx, v.InitializeFromIterator(iterator_resource));
+    variant_t->scalar<Variant>()() = v;
+  }
+};
+
+class DeserializeIteratorOp : public OpKernel {
+ public:
+  explicit DeserializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+  void Compute(OpKernelContext* ctx) override {
+    // Validate that the handle corresponds to a real resource, and
+    // that it is an IteratorResource.
+    IteratorResource* iterator_resource;
+    OP_REQUIRES_OK(
+        ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
+
+    Variant variant = ctx->input(1).scalar<Variant>()();
+    auto* wrapper = variant.get<IteratorStateVariant>();
+    OP_REQUIRES(ctx, wrapper != nullptr,
+                errors::InvalidArgument(
+                    "DeserializeIteratorOp: Unable to parse variant tensor."));
+    OP_REQUIRES_OK(ctx, wrapper->status());
+    OP_REQUIRES_OK(ctx, iterator_resource->Restore(ctx, wrapper->get()));
+  }
+};
+
 REGISTER_KERNEL_BUILDER(Name("Iterator").Device(DEVICE_CPU), IteratorHandleOp);
 REGISTER_KERNEL_BUILDER(Name("MakeIterator").Device(DEVICE_CPU),
                         MakeIteratorOp);
 REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
                         ToSingleElementOp);
-REGISTER_KERNEL_BUILDER(Name("SaveIterator").Device(DEVICE_CPU),
-                        SaveIteratorOp);
-REGISTER_KERNEL_BUILDER(Name("RestoreIterator").Device(DEVICE_CPU),
-                        RestoreIteratorOp);
 REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
                         OneShotIteratorOp);
 REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
@@ -661,6 +810,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
                         IteratorToStringHandleOp);
 REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandle").Device(DEVICE_CPU),
                         IteratorFromStringHandleOp);
+REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
+                        SerializeIteratorOp);
+REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
+                        DeserializeIteratorOp);
 
 }  // namespace
 
diff --git a/tensorflow/core/kernels/parse_tensor_op.cc b/tensorflow/core/kernels/parse_tensor_op.cc
index ab91a6ef677..6b599612ad7 100644
--- a/tensorflow/core/kernels/parse_tensor_op.cc
+++ b/tensorflow/core/kernels/parse_tensor_op.cc
@@ -92,6 +92,7 @@ class SerializeTensorOp : public OpKernel {
       Name("SerializeTensor").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
       SerializeTensorOp<T>);
 TF_CALL_ALL_TYPES(REGISTER)
+TF_CALL_variant(REGISTER)
 #undef REGISTER
 
 }  // namespace tensorflow
diff --git a/tensorflow/core/kernels/range_dataset_op.cc b/tensorflow/core/kernels/range_dataset_op.cc
index a57c21a590b..7adfcc4f8d2 100644
--- a/tensorflow/core/kernels/range_dataset_op.cc
+++ b/tensorflow/core/kernels/range_dataset_op.cc
@@ -112,19 +112,16 @@ class RangeDatasetOp : public DatasetOpKernel {
       }
 
      protected:
-      Status SaveInternal(OpKernelContext* ctx,
-                          IteratorBundleWriter* writer) override {
+      Status SaveInternal(IteratorStateWriter* writer) override {
         mutex_lock l(mu_);
-        TF_RETURN_IF_ERROR(
-            writer->WriteScalar<int64>(full_name("next"), next_));
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("next"), next_));
         return Status::OK();
       }
 
       Status RestoreInternal(OpKernelContext* ctx,
-                             IteratorBundleReader* reader) override {
+                             IteratorStateReader* reader) override {
         mutex_lock l(mu_);
-        TF_RETURN_IF_ERROR(
-            reader->ReadScalar<int64>(full_name("next"), &next_));
+        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("next"), &next_));
         return Status::OK();
       }
 
diff --git a/tensorflow/core/kernels/reader_dataset_ops.cc b/tensorflow/core/kernels/reader_dataset_ops.cc
index b455c28e07c..fb88c55f73b 100644
--- a/tensorflow/core/kernels/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/reader_dataset_ops.cc
@@ -356,31 +356,30 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
       }
 
      protected:
-      Status SaveInternal(OpKernelContext* ctx,
-                          IteratorBundleWriter* writer) override {
+      Status SaveInternal(IteratorStateWriter* writer) override {
         mutex_lock l(mu_);
-        TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(
-            full_name("current_file_index"), current_file_index_));
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
+                                               current_file_index_));
 
         // `input_buffer_` is empty if
         // 1. GetNext has not been called even once.
         // 2. All files have been read and iterator has been exhausted.
         int64 current_pos = input_buffer_ ? input_buffer_->Tell() : -1;
         TF_RETURN_IF_ERROR(
-            writer->WriteScalar<int64>(full_name("current_pos"), current_pos));
+            writer->WriteScalar(full_name("current_pos"), current_pos));
         return Status::OK();
       }
 
       Status RestoreInternal(OpKernelContext* ctx,
-                             IteratorBundleReader* reader) override {
+                             IteratorStateReader* reader) override {
         mutex_lock l(mu_);
         int64 current_file_index;
-        TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(
-            full_name("current_file_index"), &current_file_index));
+        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
+                                              &current_file_index));
         current_file_index_ = size_t(current_file_index);
         int64 current_pos;
         TF_RETURN_IF_ERROR(
-            reader->ReadScalar<int64>(full_name("current_pos"), &current_pos));
+            reader->ReadScalar(full_name("current_pos"), &current_pos));
 
         // Seek to current_pos.
         input_buffer_.reset();
diff --git a/tensorflow/core/kernels/repeat_dataset_op.cc b/tensorflow/core/kernels/repeat_dataset_op.cc
index 5d836927d22..9813e99a70b 100644
--- a/tensorflow/core/kernels/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/repeat_dataset_op.cc
@@ -124,19 +124,18 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
       }
 
      protected:
-      Status SaveInternal(OpKernelContext* ctx,
-                          IteratorBundleWriter* writer) override {
+      Status SaveInternal(IteratorStateWriter* writer) override {
         mutex_lock l(mu_);
-        TF_RETURN_IF_ERROR(writer->WriteScalar<int64>(full_name("i"), i_));
-        TF_RETURN_IF_ERROR(writer->SaveParent(ctx, input_impl_));
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
+        TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
         return Status::OK();
       }
 
       Status RestoreInternal(OpKernelContext* ctx,
-                             IteratorBundleReader* reader) override {
+                             IteratorStateReader* reader) override {
         mutex_lock l(mu_);
-        TF_RETURN_IF_ERROR(reader->ReadScalar<int64>(full_name("i"), &i_));
-        TF_RETURN_IF_ERROR(reader->RestoreParent(ctx, input_impl_));
+        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
+        TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
         return Status::OK();
       }
 
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 6772024263d..c5ceb14a096 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -28753,18 +28753,6 @@ op {
   }
   is_stateful: true
 }
-op {
-  name: "RestoreIterator"
-  input_arg {
-    name: "iterator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "path"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
 op {
   name: "RestoreSlice"
   input_arg {
@@ -29548,18 +29536,6 @@ op {
   }
   is_stateful: true
 }
-op {
-  name: "SaveIterator"
-  input_arg {
-    name: "iterator"
-    type: DT_RESOURCE
-  }
-  input_arg {
-    name: "path"
-    type: DT_STRING
-  }
-  is_stateful: true
-}
 op {
   name: "SaveSlices"
   input_arg {
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 566049179a1..8b77e3f9f0b 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -598,24 +598,6 @@ This operation may be executed multiple times. Each execution will reset the
 iterator in `iterator` to the first element of `dataset`.
 )doc");
 
-REGISTER_OP("SaveIterator")
-    .Input("iterator: resource")
-    .Input("path: string")
-    .SetShapeFn(shape_inference::NoOutputs)
-    .Doc(R"doc(
-Saves the state of the `iterator` at `path`.
-
-This state can be restored using "RestoreIterator".
-)doc");
-
-REGISTER_OP("RestoreIterator")
-    .Input("iterator: resource")
-    .Input("path: string")
-    .SetShapeFn(shape_inference::NoOutputs)
-    .Doc(R"doc(
-Restores the state of the `iterator` from the checkpoint saved at `path` using "SaveIterator".
-)doc");
-
 REGISTER_OP("OneShotIterator")
     .Output("handle: resource")
     .Attr("dataset_factory: func")
@@ -737,4 +719,28 @@ output_shapes: If specified, defines the shape of each tuple component in an
   element produced by the resulting iterator.
 )doc");
 
+REGISTER_OP("SerializeIterator")
+    .Input("resource_handle: resource")
+    .Output("serialized: variant")
+    .SetShapeFn(shape_inference::ScalarShape)
+    .Doc(R"doc(
+Converts the given `resource_handle` representing an iterator to a variant tensor.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+  resource.
+)doc");
+
+REGISTER_OP("DeserializeIterator")
+    .Input("resource_handle: resource")
+    .Input("serialized: variant")
+    .SetShapeFn(shape_inference::NoOutputs)
+    .Doc(R"doc(
+Converts the given variant tensor to an iterator and stores it in the given resource.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+  resource.
+)doc");
+
 }  // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 0e36c3498ae..b02bae95fd5 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2886,7 +2886,9 @@ tf_py_test(
         "//tensorflow/python:dataset_ops_gen",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
+        "//tensorflow/python:io_ops",
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:parsing_ops",
         "//tensorflow/python:platform",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:variables",
@@ -2907,7 +2909,9 @@ tf_py_test(
         "//tensorflow/python:dtypes",
         "//tensorflow/python:errors",
         "//tensorflow/python:framework_ops",
+        "//tensorflow/python:io_ops",
         "//tensorflow/python:lib",
+        "//tensorflow/python:parsing_ops",
         "//tensorflow/python:tensor_shape",
         "//tensorflow/python:util",
         "//tensorflow/python/data/ops:iterator_ops",
@@ -3022,6 +3026,7 @@ tf_py_test(
         "//tensorflow/python:function",
         "//tensorflow/python:functional_ops",
         "//tensorflow/python:gradients",
+        "//tensorflow/python:io_ops",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:parsing_ops",
         "//tensorflow/python:script_ops",
diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py
index b5ec9f7db03..2128ef4ae17 100644
--- a/tensorflow/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/kernel_tests/iterator_ops_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import functional_ops
 from tensorflow.python.ops import gen_dataset_ops
 from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import io_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import script_ops
@@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
 
   def testIncorrectIteratorRestore(self):
 
-    def _iterator_checkpoint_prefix():
+    def _path():
       return os.path.join(self.get_temp_dir(), "iterator")
 
+    def _save_op(iterator_resource):
+      iterator_state_variant = gen_dataset_ops.serialize_iterator(
+          iterator_resource)
+      save_op = io_ops.write_file(
+          _path(), parsing_ops.serialize_tensor(iterator_state_variant))
+      return save_op
+
+    def _restore_op(iterator_resource):
+      iterator_state_variant = parsing_ops.parse_tensor(
+          io_ops.read_file(_path()), dtypes.variant)
+      restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                        iterator_state_variant)
+      return restore_op
+
     def _build_range_dataset_graph():
       start = 1
       stop = 10
@@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
                                            stop).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = _iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = _save_op(iterator._iterator_resource)
+      restore_op = _restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     def _build_reader_dataset_graph():
       filenames = ["test"]  # Does not exist but we don't care in this test.
-      path = _iterator_checkpoint_prefix()
       iterator = readers.FixedLengthRecordDataset(
           filenames, 1, 0, 0).make_initializable_iterator()
       init_op = iterator.initializer
       get_next_op = iterator.get_next()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = _save_op(iterator._iterator_resource)
+      restore_op = _restore_op(iterator._iterator_resource)
       return init_op, get_next_op, save_op, restore_op
 
     # Saving iterator for RangeDataset graph.
diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py
index 82919671556..0c530522b83 100644
--- a/tensorflow/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py
@@ -27,6 +27,8 @@ from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
@@ -169,6 +171,21 @@ class RangeDatasetTest(test.TestCase):
   def _iterator_checkpoint_prefix(self):
     return os.path.join(self.get_temp_dir(), "iterator")
 
+  def _save_op(self, iterator_resource):
+    iterator_state_variant = gen_dataset_ops.serialize_iterator(
+        iterator_resource)
+    save_op = io_ops.write_file(
+        self._iterator_checkpoint_prefix(),
+        parsing_ops.serialize_tensor(iterator_state_variant))
+    return save_op
+
+  def _restore_op(self, iterator_resource):
+    iterator_state_variant = parsing_ops.parse_tensor(
+        io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
+    restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                      iterator_state_variant)
+    return restore_op
+
   def testSaveRestore(self):
 
     def _build_graph(start, stop):
@@ -176,10 +193,8 @@ class RangeDatasetTest(test.TestCase):
                                            stop).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     # Saving and restoring in different sessions.
@@ -222,14 +237,13 @@ class RangeDatasetTest(test.TestCase):
 
   def testRestoreWithoutBuildingDatasetGraph(self):
 
-    def _build_graph(start, stop, num_epochs, path):
+    def _build_graph(start, stop, num_epochs):
       dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
       iterator = dataset.make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     # Saving and restoring in different sessions.
@@ -238,10 +252,8 @@ class RangeDatasetTest(test.TestCase):
     num_epochs = 5
     break_point = 5
     break_epoch = 3
-    path = self._iterator_checkpoint_prefix()
     with ops.Graph().as_default() as g:
-      init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
-                                                   path)
+      init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
       with self.test_session(graph=g) as sess:
         sess.run(variables.global_variables_initializer())
         sess.run(init_op)
@@ -258,8 +270,7 @@ class RangeDatasetTest(test.TestCase):
       output_shapes = tensor_shape.scalar()
       iterator = iterator_ops.Iterator.from_structure(output_types,
                                                       output_shapes)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      restore_op = self._restore_op(iterator._iterator_resource)
       get_next = iterator.get_next()
       with self.test_session(graph=g) as sess:
         sess.run(restore_op)
@@ -278,10 +289,8 @@ class RangeDatasetTest(test.TestCase):
       iterator = dataset.make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     # Saving and restoring in different sessions.
@@ -319,10 +328,8 @@ class RangeDatasetTest(test.TestCase):
       iterator = dataset.make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     # Saving and restoring in different sessions.
@@ -355,10 +362,8 @@ class RangeDatasetTest(test.TestCase):
                                            stop).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     start = 2
@@ -400,10 +405,8 @@ class RangeDatasetTest(test.TestCase):
           start, stop).repeat(num_epochs).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     start = 2
@@ -447,10 +450,8 @@ class RangeDatasetTest(test.TestCase):
           start, stop).repeat(num_epochs).make_initializable_iterator()
       init_op = iterator.initializer
       get_next = iterator.get_next()
-      path = self._iterator_checkpoint_prefix()
-      save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-      restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                    path)
+      save_op = self._save_op(iterator._iterator_resource)
+      restore_op = self._restore_op(iterator._iterator_resource)
       return init_op, get_next, save_op, restore_op
 
     start = 2
diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
index 38420328efe..c8e7333b4b9 100644
--- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
@@ -31,6 +31,8 @@ from tensorflow.python.framework import tensor_shape
 from tensorflow.python.lib.io import python_io
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
 from tensorflow.python.platform import test
 from tensorflow.python.util import compat
 
@@ -273,18 +275,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
   def _iterator_checkpoint_path(self):
     return os.path.join(self.get_temp_dir(), "iterator")
 
+  def _save_op(self, iterator_resource):
+    iterator_state_variant = gen_dataset_ops.serialize_iterator(
+        iterator_resource)
+    save_op = io_ops.write_file(
+        self._iterator_checkpoint_path(),
+        parsing_ops.serialize_tensor(iterator_state_variant))
+    return save_op
+
+  def _restore_op(self, iterator_resource):
+    iterator_state_variant = parsing_ops.parse_tensor(
+        io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
+    restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+                                                      iterator_state_variant)
+    return restore_op
+
   def _build_iterator_graph(self, num_epochs):
     filenames = self._createFiles()
-    path = self._iterator_checkpoint_path()
     dataset = (readers.FixedLengthRecordDataset(
         filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
                .repeat(num_epochs))
     iterator = dataset.make_initializable_iterator()
     init_op = iterator.initializer
     get_next_op = iterator.get_next()
-    save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
-    restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
-                                                  path)
+    save_op = self._save_op(iterator._iterator_resource)
+    restore_op = self._restore_op(iterator._iterator_resource)
     return init_op, get_next_op, save_op, restore_op
 
   def _restore_iterator(self):
@@ -292,8 +307,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
     output_shapes = tensor_shape.scalar()
     iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
     get_next = iterator.get_next()
-    restore_op = gen_dataset_ops.restore_iterator(
-        iterator._iterator_resource, self._iterator_checkpoint_path())
+    restore_op = self._restore_op(iterator._iterator_resource)
     return restore_op, get_next
 
   def testSaveRestore(self):