Builds a static tf.train.Saver from a checkpointable object graph

Moves around some SaveableObjects to support a freeze method for python state saveables, and makes sure that the object graph proto is included in the frozen Saver.

This should be useful for embedding in SavedModels, where variables can be updated and the resulting checkpoints (saved from the SaverDef in the SavedModel) will still support Keras-style object-based restoration into Python programs (with better eager support and less fragile variable matching). This is also a step toward Estimators saving object-based checkpoints.

PiperOrigin-RevId: 212017296
This commit is contained in:
Allen Lavoie 2018-09-07 12:24:30 -07:00 committed by TensorFlower Gardener
parent bb096a0735
commit ca92311cbd
3 changed files with 235 additions and 63 deletions

View File

@ -17,11 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import functools
import json
import weakref
import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@ -91,7 +94,45 @@ class CheckpointInitialValue(ops.Tensor):
return self._checkpoint_position
class PythonStringStateSaveable(saveable_object.SaveableObject):
class NoRestoreSaveable(saveable_object.SaveableObject):
"""Embeds a tensor in a checkpoint with no restore ops."""
def __init__(self, tensor, name, dtype=None):
spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype)
super(NoRestoreSaveable, self).__init__(tensor, [spec], name)
def restore(self, restored_tensors, restored_shapes):
return control_flow_ops.no_op()
@six.add_metaclass(abc.ABCMeta)
class PythonStateSaveable(saveable_object.SaveableObject):
"""An interface for saving/restoring volatile Python state."""
@abc.abstractmethod
def feed_dict_additions(self):
"""When running a graph, indicates fresh state to feed.
Returns:
A dictionary mapping `Tensor`s to current Python state.
"""
pass
@abc.abstractmethod
def freeze(self):
"""Create a new `SaveableObject` which freezes current state as a constant.
Used when executing eagerly to embed the current state as a constant, or
when creating a static tf.train.Saver with the frozen current Python state.
Returns:
A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has
no Python state associated with it).
"""
pass
class PythonStringStateSaveable(PythonStateSaveable):
"""Saves Python state in a checkpoint."""
def __init__(self, name, state_callback, restore_callback=None):
@ -104,19 +145,26 @@ class PythonStringStateSaveable(saveable_object.SaveableObject):
restore_callback: A function taking a Python string, used to restore
state. Optional; defaults to doing nothing.
"""
self._state_callback = state_callback
self._restore_callback = restore_callback
if context.executing_eagerly():
self._save_string = (
lambda: constant_op.constant(state_callback(), dtype=dtypes.string))
else:
with ops.device("/cpu:0"):
self._save_string = constant_op.constant("", dtype=dtypes.string)
self.feed_dict_additions = (
lambda: {self._save_string: state_callback()})
spec = saveable_object.SaveSpec(
self._save_string, "", name, dtype=dtypes.string)
super(PythonStringStateSaveable, self).__init__(
self._save_string, [spec], name)
def feed_dict_additions(self):
"""When running a graph, indicates fresh state to feed."""
return {self._save_string: self._state_callback()}
def freeze(self):
"""Create a frozen `SaveableObject` which saves the current state."""
return NoRestoreSaveable(
tensor=self._state_callback,
dtype=dtypes.string,
name=self.name)
def python_restore(self, restored_strings):
"""Called to restore Python state."""
if self._restore_callback:
@ -309,7 +357,7 @@ class _CheckpointPosition(object):
if self._checkpoint.saveable_object_cache is not None:
self._checkpoint.saveable_object_cache.setdefault(
self.checkpointable, {})[serialized_tensor.name] = [saveable]
if isinstance(saveable, PythonStringStateSaveable):
if isinstance(saveable, PythonStateSaveable):
python_saveables.append(saveable)
else:
named_saveables[serialized_tensor.checkpoint_key] = saveable
@ -819,7 +867,7 @@ class CheckpointableBase(object):
def _state_callback():
dereferenced_self = weak_self()
if dereferenced_self:
return json.dumps(self,
return json.dumps(dereferenced_self,
default=serialization.get_json_type,
sort_keys=True).encode("utf8")
else:

View File

@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
@ -557,7 +556,14 @@ def _serialize_checkpointables(
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
named_saveables = []
feed_additions = {}
if saveables_cache is None:
# No SaveableObject caching. Either we're executing eagerly, or building a
# static save which is specialized to the current Python state.
feed_additions = None
else:
# If we are caching SaveableObjects, we need to build up a feed_dict with
# functions computing volatile Python state to be saved with the checkpoint.
feed_additions = {}
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
@ -616,18 +622,25 @@ def _serialize_checkpointables(
for saveable in saveables:
if hasattr(saveable, "full_name"):
attribute.full_name = saveable.full_name
saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None)
if saveable_feed_dict_fn is not None:
saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable
for new_feed_key in saveable_feed_dict.keys():
if new_feed_key in feed_additions:
raise AssertionError(
("The object %s tried to feed a value for the Tensor %s "
"when saving, but another object is already feeding a "
"value.")
% (checkpointable, new_feed_key))
feed_additions.update(saveable_feed_dict)
named_saveables.extend(saveables)
if isinstance(saveable, base.PythonStateSaveable):
if feed_additions is None:
assert saveables_cache is None
# If we're not caching saveables, then we're either executing
# eagerly or building a static save/restore (e.g. for a
# SavedModel). In either case, we should embed the current Python
# state in the graph rather than relying on a feed dict.
saveable = saveable.freeze()
else:
saveable_feed_dict = saveable.feed_dict_additions()
for new_feed_key in saveable_feed_dict.keys():
if new_feed_key in feed_additions:
raise AssertionError(
("The object %s tried to feed a value for the Tensor %s "
"when saving, but another object is already feeding a "
"value.")
% (checkpointable, new_feed_key))
feed_additions.update(saveable_feed_dict)
named_saveables.append(saveable)
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
@ -827,16 +840,6 @@ def capture_dependencies(template):
yield
class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
def __init__(self, tensor, name):
spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name)
super(_NoRestoreSaveable, self).__init__(tensor, [spec], name)
def restore(self, restored_tensors, restored_shapes):
return control_flow_ops.no_op()
class _LoadStatus(object):
"""Abstract base for load status callbacks."""
@ -1241,6 +1244,78 @@ class CheckpointableSaver(object):
else:
return self._root_checkpointable_ref
def _gather_saveables(
self, object_graph_tensor=None, saveable_object_cache=None):
"""Wraps _serialize_object_graph to include the object graph proto."""
assert ((object_graph_tensor is None and saveable_object_cache is None)
or (object_graph_tensor is not None
and saveable_object_cache is not None))
(named_saveable_objects, graph_proto,
feed_additions) = _serialize_object_graph(
self._root_checkpointable,
saveables_cache=saveable_object_cache)
if object_graph_tensor is None:
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
graph_proto.SerializeToString(), dtype=dtypes.string)
else:
feed_additions.update(
{object_graph_tensor: graph_proto.SerializeToString()})
assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects
named_saveable_objects.append(
base.NoRestoreSaveable(
tensor=object_graph_tensor,
name=base.OBJECT_GRAPH_PROTO_KEY))
return named_saveable_objects, graph_proto, feed_additions
def freeze(self):
"""Creates a `tf.train.Saver` with the current object graph frozen."""
named_saveable_objects, _, _ = self._gather_saveables(
object_graph_tensor=None, saveable_object_cache=None)
return saver_lib.Saver(
var_list=named_saveable_objects, max_to_keep=None)
def _prepare_save(self,
object_graph_tensor=None,
saveable_object_cache=None):
"""Create or retrieve save ops.
When graph building, `saveable_object_cache` will typically be non-`None`,
meaning that existing `SaveableObject`s are re-used across calls to
`_prepare_save` even if the object graph has grown. This avoids
unnecessarily re-creating save ops.
Args:
object_graph_tensor: A `Tensor` to which the current object graph will be
fed.
saveable_object_cache: A dictionary; if specified, used to cache
`SaveableObject`s.
Returns:
A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s
to feed when running save ops. The feed dict contains the current object
graph and any Python state to be saved in the checkpoint.
"""
(named_saveable_objects, graph_proto,
feed_additions) = self._gather_saveables(
object_graph_tensor=object_graph_tensor,
saveable_object_cache=saveable_object_cache)
if (self._last_save_object_graph != graph_proto
# When executing eagerly, we need to re-create SaveableObjects each time
# save() is called so they pick up new Tensors passed to their
# constructors. That means the Saver needs to be copied with a new
# var_list.
or context.executing_eagerly()):
if self._last_save_object_graph is not None:
self._last_save_saver = _copy_saver_with_new_var_list(
old_saver=self._last_save_saver,
new_var_list=named_saveable_objects)
else:
self._last_save_saver = saver_lib.Saver(
var_list=named_saveable_objects, max_to_keep=None)
self._last_save_object_graph = graph_proto
return self._last_save_saver, feed_additions
def save(self, file_prefix, checkpoint_number=None, session=None):
"""Save a training checkpoint.
@ -1263,44 +1338,29 @@ class CheckpointableSaver(object):
Returns:
The full path to the checkpoint.
"""
named_variables, graph_proto, feed_additions = _serialize_object_graph(
self._root_checkpointable,
saveables_cache=self._saveable_object_cache)
if not context.executing_eagerly():
if session is None:
session = ops.get_default_session()
feed_additions = {}
graph_building = not context.executing_eagerly()
if graph_building:
if self._object_graph_feed_tensor is None:
with ops.device("/cpu:0"):
self._object_graph_feed_tensor = constant_op.constant(
"", dtype=dtypes.string)
object_graph_tensor = self._object_graph_feed_tensor
feed_additions.update(
{object_graph_tensor: graph_proto.SerializeToString()})
else:
object_graph_tensor = None
saver, new_feed_additions = self._prepare_save(
object_graph_tensor=object_graph_tensor,
saveable_object_cache=self._saveable_object_cache)
if new_feed_additions:
feed_additions.update(new_feed_additions)
if not graph_building:
session = None
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
graph_proto.SerializeToString(), dtype=dtypes.string)
assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables
named_variables.append(
_NoRestoreSaveable(
tensor=object_graph_tensor,
name=base.OBJECT_GRAPH_PROTO_KEY))
if (self._last_save_object_graph != graph_proto
# When executing eagerly, we need to re-create SaveableObjects each time
# save() is called so they pick up new Tensors passed to their
# constructors. That means the Saver needs to be copied with a new
# var_list.
or context.executing_eagerly()):
if self._last_save_object_graph is not None:
self._last_save_saver = _copy_saver_with_new_var_list(
old_saver=self._last_save_saver, new_var_list=named_variables)
else:
self._last_save_saver = saver_lib.Saver(
var_list=named_variables, max_to_keep=None)
self._last_save_object_graph = graph_proto
elif session is None:
session = ops.get_default_session()
with ops.device("/cpu:0"):
save_path = self._last_save_saver.save(
save_path = saver.save(
sess=_SessionWithFeedDictAdditions(
session=session, feed_additions=feed_additions),
save_path=file_prefix,
@ -1422,6 +1482,30 @@ class CheckpointableSaver(object):
return load_status
def frozen_saver(root_checkpointable):
"""Creates a static `tf.train.Saver` from a checkpointable object.
The returned `Saver` saves object-based checkpoints, but these checkpoints
will no longer reflect structural changes to the object graph, only changes to
the values of `Variable`s added as dependencies of the root object before
`freeze` was called.
`restore` works on the returned `Saver`, but requires that the object graph of
the checkpoint being loaded exactly matches the object graph when `freeze` was
called. This is in contrast the object-based restore performed by
`tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's
object graph and the current Python object graph.
Args:
root_checkpointable: A checkpointable object to save.
Returns:
A `tf.train.Saver` which saves object-based checkpoints for the object graph
frozen at the time `frozen_saver` was called.
"""
return CheckpointableSaver(root_checkpointable).freeze()
@tf_export("train.Checkpoint")
class Checkpoint(tracking.Checkpointable):
"""Groups checkpointable objects, saving and restoring them.

View File

@ -559,6 +559,46 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(training_continuation + 1,
self.evaluate(root.save_counter))
@test_util.run_in_graph_and_eager_modes
def testFreezing(self):
with self.cached_session(use_gpu=True) as session:
# Save an object-based checkpoint using a frozen saver
directory = self.get_temp_dir()
prefix = os.path.join(directory, "ckpt")
v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
checkpoint = checkpointable_utils.Checkpoint(v=v)
self.evaluate(v.assign(3))
# Create the save counter so assert_consumed doesn't complain about it not
# existing in the checkpoint on restore.
self.evaluate(checkpoint.save_counter.assign(12))
saver = checkpointable_utils.frozen_saver(checkpoint)
save_path = saver.save(session, prefix)
self.evaluate(v.assign(10))
# Use the frozen saver to restore the same object graph
saver.restore(session, save_path)
self.assertEqual(3, self.evaluate(v))
# Restore using another frozen saver on an identical object graph
del v, checkpoint, saver
v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
checkpoint = checkpointable_utils.Checkpoint(v=v)
saver = checkpointable_utils.frozen_saver(checkpoint)
saver.restore(session, save_path)
self.assertEqual(3, self.evaluate(v))
# Restore as an object-based checkpoint
del v, checkpoint, saver
checkpoint = checkpointable_utils.Checkpoint()
status = checkpoint.restore(save_path)
v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64)
if context.executing_eagerly():
self.assertEqual(12, self.evaluate(checkpoint.save_counter))
self.assertEqual(0, self.evaluate(v))
checkpoint.v = v
status.assert_consumed().run_restore_ops()
self.assertEqual(3, self.evaluate(v))
self.assertEqual(12, self.evaluate(checkpoint.save_counter))
@test_util.run_in_graph_and_eager_modes
def testCustomNumbering(self):
directory = self.get_temp_dir()