Consolidating experimental/ops/iterator_ops._Saveable and ops/iterator_ops._IteratorSaveable.

PiperOrigin-RevId: 286256630
Change-Id: I91d97553a64656c575ae5a5771a9145978065910
This commit is contained in:
Rohan Jain 2019-12-18 13:43:15 -08:00 committed by TensorFlower Gardener
parent a519b79c05
commit 3378750f4c

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
@ -69,24 +68,8 @@ def make_saveable_from_iterator(iterator):
Note: Not all iterators support checkpointing yet. Attempting to save the Note: Not all iterators support checkpointing yet. Attempting to save the
state of an unsupported iterator will throw an error. state of an unsupported iterator will throw an error.
""" """
return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access return iterator_ops._IteratorSaveable(iterator._iterator_resource, # pylint: disable=protected-access
iterator._iterator_resource.name) # pylint: disable=protected-access
class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
"""SaveableObject for saving/restoring iterator state."""
def __init__(self, iterator_resource):
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
specs = [
saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
iterator_resource.name + "-state")
]
super(_Saveable, self).__init__(iterator_resource, specs,
iterator_resource.name)
def restore(self, restored_tensors, unused_restored_shapes):
with ops.colocate_with(self.op):
return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
@tf_export("data.experimental.CheckpointInputPipelineHook") @tf_export("data.experimental.CheckpointInputPipelineHook")
@ -189,7 +172,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
if (self._checkpoint_saver_hook._saver is None and if (self._checkpoint_saver_hook._saver is None and
self._checkpoint_saver_hook._scaffold is None): self._checkpoint_saver_hook._scaffold is None):
iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
saveables = [_Saveable(i) for i in iterators] saveables = [iterator_ops._IteratorSaveable(i, i.name) for i in iterators]
self._checkpoint_saver_hook._saver = _CustomSaver(saveables, self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
self._latest_filename) self._latest_filename)
# pylint: enable=protected-access # pylint: enable=protected-access