diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD index d26932c1aae..f97f42a6593 100644 --- a/tensorflow/python/training/checkpointable/BUILD +++ b/tensorflow/python/training/checkpointable/BUILD @@ -152,7 +152,7 @@ py_test( "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", - "//tensorflow/python/eager:function", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", "//tensorflow/python/keras:engine", "//tensorflow/python/keras:layers", diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index f45f7445f13..85844393f38 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -549,13 +549,11 @@ def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): return slot_variables -def _serialize_checkpointables( - checkpointable_objects, node_ids, object_names, slot_variables, +def _add_attributes_to_object_graph( + checkpointable_objects, object_graph_proto, node_ids, object_names, saveables_cache, object_map): - """Name non-slot `Checkpointable`s and add them to `object_graph_proto`.""" - object_graph_proto = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - named_saveables = [] + """Create SaveableObjects and corresponding SerializedTensor protos.""" + named_saveable_objects = [] if saveables_cache is None: # No SaveableObject caching. Either we're executing eagerly, or building a # static save which is specialized to the current Python state. @@ -564,10 +562,9 @@ def _serialize_checkpointables( # If we are caching SaveableObjects, we need to build up a feed_dict with # functions computing volatile Python state to be saved with the checkpoint. feed_additions = {} - for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + for checkpoint_id, (checkpointable, object_proto) in enumerate( + zip(checkpointable_objects, object_graph_proto.nodes)): assert node_ids[checkpointable] == checkpoint_id - object_proto = object_graph_proto.nodes.add() - object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) object_name = object_names[checkpointable] if object_map: object_to_save = object_map.get(checkpointable, checkpointable) @@ -645,14 +642,24 @@ def _serialize_checkpointables( "value.") % (checkpointable, new_feed_key)) feed_additions.update(saveable_feed_dict) - named_saveables.append(saveable) + named_saveable_objects.append(saveable) + return named_saveable_objects, feed_additions + + +def _make_object_graph_proto(checkpointable_objects, node_ids, slot_variables): + """Name non-slot `Checkpointable`s and add them to `object_graph_proto`.""" + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + assert node_ids[checkpointable] == checkpoint_id + object_proto = object_graph_proto.nodes.add() + object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access child_proto = object_proto.children.add() child_proto.node_id = node_ids[child.ref] child_proto.local_name = child.name - - return named_saveables, object_graph_proto, feed_additions + return object_graph_proto def _serialize_gathered_objects( @@ -668,13 +675,18 @@ def _serialize_gathered_objects( checkpointable_objects=checkpointable_objects, node_ids=node_ids, object_names=object_names) - return _serialize_checkpointables( + object_graph_proto = _make_object_graph_proto( checkpointable_objects=checkpointable_objects, node_ids=node_ids, + slot_variables=slot_variables) + named_saveable_objects, feed_additions = _add_attributes_to_object_graph( + checkpointable_objects=checkpointable_objects, + object_graph_proto=object_graph_proto, + node_ids=node_ids, object_names=object_names, - slot_variables=slot_variables, saveables_cache=saveables_cache, object_map=object_map) + return named_saveable_objects, object_graph_proto, feed_additions def _serialize_object_graph(root_checkpointable, saveables_cache): @@ -716,6 +728,23 @@ def named_saveables(root_checkpointable): return _serialize_object_graph(root_checkpointable, None)[0] +def _find_objects(root_checkpointable): + """Find and number objects which are dependencies of `root_checkpointable`.""" + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_names = _ObjectIdentityDictionary() + for obj, path in path_to_root.items(): + object_names[obj] = _object_prefix_from_path(path) + node_ids = _ObjectIdentityDictionary() + for node_id, node in enumerate(checkpointable_objects): + node_ids[node] = node_id + slot_variables = _serialize_slot_variables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names) + return checkpointable_objects, node_ids, slot_variables + + def list_objects(root_checkpointable): """Traverse the object graph and list all accessible objects. @@ -730,23 +759,18 @@ def list_objects(root_checkpointable): Returns: A flat list of objects. """ - # TODO(allenl): Extract out gathering logic so the naming logic doesn't have - # to run. - checkpointable_objects, path_to_root = ( - _breadth_first_checkpointable_traversal(root_checkpointable)) - object_names = _ObjectIdentityDictionary() - for obj, path in path_to_root.items(): - object_names[obj] = _object_prefix_from_path(path) - node_ids = _ObjectIdentityDictionary() - for node_id, node in enumerate(checkpointable_objects): - node_ids[node] = node_id - _serialize_slot_variables( - checkpointable_objects=checkpointable_objects, - node_ids=node_ids, - object_names=object_names) + checkpointable_objects, _, _ = _find_objects(root_checkpointable) return checkpointable_objects +def make_object_graph_without_attributes(root_checkpointable): + """Construct a CheckpointableObjectGraph proto with no variable values.""" + checkpointable_objects, node_ids, slot_variables = _find_objects( + root_checkpointable) + return _make_object_graph_proto( + checkpointable_objects, node_ids, slot_variables) + + def gather_initializers(root_checkpointable): """Traverse the object graph and find initialization ops. diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py index 19955140123..de9cac08632 100644 --- a/tensorflow/python/training/checkpointable/util_test.py +++ b/tensorflow/python/training/checkpointable/util_test.py @@ -26,7 +26,7 @@ from tensorflow.python import pywrap_tensorflow from tensorflow.python.client import session as session_lib from tensorflow.python.eager import backprop from tensorflow.python.eager import context -from tensorflow.python.eager import function +from tensorflow.python.eager import def_function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -44,6 +44,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import adam from tensorflow.python.training import checkpoint_management +from tensorflow.python.training import momentum from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util from tensorflow.python.training.checkpointable import base @@ -198,6 +199,17 @@ class InterfaceTests(test.TestCase): with self.assertRaises(NotImplementedError): checkpoint_reversed.save(prefix) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def test_object_graph_no_attributes(self): + root = tracking.Checkpointable() + root.v = resource_variable_ops.ResourceVariable(1.) + root.opt = momentum.MomentumOptimizer(0.01, 0.5) + root.opt.minimize(root.v.read_value) + object_graph = checkpointable_utils.make_object_graph_without_attributes( + root) + # Four objects: Root, v, opt, and a slot variable for v + self.assertEqual(4, len(object_graph.nodes)) + class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject): @@ -632,7 +644,7 @@ class CheckpointingTests(test.TestCase): checkpoint_directory) status = root.restore(save_path=checkpoint_path) def train_fn(): - @function.defun + @def_function.function def _call_model(x): return model(x) with backprop.GradientTape() as tape: