Consolidating experimental/ops/iterator_ops._Saveable and ops/iterator_ops._IteratorSaveable.
PiperOrigin-RevId: 286256630 Change-Id: I91d97553a64656c575ae5a5771a9145978065910
This commit is contained in:
parent
a519b79c05
commit
3378750f4c
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
@ -69,24 +68,8 @@ def make_saveable_from_iterator(iterator):
|
||||
Note: Not all iterators support checkpointing yet. Attempting to save the
|
||||
state of an unsupported iterator will throw an error.
|
||||
"""
|
||||
return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
|
||||
|
||||
|
||||
class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
||||
"""SaveableObject for saving/restoring iterator state."""
|
||||
|
||||
def __init__(self, iterator_resource):
|
||||
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
|
||||
specs = [
|
||||
saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
|
||||
iterator_resource.name + "-state")
|
||||
]
|
||||
super(_Saveable, self).__init__(iterator_resource, specs,
|
||||
iterator_resource.name)
|
||||
|
||||
def restore(self, restored_tensors, unused_restored_shapes):
|
||||
with ops.colocate_with(self.op):
|
||||
return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
|
||||
return iterator_ops._IteratorSaveable(iterator._iterator_resource, # pylint: disable=protected-access
|
||||
iterator._iterator_resource.name) # pylint: disable=protected-access
|
||||
|
||||
|
||||
@tf_export("data.experimental.CheckpointInputPipelineHook")
|
||||
@ -189,7 +172,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
|
||||
if (self._checkpoint_saver_hook._saver is None and
|
||||
self._checkpoint_saver_hook._scaffold is None):
|
||||
iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
|
||||
saveables = [_Saveable(i) for i in iterators]
|
||||
saveables = [iterator_ops._IteratorSaveable(i, i.name) for i in iterators]
|
||||
self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
|
||||
self._latest_filename)
|
||||
# pylint: enable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user