Refactor to allow creation of an object graph proto with no variable values
Will be useful for creating a SavedModel object proto PiperOrigin-RevId: 222430965
This commit is contained in:
parent
95d7bbb2fc
commit
c16394423c
@ -152,7 +152,7 @@ py_test(
|
|||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python/eager:backprop",
|
"//tensorflow/python/eager:backprop",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
"//tensorflow/python/eager:function",
|
"//tensorflow/python/eager:def_function",
|
||||||
"//tensorflow/python/eager:test",
|
"//tensorflow/python/eager:test",
|
||||||
"//tensorflow/python/keras:engine",
|
"//tensorflow/python/keras:engine",
|
||||||
"//tensorflow/python/keras:layers",
|
"//tensorflow/python/keras:layers",
|
||||||
|
@ -549,13 +549,11 @@ def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
|
|||||||
return slot_variables
|
return slot_variables
|
||||||
|
|
||||||
|
|
||||||
def _serialize_checkpointables(
|
def _add_attributes_to_object_graph(
|
||||||
checkpointable_objects, node_ids, object_names, slot_variables,
|
checkpointable_objects, object_graph_proto, node_ids, object_names,
|
||||||
saveables_cache, object_map):
|
saveables_cache, object_map):
|
||||||
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
|
"""Create SaveableObjects and corresponding SerializedTensor protos."""
|
||||||
object_graph_proto = (
|
named_saveable_objects = []
|
||||||
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
|
|
||||||
named_saveables = []
|
|
||||||
if saveables_cache is None:
|
if saveables_cache is None:
|
||||||
# No SaveableObject caching. Either we're executing eagerly, or building a
|
# No SaveableObject caching. Either we're executing eagerly, or building a
|
||||||
# static save which is specialized to the current Python state.
|
# static save which is specialized to the current Python state.
|
||||||
@ -564,10 +562,9 @@ def _serialize_checkpointables(
|
|||||||
# If we are caching SaveableObjects, we need to build up a feed_dict with
|
# If we are caching SaveableObjects, we need to build up a feed_dict with
|
||||||
# functions computing volatile Python state to be saved with the checkpoint.
|
# functions computing volatile Python state to be saved with the checkpoint.
|
||||||
feed_additions = {}
|
feed_additions = {}
|
||||||
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
|
for checkpoint_id, (checkpointable, object_proto) in enumerate(
|
||||||
|
zip(checkpointable_objects, object_graph_proto.nodes)):
|
||||||
assert node_ids[checkpointable] == checkpoint_id
|
assert node_ids[checkpointable] == checkpoint_id
|
||||||
object_proto = object_graph_proto.nodes.add()
|
|
||||||
object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
|
|
||||||
object_name = object_names[checkpointable]
|
object_name = object_names[checkpointable]
|
||||||
if object_map:
|
if object_map:
|
||||||
object_to_save = object_map.get(checkpointable, checkpointable)
|
object_to_save = object_map.get(checkpointable, checkpointable)
|
||||||
@ -645,14 +642,24 @@ def _serialize_checkpointables(
|
|||||||
"value.")
|
"value.")
|
||||||
% (checkpointable, new_feed_key))
|
% (checkpointable, new_feed_key))
|
||||||
feed_additions.update(saveable_feed_dict)
|
feed_additions.update(saveable_feed_dict)
|
||||||
named_saveables.append(saveable)
|
named_saveable_objects.append(saveable)
|
||||||
|
|
||||||
|
return named_saveable_objects, feed_additions
|
||||||
|
|
||||||
|
|
||||||
|
def _make_object_graph_proto(checkpointable_objects, node_ids, slot_variables):
|
||||||
|
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
|
||||||
|
object_graph_proto = (
|
||||||
|
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
|
||||||
|
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
|
||||||
|
assert node_ids[checkpointable] == checkpoint_id
|
||||||
|
object_proto = object_graph_proto.nodes.add()
|
||||||
|
object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
|
||||||
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
|
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
|
||||||
child_proto = object_proto.children.add()
|
child_proto = object_proto.children.add()
|
||||||
child_proto.node_id = node_ids[child.ref]
|
child_proto.node_id = node_ids[child.ref]
|
||||||
child_proto.local_name = child.name
|
child_proto.local_name = child.name
|
||||||
|
return object_graph_proto
|
||||||
return named_saveables, object_graph_proto, feed_additions
|
|
||||||
|
|
||||||
|
|
||||||
def _serialize_gathered_objects(
|
def _serialize_gathered_objects(
|
||||||
@ -668,13 +675,18 @@ def _serialize_gathered_objects(
|
|||||||
checkpointable_objects=checkpointable_objects,
|
checkpointable_objects=checkpointable_objects,
|
||||||
node_ids=node_ids,
|
node_ids=node_ids,
|
||||||
object_names=object_names)
|
object_names=object_names)
|
||||||
return _serialize_checkpointables(
|
object_graph_proto = _make_object_graph_proto(
|
||||||
checkpointable_objects=checkpointable_objects,
|
checkpointable_objects=checkpointable_objects,
|
||||||
node_ids=node_ids,
|
node_ids=node_ids,
|
||||||
|
slot_variables=slot_variables)
|
||||||
|
named_saveable_objects, feed_additions = _add_attributes_to_object_graph(
|
||||||
|
checkpointable_objects=checkpointable_objects,
|
||||||
|
object_graph_proto=object_graph_proto,
|
||||||
|
node_ids=node_ids,
|
||||||
object_names=object_names,
|
object_names=object_names,
|
||||||
slot_variables=slot_variables,
|
|
||||||
saveables_cache=saveables_cache,
|
saveables_cache=saveables_cache,
|
||||||
object_map=object_map)
|
object_map=object_map)
|
||||||
|
return named_saveable_objects, object_graph_proto, feed_additions
|
||||||
|
|
||||||
|
|
||||||
def _serialize_object_graph(root_checkpointable, saveables_cache):
|
def _serialize_object_graph(root_checkpointable, saveables_cache):
|
||||||
@ -716,6 +728,23 @@ def named_saveables(root_checkpointable):
|
|||||||
return _serialize_object_graph(root_checkpointable, None)[0]
|
return _serialize_object_graph(root_checkpointable, None)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _find_objects(root_checkpointable):
|
||||||
|
"""Find and number objects which are dependencies of `root_checkpointable`."""
|
||||||
|
checkpointable_objects, path_to_root = (
|
||||||
|
_breadth_first_checkpointable_traversal(root_checkpointable))
|
||||||
|
object_names = _ObjectIdentityDictionary()
|
||||||
|
for obj, path in path_to_root.items():
|
||||||
|
object_names[obj] = _object_prefix_from_path(path)
|
||||||
|
node_ids = _ObjectIdentityDictionary()
|
||||||
|
for node_id, node in enumerate(checkpointable_objects):
|
||||||
|
node_ids[node] = node_id
|
||||||
|
slot_variables = _serialize_slot_variables(
|
||||||
|
checkpointable_objects=checkpointable_objects,
|
||||||
|
node_ids=node_ids,
|
||||||
|
object_names=object_names)
|
||||||
|
return checkpointable_objects, node_ids, slot_variables
|
||||||
|
|
||||||
|
|
||||||
def list_objects(root_checkpointable):
|
def list_objects(root_checkpointable):
|
||||||
"""Traverse the object graph and list all accessible objects.
|
"""Traverse the object graph and list all accessible objects.
|
||||||
|
|
||||||
@ -730,23 +759,18 @@ def list_objects(root_checkpointable):
|
|||||||
Returns:
|
Returns:
|
||||||
A flat list of objects.
|
A flat list of objects.
|
||||||
"""
|
"""
|
||||||
# TODO(allenl): Extract out gathering logic so the naming logic doesn't have
|
checkpointable_objects, _, _ = _find_objects(root_checkpointable)
|
||||||
# to run.
|
|
||||||
checkpointable_objects, path_to_root = (
|
|
||||||
_breadth_first_checkpointable_traversal(root_checkpointable))
|
|
||||||
object_names = _ObjectIdentityDictionary()
|
|
||||||
for obj, path in path_to_root.items():
|
|
||||||
object_names[obj] = _object_prefix_from_path(path)
|
|
||||||
node_ids = _ObjectIdentityDictionary()
|
|
||||||
for node_id, node in enumerate(checkpointable_objects):
|
|
||||||
node_ids[node] = node_id
|
|
||||||
_serialize_slot_variables(
|
|
||||||
checkpointable_objects=checkpointable_objects,
|
|
||||||
node_ids=node_ids,
|
|
||||||
object_names=object_names)
|
|
||||||
return checkpointable_objects
|
return checkpointable_objects
|
||||||
|
|
||||||
|
|
||||||
|
def make_object_graph_without_attributes(root_checkpointable):
|
||||||
|
"""Construct a CheckpointableObjectGraph proto with no variable values."""
|
||||||
|
checkpointable_objects, node_ids, slot_variables = _find_objects(
|
||||||
|
root_checkpointable)
|
||||||
|
return _make_object_graph_proto(
|
||||||
|
checkpointable_objects, node_ids, slot_variables)
|
||||||
|
|
||||||
|
|
||||||
def gather_initializers(root_checkpointable):
|
def gather_initializers(root_checkpointable):
|
||||||
"""Traverse the object graph and find initialization ops.
|
"""Traverse the object graph and find initialization ops.
|
||||||
|
|
||||||
|
@ -26,7 +26,7 @@ from tensorflow.python import pywrap_tensorflow
|
|||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -44,6 +44,7 @@ from tensorflow.python.ops import variable_scope
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.training import adam
|
from tensorflow.python.training import adam
|
||||||
from tensorflow.python.training import checkpoint_management
|
from tensorflow.python.training import checkpoint_management
|
||||||
|
from tensorflow.python.training import momentum
|
||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
from tensorflow.python.training.checkpointable import base
|
from tensorflow.python.training.checkpointable import base
|
||||||
@ -198,6 +199,17 @@ class InterfaceTests(test.TestCase):
|
|||||||
with self.assertRaises(NotImplementedError):
|
with self.assertRaises(NotImplementedError):
|
||||||
checkpoint_reversed.save(prefix)
|
checkpoint_reversed.save(prefix)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||||
|
def test_object_graph_no_attributes(self):
|
||||||
|
root = tracking.Checkpointable()
|
||||||
|
root.v = resource_variable_ops.ResourceVariable(1.)
|
||||||
|
root.opt = momentum.MomentumOptimizer(0.01, 0.5)
|
||||||
|
root.opt.minimize(root.v.read_value)
|
||||||
|
object_graph = checkpointable_utils.make_object_graph_without_attributes(
|
||||||
|
root)
|
||||||
|
# Four objects: Root, v, opt, and a slot variable for v
|
||||||
|
self.assertEqual(4, len(object_graph.nodes))
|
||||||
|
|
||||||
|
|
||||||
class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
|
||||||
|
|
||||||
@ -632,7 +644,7 @@ class CheckpointingTests(test.TestCase):
|
|||||||
checkpoint_directory)
|
checkpoint_directory)
|
||||||
status = root.restore(save_path=checkpoint_path)
|
status = root.restore(save_path=checkpoint_path)
|
||||||
def train_fn():
|
def train_fn():
|
||||||
@function.defun
|
@def_function.function
|
||||||
def _call_model(x):
|
def _call_model(x):
|
||||||
return model(x)
|
return model(x)
|
||||||
with backprop.GradientTape() as tape:
|
with backprop.GradientTape() as tape:
|
||||||
|
Loading…
Reference in New Issue
Block a user