Refactor to allow creation of an object graph proto with no variable values
Will be useful for creating a SavedModel object proto PiperOrigin-RevId: 222430965
This commit is contained in:
parent
95d7bbb2fc
commit
c16394423c
@ -152,7 +152,7 @@ py_test(
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:function",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras:engine",
|
||||
"//tensorflow/python/keras:layers",
|
||||
|
@ -549,13 +549,11 @@ def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
|
||||
return slot_variables
|
||||
|
||||
|
||||
def _serialize_checkpointables(
|
||||
checkpointable_objects, node_ids, object_names, slot_variables,
|
||||
def _add_attributes_to_object_graph(
|
||||
checkpointable_objects, object_graph_proto, node_ids, object_names,
|
||||
saveables_cache, object_map):
|
||||
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
|
||||
object_graph_proto = (
|
||||
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
|
||||
named_saveables = []
|
||||
"""Create SaveableObjects and corresponding SerializedTensor protos."""
|
||||
named_saveable_objects = []
|
||||
if saveables_cache is None:
|
||||
# No SaveableObject caching. Either we're executing eagerly, or building a
|
||||
# static save which is specialized to the current Python state.
|
||||
@ -564,10 +562,9 @@ def _serialize_checkpointables(
|
||||
# If we are caching SaveableObjects, we need to build up a feed_dict with
|
||||
# functions computing volatile Python state to be saved with the checkpoint.
|
||||
feed_additions = {}
|
||||
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
|
||||
for checkpoint_id, (checkpointable, object_proto) in enumerate(
|
||||
zip(checkpointable_objects, object_graph_proto.nodes)):
|
||||
assert node_ids[checkpointable] == checkpoint_id
|
||||
object_proto = object_graph_proto.nodes.add()
|
||||
object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
|
||||
object_name = object_names[checkpointable]
|
||||
if object_map:
|
||||
object_to_save = object_map.get(checkpointable, checkpointable)
|
||||
@ -645,14 +642,24 @@ def _serialize_checkpointables(
|
||||
"value.")
|
||||
% (checkpointable, new_feed_key))
|
||||
feed_additions.update(saveable_feed_dict)
|
||||
named_saveables.append(saveable)
|
||||
named_saveable_objects.append(saveable)
|
||||
|
||||
return named_saveable_objects, feed_additions
|
||||
|
||||
|
||||
def _make_object_graph_proto(checkpointable_objects, node_ids, slot_variables):
|
||||
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
|
||||
object_graph_proto = (
|
||||
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
|
||||
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
|
||||
assert node_ids[checkpointable] == checkpoint_id
|
||||
object_proto = object_graph_proto.nodes.add()
|
||||
object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
|
||||
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
|
||||
child_proto = object_proto.children.add()
|
||||
child_proto.node_id = node_ids[child.ref]
|
||||
child_proto.local_name = child.name
|
||||
|
||||
return named_saveables, object_graph_proto, feed_additions
|
||||
return object_graph_proto
|
||||
|
||||
|
||||
def _serialize_gathered_objects(
|
||||
@ -668,13 +675,18 @@ def _serialize_gathered_objects(
|
||||
checkpointable_objects=checkpointable_objects,
|
||||
node_ids=node_ids,
|
||||
object_names=object_names)
|
||||
return _serialize_checkpointables(
|
||||
object_graph_proto = _make_object_graph_proto(
|
||||
checkpointable_objects=checkpointable_objects,
|
||||
node_ids=node_ids,
|
||||
slot_variables=slot_variables)
|
||||
named_saveable_objects, feed_additions = _add_attributes_to_object_graph(
|
||||
checkpointable_objects=checkpointable_objects,
|
||||
object_graph_proto=object_graph_proto,
|
||||
node_ids=node_ids,
|
||||
object_names=object_names,
|
||||
slot_variables=slot_variables,
|
||||
saveables_cache=saveables_cache,
|
||||
object_map=object_map)
|
||||
return named_saveable_objects, object_graph_proto, feed_additions
|
||||
|
||||
|
||||
def _serialize_object_graph(root_checkpointable, saveables_cache):
|
||||
@ -716,6 +728,23 @@ def named_saveables(root_checkpointable):
|
||||
return _serialize_object_graph(root_checkpointable, None)[0]
|
||||
|
||||
|
||||
def _find_objects(root_checkpointable):
|
||||
"""Find and number objects which are dependencies of `root_checkpointable`."""
|
||||
checkpointable_objects, path_to_root = (
|
||||
_breadth_first_checkpointable_traversal(root_checkpointable))
|
||||
object_names = _ObjectIdentityDictionary()
|
||||
for obj, path in path_to_root.items():
|
||||
object_names[obj] = _object_prefix_from_path(path)
|
||||
node_ids = _ObjectIdentityDictionary()
|
||||
for node_id, node in enumerate(checkpointable_objects):
|
||||
node_ids[node] = node_id
|
||||
slot_variables = _serialize_slot_variables(
|
||||
checkpointable_objects=checkpointable_objects,
|
||||
node_ids=node_ids,
|
||||
object_names=object_names)
|
||||
return checkpointable_objects, node_ids, slot_variables
|
||||
|
||||
|
||||
def list_objects(root_checkpointable):
|
||||
"""Traverse the object graph and list all accessible objects.
|
||||
|
||||
@ -730,23 +759,18 @@ def list_objects(root_checkpointable):
|
||||
Returns:
|
||||
A flat list of objects.
|
||||
"""
|
||||
# TODO(allenl): Extract out gathering logic so the naming logic doesn't have
|
||||
# to run.
|
||||
checkpointable_objects, path_to_root = (
|
||||
_breadth_first_checkpointable_traversal(root_checkpointable))
|
||||
object_names = _ObjectIdentityDictionary()
|
||||
for obj, path in path_to_root.items():
|
||||
object_names[obj] = _object_prefix_from_path(path)
|
||||
node_ids = _ObjectIdentityDictionary()
|
||||
for node_id, node in enumerate(checkpointable_objects):
|
||||
node_ids[node] = node_id
|
||||
_serialize_slot_variables(
|
||||
checkpointable_objects=checkpointable_objects,
|
||||
node_ids=node_ids,
|
||||
object_names=object_names)
|
||||
checkpointable_objects, _, _ = _find_objects(root_checkpointable)
|
||||
return checkpointable_objects
|
||||
|
||||
|
||||
def make_object_graph_without_attributes(root_checkpointable):
|
||||
"""Construct a CheckpointableObjectGraph proto with no variable values."""
|
||||
checkpointable_objects, node_ids, slot_variables = _find_objects(
|
||||
root_checkpointable)
|
||||
return _make_object_graph_proto(
|
||||
checkpointable_objects, node_ids, slot_variables)
|
||||
|
||||
|
||||
def gather_initializers(root_checkpointable):
|
||||
"""Traverse the object graph and find initialization ops.
|
||||
|
||||
|
@ -26,7 +26,7 @@ from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.client import session as session_lib
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -44,6 +44,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import adam
|
||||
from tensorflow.python.training import checkpoint_management
|
||||
from tensorflow.python.training import momentum
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.training.checkpointable import base
|
||||
@ -198,6 +199,17 @@ class InterfaceTests(test.TestCase):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
checkpoint_reversed.save(prefix)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def test_object_graph_no_attributes(self):
|
||||
root = tracking.Checkpointable()
|
||||
root.v = resource_variable_ops.ResourceVariable(1.)
|
||||
root.opt = momentum.MomentumOptimizer(0.01, 0.5)
|
||||
root.opt.minimize(root.v.read_value)
|
||||
object_graph = checkpointable_utils.make_object_graph_without_attributes(
|
||||
root)
|
||||
# Four objects: Root, v, opt, and a slot variable for v
|
||||
self.assertEqual(4, len(object_graph.nodes))
|
||||
|
||||
|
||||
class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
||||
|
||||
@ -632,7 +644,7 @@ class CheckpointingTests(test.TestCase):
|
||||
checkpoint_directory)
|
||||
status = root.restore(save_path=checkpoint_path)
|
||||
def train_fn():
|
||||
@function.defun
|
||||
@def_function.function
|
||||
def _call_model(x):
|
||||
return model(x)
|
||||
with backprop.GradientTape() as tape:
|
||||
|
Loading…
Reference in New Issue
Block a user