Builds a static tf.train.Saver from a checkpointable object graph
Moves around some SaveableObjects to support a freeze method for python state saveables, and makes sure that the object graph proto is included in the frozen Saver. This should be useful for embedding in SavedModels, where variables can be updated and the resulting checkpoints (saved from the SaverDef in the SavedModel) will still support Keras-style object-based restoration into Python programs (with better eager support and less fragile variable matching). This is also a step toward Estimators saving object-based checkpoints. PiperOrigin-RevId: 212017296
This commit is contained in:
parent
bb096a0735
commit
ca92311cbd
@ -17,11 +17,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import functools
|
||||
import json
|
||||
import weakref
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -91,7 +94,45 @@ class CheckpointInitialValue(ops.Tensor):
|
||||
return self._checkpoint_position
|
||||
|
||||
|
||||
class PythonStringStateSaveable(saveable_object.SaveableObject):
|
||||
class NoRestoreSaveable(saveable_object.SaveableObject):
|
||||
"""Embeds a tensor in a checkpoint with no restore ops."""
|
||||
|
||||
def __init__(self, tensor, name, dtype=None):
|
||||
spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
|
||||
super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class PythonStateSaveable(saveable_object.SaveableObject):
|
||||
"""An interface for saving/restoring volatile Python state."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def feed_dict_additions(self):
|
||||
"""When running a graph, indicates fresh state to feed.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping `Tensor`s to current Python state.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def freeze(self):
|
||||
"""Create a new `SaveableObject` which freezes current state as a constant.
|
||||
|
||||
Used when executing eagerly to embed the current state as a constant, or
|
||||
when creating a static tf.train.Saver with the frozen current Python state.
|
||||
|
||||
Returns:
|
||||
A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
|
||||
no Python state associated with it).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PythonStringStateSaveable(PythonStateSaveable):
|
||||
"""Saves Python state in a checkpoint."""
|
||||
|
||||
def __init__(self, name, state_callback, restore_callback=None):
|
||||
@ -104,19 +145,26 @@ class PythonStringStateSaveable(saveable_object.SaveableObject):
|
||||
restore_callback: A function taking a Python string, used to restore
|
||||
state. Optional; defaults to doing nothing.
|
||||
"""
|
||||
self._state_callback = state_callback
|
||||
self._restore_callback = restore_callback
|
||||
if context.executing_eagerly():
|
||||
self._save_string = (
|
||||
lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
|
||||
else:
|
||||
with ops.device("/cpu:0"):
|
||||
self._save_string = constant_op.constant("", dtype=dtypes.string)
|
||||
self.feed_dict_additions = (
|
||||
lambda: {self._save_string: state_callback()})
|
||||
spec = saveable_object.SaveSpec(
|
||||
self._save_string, "", name, dtype=dtypes.string)
|
||||
super(PythonStringStateSaveable, self).__init__(
|
||||
self._save_string, [spec], name)
|
||||
|
||||
def feed_dict_additions(self):
|
||||
"""When running a graph, indicates fresh state to feed."""
|
||||
return {self._save_string: self._state_callback()}
|
||||
|
||||
def freeze(self):
|
||||
"""Create a frozen `SaveableObject` which saves the current state."""
|
||||
return NoRestoreSaveable(
|
||||
tensor=self._state_callback,
|
||||
dtype=dtypes.string,
|
||||
name=self.name)
|
||||
|
||||
def python_restore(self, restored_strings):
|
||||
"""Called to restore Python state."""
|
||||
if self._restore_callback:
|
||||
@ -309,7 +357,7 @@ class _CheckpointPosition(object):
|
||||
if self._checkpoint.saveable_object_cache is not None:
|
||||
self._checkpoint.saveable_object_cache.setdefault(
|
||||
self.checkpointable, {})[serialized_tensor.name] = [saveable]
|
||||
if isinstance(saveable, PythonStringStateSaveable):
|
||||
if isinstance(saveable, PythonStateSaveable):
|
||||
python_saveables.append(saveable)
|
||||
else:
|
||||
named_saveables[serialized_tensor.checkpoint_key] = saveable
|
||||
@ -819,7 +867,7 @@ class CheckpointableBase(object):
|
||||
def _state_callback():
|
||||
dereferenced_self = weak_self()
|
||||
if dereferenced_self:
|
||||
return json.dumps(self,
|
||||
return json.dumps(dereferenced_self,
|
||||
default=serialization.get_json_type,
|
||||
sort_keys=True).encode("utf8")
|
||||
else:
|
||||
|
@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_io_ops as io_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
@ -557,7 +556,14 @@ def _serialize_checkpointables(
|
||||
object_graph_proto = (
|
||||
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
|
||||
named_saveables = []
|
||||
feed_additions = {}
|
||||
if saveables_cache is None:
|
||||
# No SaveableObject caching. Either we're executing eagerly, or building a
|
||||
# static save which is specialized to the current Python state.
|
||||
feed_additions = None
|
||||
else:
|
||||
# If we are caching SaveableObjects, we need to build up a feed_dict with
|
||||
# functions computing volatile Python state to be saved with the checkpoint.
|
||||
feed_additions = {}
|
||||
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
|
||||
assert node_ids[checkpointable] == checkpoint_id
|
||||
object_proto = object_graph_proto.nodes.add()
|
||||
@ -616,18 +622,25 @@ def _serialize_checkpointables(
|
||||
for saveable in saveables:
|
||||
if hasattr(saveable, "full_name"):
|
||||
attribute.full_name = saveable.full_name
|
||||
saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
|
||||
if saveable_feed_dict_fn is not None:
|
||||
saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable
|
||||
for new_feed_key in saveable_feed_dict.keys():
|
||||
if new_feed_key in feed_additions:
|
||||
raise AssertionError(
|
||||
("The object %s tried to feed a value for the Tensor %s "
|
||||
"when saving, but another object is already feeding a "
|
||||
"value.")
|
||||
% (checkpointable, new_feed_key))
|
||||
feed_additions.update(saveable_feed_dict)
|
||||
named_saveables.extend(saveables)
|
||||
if isinstance(saveable, base.PythonStateSaveable):
|
||||
if feed_additions is None:
|
||||
assert saveables_cache is None
|
||||
# If we're not caching saveables, then we're either executing
|
||||
# eagerly or building a static save/restore (e.g. for a
|
||||
# SavedModel). In either case, we should embed the current Python
|
||||
# state in the graph rather than relying on a feed dict.
|
||||
saveable = saveable.freeze()
|
||||
else:
|
||||
saveable_feed_dict = saveable.feed_dict_additions()
|
||||
for new_feed_key in saveable_feed_dict.keys():
|
||||
if new_feed_key in feed_additions:
|
||||
raise AssertionError(
|
||||
("The object %s tried to feed a value for the Tensor %s "
|
||||
"when saving, but another object is already feeding a "
|
||||
"value.")
|
||||
% (checkpointable, new_feed_key))
|
||||
feed_additions.update(saveable_feed_dict)
|
||||
named_saveables.append(saveable)
|
||||
|
||||
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
|
||||
child_proto = object_proto.children.add()
|
||||
@ -827,16 +840,6 @@ def capture_dependencies(template):
|
||||
yield
|
||||
|
||||
|
||||
class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
||||
|
||||
def __init__(self, tensor, name):
|
||||
spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name)
|
||||
super(_NoRestoreSaveable, self).__init__(tensor, [spec], name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
return control_flow_ops.no_op()
|
||||
|
||||
|
||||
class _LoadStatus(object):
|
||||
"""Abstract base for load status callbacks."""
|
||||
|
||||
@ -1241,6 +1244,78 @@ class CheckpointableSaver(object):
|
||||
else:
|
||||
return self._root_checkpointable_ref
|
||||
|
||||
def _gather_saveables(
|
||||
self, object_graph_tensor=None, saveable_object_cache=None):
|
||||
"""Wraps _serialize_object_graph to include the object graph proto."""
|
||||
assert ((object_graph_tensor is None and saveable_object_cache is None)
|
||||
or (object_graph_tensor is not None
|
||||
and saveable_object_cache is not None))
|
||||
(named_saveable_objects, graph_proto,
|
||||
feed_additions) = _serialize_object_graph(
|
||||
self._root_checkpointable,
|
||||
saveables_cache=saveable_object_cache)
|
||||
if object_graph_tensor is None:
|
||||
with ops.device("/cpu:0"):
|
||||
object_graph_tensor = constant_op.constant(
|
||||
graph_proto.SerializeToString(), dtype=dtypes.string)
|
||||
else:
|
||||
feed_additions.update(
|
||||
{object_graph_tensor: graph_proto.SerializeToString()})
|
||||
assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
|
||||
named_saveable_objects.append(
|
||||
base.NoRestoreSaveable(
|
||||
tensor=object_graph_tensor,
|
||||
name=base.OBJECT_GRAPH_PROTO_KEY))
|
||||
return named_saveable_objects, graph_proto, feed_additions
|
||||
|
||||
def freeze(self):
|
||||
"""Creates a `tf.train.Saver` with the current object graph frozen."""
|
||||
named_saveable_objects, _, _ = self._gather_saveables(
|
||||
object_graph_tensor=None, saveable_object_cache=None)
|
||||
return saver_lib.Saver(
|
||||
var_list=named_saveable_objects, max_to_keep=None)
|
||||
|
||||
def _prepare_save(self,
|
||||
object_graph_tensor=None,
|
||||
saveable_object_cache=None):
|
||||
"""Create or retrieve save ops.
|
||||
|
||||
When graph building, `saveable_object_cache` will typically be non-`None`,
|
||||
meaning that existing `SaveableObject`s are re-used across calls to
|
||||
`_prepare_save` even if the object graph has grown. This avoids
|
||||
unnecessarily re-creating save ops.
|
||||
|
||||
Args:
|
||||
object_graph_tensor: A `Tensor` to which the current object graph will be
|
||||
fed.
|
||||
saveable_object_cache: A dictionary; if specified, used to cache
|
||||
`SaveableObject`s.
|
||||
|
||||
Returns:
|
||||
A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
|
||||
to feed when running save ops. The feed dict contains the current object
|
||||
graph and any Python state to be saved in the checkpoint.
|
||||
"""
|
||||
(named_saveable_objects, graph_proto,
|
||||
feed_additions) = self._gather_saveables(
|
||||
object_graph_tensor=object_graph_tensor,
|
||||
saveable_object_cache=saveable_object_cache)
|
||||
if (self._last_save_object_graph != graph_proto
|
||||
# When executing eagerly, we need to re-create SaveableObjects each time
|
||||
# save() is called so they pick up new Tensors passed to their
|
||||
# constructors. That means the Saver needs to be copied with a new
|
||||
# var_list.
|
||||
or context.executing_eagerly()):
|
||||
if self._last_save_object_graph is not None:
|
||||
self._last_save_saver = _copy_saver_with_new_var_list(
|
||||
old_saver=self._last_save_saver,
|
||||
new_var_list=named_saveable_objects)
|
||||
else:
|
||||
self._last_save_saver = saver_lib.Saver(
|
||||
var_list=named_saveable_objects, max_to_keep=None)
|
||||
self._last_save_object_graph = graph_proto
|
||||
return self._last_save_saver, feed_additions
|
||||
|
||||
def save(self, file_prefix, checkpoint_number=None, session=None):
|
||||
"""Save a training checkpoint.
|
||||
|
||||
@ -1263,44 +1338,29 @@ class CheckpointableSaver(object):
|
||||
Returns:
|
||||
The full path to the checkpoint.
|
||||
"""
|
||||
named_variables, graph_proto, feed_additions = _serialize_object_graph(
|
||||
self._root_checkpointable,
|
||||
saveables_cache=self._saveable_object_cache)
|
||||
if not context.executing_eagerly():
|
||||
if session is None:
|
||||
session = ops.get_default_session()
|
||||
feed_additions = {}
|
||||
graph_building = not context.executing_eagerly()
|
||||
if graph_building:
|
||||
if self._object_graph_feed_tensor is None:
|
||||
with ops.device("/cpu:0"):
|
||||
self._object_graph_feed_tensor = constant_op.constant(
|
||||
"", dtype=dtypes.string)
|
||||
object_graph_tensor = self._object_graph_feed_tensor
|
||||
feed_additions.update(
|
||||
{object_graph_tensor: graph_proto.SerializeToString()})
|
||||
else:
|
||||
object_graph_tensor = None
|
||||
|
||||
saver, new_feed_additions = self._prepare_save(
|
||||
object_graph_tensor=object_graph_tensor,
|
||||
saveable_object_cache=self._saveable_object_cache)
|
||||
if new_feed_additions:
|
||||
feed_additions.update(new_feed_additions)
|
||||
if not graph_building:
|
||||
session = None
|
||||
with ops.device("/cpu:0"):
|
||||
object_graph_tensor = constant_op.constant(
|
||||
graph_proto.SerializeToString(), dtype=dtypes.string)
|
||||
assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables
|
||||
named_variables.append(
|
||||
_NoRestoreSaveable(
|
||||
tensor=object_graph_tensor,
|
||||
name=base.OBJECT_GRAPH_PROTO_KEY))
|
||||
if (self._last_save_object_graph != graph_proto
|
||||
# When executing eagerly, we need to re-create SaveableObjects each time
|
||||
# save() is called so they pick up new Tensors passed to their
|
||||
# constructors. That means the Saver needs to be copied with a new
|
||||
# var_list.
|
||||
or context.executing_eagerly()):
|
||||
if self._last_save_object_graph is not None:
|
||||
self._last_save_saver = _copy_saver_with_new_var_list(
|
||||
old_saver=self._last_save_saver, new_var_list=named_variables)
|
||||
else:
|
||||
self._last_save_saver = saver_lib.Saver(
|
||||
var_list=named_variables, max_to_keep=None)
|
||||
self._last_save_object_graph = graph_proto
|
||||
elif session is None:
|
||||
session = ops.get_default_session()
|
||||
|
||||
with ops.device("/cpu:0"):
|
||||
save_path = self._last_save_saver.save(
|
||||
save_path = saver.save(
|
||||
sess=_SessionWithFeedDictAdditions(
|
||||
session=session, feed_additions=feed_additions),
|
||||
save_path=file_prefix,
|
||||
@ -1422,6 +1482,30 @@ class CheckpointableSaver(object):
|
||||
return load_status
|
||||
|
||||
|
||||
def frozen_saver(root_checkpointable):
|
||||
"""Creates a static `tf.train.Saver` from a checkpointable object.
|
||||
|
||||
The returned `Saver` saves object-based checkpoints, but these checkpoints
|
||||
will no longer reflect structural changes to the object graph, only changes to
|
||||
the values of `Variable`s added as dependencies of the root object before
|
||||
`freeze` was called.
|
||||
|
||||
`restore` works on the returned `Saver`, but requires that the object graph of
|
||||
the checkpoint being loaded exactly matches the object graph when `freeze` was
|
||||
called. This is in contrast the object-based restore performed by
|
||||
`tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
|
||||
object graph and the current Python object graph.
|
||||
|
||||
Args:
|
||||
root_checkpointable: A checkpointable object to save.
|
||||
|
||||
Returns:
|
||||
A `tf.train.Saver` which saves object-based checkpoints for the object graph
|
||||
frozen at the time `frozen_saver` was called.
|
||||
"""
|
||||
return CheckpointableSaver(root_checkpointable).freeze()
|
||||
|
||||
|
||||
@tf_export("train.Checkpoint")
|
||||
class Checkpoint(tracking.Checkpointable):
|
||||
"""Groups checkpointable objects, saving and restoring them.
|
||||
|
@ -559,6 +559,46 @@ class CheckpointingTests(test.TestCase):
|
||||
self.assertEqual(training_continuation + 1,
|
||||
self.evaluate(root.save_counter))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testFreezing(self):
|
||||
with self.cached_session(use_gpu=True) as session:
|
||||
# Save an object-based checkpoint using a frozen saver
|
||||
directory = self.get_temp_dir()
|
||||
prefix = os.path.join(directory, "ckpt")
|
||||
v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
|
||||
checkpoint = checkpointable_utils.Checkpoint(v=v)
|
||||
self.evaluate(v.assign(3))
|
||||
# Create the save counter so assert_consumed doesn't complain about it not
|
||||
# existing in the checkpoint on restore.
|
||||
self.evaluate(checkpoint.save_counter.assign(12))
|
||||
saver = checkpointable_utils.frozen_saver(checkpoint)
|
||||
save_path = saver.save(session, prefix)
|
||||
self.evaluate(v.assign(10))
|
||||
# Use the frozen saver to restore the same object graph
|
||||
saver.restore(session, save_path)
|
||||
self.assertEqual(3, self.evaluate(v))
|
||||
|
||||
# Restore using another frozen saver on an identical object graph
|
||||
del v, checkpoint, saver
|
||||
v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
|
||||
checkpoint = checkpointable_utils.Checkpoint(v=v)
|
||||
saver = checkpointable_utils.frozen_saver(checkpoint)
|
||||
saver.restore(session, save_path)
|
||||
self.assertEqual(3, self.evaluate(v))
|
||||
|
||||
# Restore as an object-based checkpoint
|
||||
del v, checkpoint, saver
|
||||
checkpoint = checkpointable_utils.Checkpoint()
|
||||
status = checkpoint.restore(save_path)
|
||||
v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
|
||||
if context.executing_eagerly():
|
||||
self.assertEqual(12, self.evaluate(checkpoint.save_counter))
|
||||
self.assertEqual(0, self.evaluate(v))
|
||||
checkpoint.v = v
|
||||
status.assert_consumed().run_restore_ops()
|
||||
self.assertEqual(3, self.evaluate(v))
|
||||
self.assertEqual(12, self.evaluate(checkpoint.save_counter))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testCustomNumbering(self):
|
||||
directory = self.get_temp_dir()
|
||||
|
Loading…
Reference in New Issue
Block a user