Consolidating experimental/ops/iterator_ops._Saveable and ops/iterator_ops._IteratorSaveable.

PiperOrigin-RevId: 286256630
Change-Id: I91d97553a64656c575ae5a5771a9145978065910
This commit is contained in:
Rohan Jain 2019-12-18 13:43:15 -08:00 committed by TensorFlower Gardener
parent a519b79c05
commit 3378750f4c

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
@ -69,24 +68,8 @@ def make_saveable_from_iterator(iterator):
Note: Not all iterators support checkpointing yet. Attempting to save the
state of an unsupported iterator will throw an error.
"""
return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
"""SaveableObject for saving/restoring iterator state."""
def __init__(self, iterator_resource):
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
specs = [
saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
iterator_resource.name + "-state")
]
super(_Saveable, self).__init__(iterator_resource, specs,
iterator_resource.name)
def restore(self, restored_tensors, unused_restored_shapes):
with ops.colocate_with(self.op):
return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
return iterator_ops._IteratorSaveable(iterator._iterator_resource, # pylint: disable=protected-access
iterator._iterator_resource.name) # pylint: disable=protected-access
@tf_export("data.experimental.CheckpointInputPipelineHook")
@ -189,7 +172,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
if (self._checkpoint_saver_hook._saver is None and
self._checkpoint_saver_hook._scaffold is None):
iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
saveables = [_Saveable(i) for i in iterators]
saveables = [iterator_ops._IteratorSaveable(i, i.name) for i in iterators]
self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
self._latest_filename)
# pylint: enable=protected-access