diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index 9189d8f3e8f..095a90ddd4f 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -17,11 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections import functools import json import weakref +import six + from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -91,7 +94,45 @@ class CheckpointInitialValue(ops.Tensor): return self._checkpoint_position -class PythonStringStateSaveable(saveable_object.SaveableObject): +class NoRestoreSaveable(saveable_object.SaveableObject): + """Embeds a tensor in a checkpoint with no restore ops.""" + + def __init__(self, tensor, name, dtype=None): + spec = saveable_object.SaveSpec(tensor, "", name, dtype=dtype) + super(NoRestoreSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + return control_flow_ops.no_op() + + +@six.add_metaclass(abc.ABCMeta) +class PythonStateSaveable(saveable_object.SaveableObject): + """An interface for saving/restoring volatile Python state.""" + + @abc.abstractmethod + def feed_dict_additions(self): + """When running a graph, indicates fresh state to feed. + + Returns: + A dictionary mapping `Tensor`s to current Python state. + """ + pass + + @abc.abstractmethod + def freeze(self): + """Create a new `SaveableObject` which freezes current state as a constant. + + Used when executing eagerly to embed the current state as a constant, or + when creating a static tf.train.Saver with the frozen current Python state. + + Returns: + A `SaveableObject` which is not a `PythonStateSaveable` instance (i.e. has + no Python state associated with it). + """ + pass + + +class PythonStringStateSaveable(PythonStateSaveable): """Saves Python state in a checkpoint.""" def __init__(self, name, state_callback, restore_callback=None): @@ -104,19 +145,26 @@ class PythonStringStateSaveable(saveable_object.SaveableObject): restore_callback: A function taking a Python string, used to restore state. Optional; defaults to doing nothing. """ + self._state_callback = state_callback self._restore_callback = restore_callback - if context.executing_eagerly(): - self._save_string = ( - lambda: constant_op.constant(state_callback(), dtype=dtypes.string)) - else: + with ops.device("/cpu:0"): self._save_string = constant_op.constant("", dtype=dtypes.string) - self.feed_dict_additions = ( - lambda: {self._save_string: state_callback()}) spec = saveable_object.SaveSpec( self._save_string, "", name, dtype=dtypes.string) super(PythonStringStateSaveable, self).__init__( self._save_string, [spec], name) + def feed_dict_additions(self): + """When running a graph, indicates fresh state to feed.""" + return {self._save_string: self._state_callback()} + + def freeze(self): + """Create a frozen `SaveableObject` which saves the current state.""" + return NoRestoreSaveable( + tensor=self._state_callback, + dtype=dtypes.string, + name=self.name) + def python_restore(self, restored_strings): """Called to restore Python state.""" if self._restore_callback: @@ -309,7 +357,7 @@ class _CheckpointPosition(object): if self._checkpoint.saveable_object_cache is not None: self._checkpoint.saveable_object_cache.setdefault( self.checkpointable, {})[serialized_tensor.name] = [saveable] - if isinstance(saveable, PythonStringStateSaveable): + if isinstance(saveable, PythonStateSaveable): python_saveables.append(saveable) else: named_saveables[serialized_tensor.checkpoint_key] = saveable @@ -819,7 +867,7 @@ class CheckpointableBase(object): def _state_callback(): dereferenced_self = weak_self() if dereferenced_self: - return json.dumps(self, + return json.dumps(dereferenced_self, default=serialization.get_json_type, sort_keys=True).encode("utf8") else: diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index 13dddd37ac7..56c4043d9d4 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -32,7 +32,6 @@ from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_io_ops as io_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope @@ -557,7 +556,14 @@ def _serialize_checkpointables( object_graph_proto = ( checkpointable_object_graph_pb2.CheckpointableObjectGraph()) named_saveables = [] - feed_additions = {} + if saveables_cache is None: + # No SaveableObject caching. Either we're executing eagerly, or building a + # static save which is specialized to the current Python state. + feed_additions = None + else: + # If we are caching SaveableObjects, we need to build up a feed_dict with + # functions computing volatile Python state to be saved with the checkpoint. + feed_additions = {} for checkpoint_id, checkpointable in enumerate(checkpointable_objects): assert node_ids[checkpointable] == checkpoint_id object_proto = object_graph_proto.nodes.add() @@ -616,18 +622,25 @@ def _serialize_checkpointables( for saveable in saveables: if hasattr(saveable, "full_name"): attribute.full_name = saveable.full_name - saveable_feed_dict_fn = getattr(saveable, "feed_dict_additions", None) - if saveable_feed_dict_fn is not None: - saveable_feed_dict = saveable_feed_dict_fn() # pylint: disable=not-callable - for new_feed_key in saveable_feed_dict.keys(): - if new_feed_key in feed_additions: - raise AssertionError( - ("The object %s tried to feed a value for the Tensor %s " - "when saving, but another object is already feeding a " - "value.") - % (checkpointable, new_feed_key)) - feed_additions.update(saveable_feed_dict) - named_saveables.extend(saveables) + if isinstance(saveable, base.PythonStateSaveable): + if feed_additions is None: + assert saveables_cache is None + # If we're not caching saveables, then we're either executing + # eagerly or building a static save/restore (e.g. for a + # SavedModel). In either case, we should embed the current Python + # state in the graph rather than relying on a feed dict. + saveable = saveable.freeze() + else: + saveable_feed_dict = saveable.feed_dict_additions() + for new_feed_key in saveable_feed_dict.keys(): + if new_feed_key in feed_additions: + raise AssertionError( + ("The object %s tried to feed a value for the Tensor %s " + "when saving, but another object is already feeding a " + "value.") + % (checkpointable, new_feed_key)) + feed_additions.update(saveable_feed_dict) + named_saveables.append(saveable) for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access child_proto = object_proto.children.add() @@ -827,16 +840,6 @@ def capture_dependencies(template): yield -class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): - - def __init__(self, tensor, name): - spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name) - super(_NoRestoreSaveable, self).__init__(tensor, [spec], name) - - def restore(self, restored_tensors, restored_shapes): - return control_flow_ops.no_op() - - class _LoadStatus(object): """Abstract base for load status callbacks.""" @@ -1241,6 +1244,78 @@ class CheckpointableSaver(object): else: return self._root_checkpointable_ref + def _gather_saveables( + self, object_graph_tensor=None, saveable_object_cache=None): + """Wraps _serialize_object_graph to include the object graph proto.""" + assert ((object_graph_tensor is None and saveable_object_cache is None) + or (object_graph_tensor is not None + and saveable_object_cache is not None)) + (named_saveable_objects, graph_proto, + feed_additions) = _serialize_object_graph( + self._root_checkpointable, + saveables_cache=saveable_object_cache) + if object_graph_tensor is None: + with ops.device("/cpu:0"): + object_graph_tensor = constant_op.constant( + graph_proto.SerializeToString(), dtype=dtypes.string) + else: + feed_additions.update( + {object_graph_tensor: graph_proto.SerializeToString()}) + assert base.OBJECT_GRAPH_PROTO_KEY not in named_saveable_objects + named_saveable_objects.append( + base.NoRestoreSaveable( + tensor=object_graph_tensor, + name=base.OBJECT_GRAPH_PROTO_KEY)) + return named_saveable_objects, graph_proto, feed_additions + + def freeze(self): + """Creates a `tf.train.Saver` with the current object graph frozen.""" + named_saveable_objects, _, _ = self._gather_saveables( + object_graph_tensor=None, saveable_object_cache=None) + return saver_lib.Saver( + var_list=named_saveable_objects, max_to_keep=None) + + def _prepare_save(self, + object_graph_tensor=None, + saveable_object_cache=None): + """Create or retrieve save ops. + + When graph building, `saveable_object_cache` will typically be non-`None`, + meaning that existing `SaveableObject`s are re-used across calls to + `_prepare_save` even if the object graph has grown. This avoids + unnecessarily re-creating save ops. + + Args: + object_graph_tensor: A `Tensor` to which the current object graph will be + fed. + saveable_object_cache: A dictionary; if specified, used to cache + `SaveableObject`s. + + Returns: + A two-element tuple with a `tf.train.Saver` and a feed_dict of `Tensor`s + to feed when running save ops. The feed dict contains the current object + graph and any Python state to be saved in the checkpoint. + """ + (named_saveable_objects, graph_proto, + feed_additions) = self._gather_saveables( + object_graph_tensor=object_graph_tensor, + saveable_object_cache=saveable_object_cache) + if (self._last_save_object_graph != graph_proto + # When executing eagerly, we need to re-create SaveableObjects each time + # save() is called so they pick up new Tensors passed to their + # constructors. That means the Saver needs to be copied with a new + # var_list. + or context.executing_eagerly()): + if self._last_save_object_graph is not None: + self._last_save_saver = _copy_saver_with_new_var_list( + old_saver=self._last_save_saver, + new_var_list=named_saveable_objects) + else: + self._last_save_saver = saver_lib.Saver( + var_list=named_saveable_objects, max_to_keep=None) + self._last_save_object_graph = graph_proto + return self._last_save_saver, feed_additions + def save(self, file_prefix, checkpoint_number=None, session=None): """Save a training checkpoint. @@ -1263,44 +1338,29 @@ class CheckpointableSaver(object): Returns: The full path to the checkpoint. """ - named_variables, graph_proto, feed_additions = _serialize_object_graph( - self._root_checkpointable, - saveables_cache=self._saveable_object_cache) - if not context.executing_eagerly(): - if session is None: - session = ops.get_default_session() + feed_additions = {} + graph_building = not context.executing_eagerly() + if graph_building: if self._object_graph_feed_tensor is None: with ops.device("/cpu:0"): self._object_graph_feed_tensor = constant_op.constant( "", dtype=dtypes.string) object_graph_tensor = self._object_graph_feed_tensor - feed_additions.update( - {object_graph_tensor: graph_proto.SerializeToString()}) else: + object_graph_tensor = None + + saver, new_feed_additions = self._prepare_save( + object_graph_tensor=object_graph_tensor, + saveable_object_cache=self._saveable_object_cache) + if new_feed_additions: + feed_additions.update(new_feed_additions) + if not graph_building: session = None - with ops.device("/cpu:0"): - object_graph_tensor = constant_op.constant( - graph_proto.SerializeToString(), dtype=dtypes.string) - assert base.OBJECT_GRAPH_PROTO_KEY not in named_variables - named_variables.append( - _NoRestoreSaveable( - tensor=object_graph_tensor, - name=base.OBJECT_GRAPH_PROTO_KEY)) - if (self._last_save_object_graph != graph_proto - # When executing eagerly, we need to re-create SaveableObjects each time - # save() is called so they pick up new Tensors passed to their - # constructors. That means the Saver needs to be copied with a new - # var_list. - or context.executing_eagerly()): - if self._last_save_object_graph is not None: - self._last_save_saver = _copy_saver_with_new_var_list( - old_saver=self._last_save_saver, new_var_list=named_variables) - else: - self._last_save_saver = saver_lib.Saver( - var_list=named_variables, max_to_keep=None) - self._last_save_object_graph = graph_proto + elif session is None: + session = ops.get_default_session() + with ops.device("/cpu:0"): - save_path = self._last_save_saver.save( + save_path = saver.save( sess=_SessionWithFeedDictAdditions( session=session, feed_additions=feed_additions), save_path=file_prefix, @@ -1422,6 +1482,30 @@ class CheckpointableSaver(object): return load_status +def frozen_saver(root_checkpointable): + """Creates a static `tf.train.Saver` from a checkpointable object. + + The returned `Saver` saves object-based checkpoints, but these checkpoints + will no longer reflect structural changes to the object graph, only changes to + the values of `Variable`s added as dependencies of the root object before + `freeze` was called. + + `restore` works on the returned `Saver`, but requires that the object graph of + the checkpoint being loaded exactly matches the object graph when `freeze` was + called. This is in contrast the object-based restore performed by + `tf.train.Checkpoint` which attempts a fuzzy matching between a checkpoint's + object graph and the current Python object graph. + + Args: + root_checkpointable: A checkpointable object to save. + + Returns: + A `tf.train.Saver` which saves object-based checkpoints for the object graph + frozen at the time `frozen_saver` was called. + """ + return CheckpointableSaver(root_checkpointable).freeze() + + @tf_export("train.Checkpoint") class Checkpoint(tracking.Checkpointable): """Groups checkpointable objects, saving and restoring them. diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index bef4bf2a16a..0d32d214267 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -559,6 +559,46 @@ class CheckpointingTests(test.TestCase): self.assertEqual(training_continuation + 1, self.evaluate(root.save_counter)) + @test_util.run_in_graph_and_eager_modes + def testFreezing(self): + with self.cached_session(use_gpu=True) as session: + # Save an object-based checkpoint using a frozen saver + directory = self.get_temp_dir() + prefix = os.path.join(directory, "ckpt") + v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) + checkpoint = checkpointable_utils.Checkpoint(v=v) + self.evaluate(v.assign(3)) + # Create the save counter so assert_consumed doesn't complain about it not + # existing in the checkpoint on restore. + self.evaluate(checkpoint.save_counter.assign(12)) + saver = checkpointable_utils.frozen_saver(checkpoint) + save_path = saver.save(session, prefix) + self.evaluate(v.assign(10)) + # Use the frozen saver to restore the same object graph + saver.restore(session, save_path) + self.assertEqual(3, self.evaluate(v)) + + # Restore using another frozen saver on an identical object graph + del v, checkpoint, saver + v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) + checkpoint = checkpointable_utils.Checkpoint(v=v) + saver = checkpointable_utils.frozen_saver(checkpoint) + saver.restore(session, save_path) + self.assertEqual(3, self.evaluate(v)) + + # Restore as an object-based checkpoint + del v, checkpoint, saver + checkpoint = checkpointable_utils.Checkpoint() + status = checkpoint.restore(save_path) + v = resource_variable_ops.ResourceVariable(0, dtype=dtypes.int64) + if context.executing_eagerly(): + self.assertEqual(12, self.evaluate(checkpoint.save_counter)) + self.assertEqual(0, self.evaluate(v)) + checkpoint.v = v + status.assert_consumed().run_restore_ops() + self.assertEqual(3, self.evaluate(v)) + self.assertEqual(12, self.evaluate(checkpoint.save_counter)) + @test_util.run_in_graph_and_eager_modes def testCustomNumbering(self): directory = self.get_temp_dir()