Consolidating experimental/ops/iterator_ops._Saveable and ops/iterator_ops._IteratorSaveable.
PiperOrigin-RevId: 286256630 Change-Id: I91d97553a64656c575ae5a5771a9145978065910
This commit is contained in:
parent
a519b79c05
commit
3378750f4c
@ -19,7 +19,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_dataset_ops
|
|
||||||
from tensorflow.python.training import basic_session_run_hooks
|
from tensorflow.python.training import basic_session_run_hooks
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
@ -69,24 +68,8 @@ def make_saveable_from_iterator(iterator):
|
|||||||
Note: Not all iterators support checkpointing yet. Attempting to save the
|
Note: Not all iterators support checkpointing yet. Attempting to save the
|
||||||
state of an unsupported iterator will throw an error.
|
state of an unsupported iterator will throw an error.
|
||||||
"""
|
"""
|
||||||
return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
|
return iterator_ops._IteratorSaveable(iterator._iterator_resource, # pylint: disable=protected-access
|
||||||
|
iterator._iterator_resource.name) # pylint: disable=protected-access
|
||||||
|
|
||||||
class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
|
||||||
"""SaveableObject for saving/restoring iterator state."""
|
|
||||||
|
|
||||||
def __init__(self, iterator_resource):
|
|
||||||
serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
|
|
||||||
specs = [
|
|
||||||
saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
|
|
||||||
iterator_resource.name + "-state")
|
|
||||||
]
|
|
||||||
super(_Saveable, self).__init__(iterator_resource, specs,
|
|
||||||
iterator_resource.name)
|
|
||||||
|
|
||||||
def restore(self, restored_tensors, unused_restored_shapes):
|
|
||||||
with ops.colocate_with(self.op):
|
|
||||||
return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export("data.experimental.CheckpointInputPipelineHook")
|
@tf_export("data.experimental.CheckpointInputPipelineHook")
|
||||||
@ -189,7 +172,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
|
|||||||
if (self._checkpoint_saver_hook._saver is None and
|
if (self._checkpoint_saver_hook._saver is None and
|
||||||
self._checkpoint_saver_hook._scaffold is None):
|
self._checkpoint_saver_hook._scaffold is None):
|
||||||
iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
|
iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
|
||||||
saveables = [_Saveable(i) for i in iterators]
|
saveables = [iterator_ops._IteratorSaveable(i, i.name) for i in iterators]
|
||||||
self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
|
self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
|
||||||
self._latest_filename)
|
self._latest_filename)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
Loading…
Reference in New Issue
Block a user