Refactor to allow creation of an object graph proto with no variable values

Will be useful for creating a SavedModel object proto

PiperOrigin-RevId: 222430965
This commit is contained in:
Allen Lavoie 2018-11-21 10:48:49 -08:00 committed by TensorFlower Gardener
parent 95d7bbb2fc
commit c16394423c
3 changed files with 67 additions and 31 deletions

View File

@ -152,7 +152,7 @@ py_test(
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras:engine",
"//tensorflow/python/keras:layers",

View File

@ -549,13 +549,11 @@ def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
return slot_variables
def _serialize_checkpointables(
checkpointable_objects, node_ids, object_names, slot_variables,
def _add_attributes_to_object_graph(
checkpointable_objects, object_graph_proto, node_ids, object_names,
saveables_cache, object_map):
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
named_saveables = []
"""Create SaveableObjects and corresponding SerializedTensor protos."""
named_saveable_objects = []
if saveables_cache is None:
# No SaveableObject caching. Either we're executing eagerly, or building a
# static save which is specialized to the current Python state.
@ -564,10 +562,9 @@ def _serialize_checkpointables(
# If we are caching SaveableObjects, we need to build up a feed_dict with
# functions computing volatile Python state to be saved with the checkpoint.
feed_additions = {}
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
for checkpoint_id, (checkpointable, object_proto) in enumerate(
zip(checkpointable_objects, object_graph_proto.nodes)):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
object_name = object_names[checkpointable]
if object_map:
object_to_save = object_map.get(checkpointable, checkpointable)
@ -645,14 +642,24 @@ def _serialize_checkpointables(
"value.")
% (checkpointable, new_feed_key))
feed_additions.update(saveable_feed_dict)
named_saveables.append(saveable)
named_saveable_objects.append(saveable)
return named_saveable_objects, feed_additions
def _make_object_graph_proto(checkpointable_objects, node_ids, slot_variables):
"""Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
object_graph_proto = (
checkpointable_object_graph_pb2.CheckpointableObjectGraph())
for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
assert node_ids[checkpointable] == checkpoint_id
object_proto = object_graph_proto.nodes.add()
object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
child_proto = object_proto.children.add()
child_proto.node_id = node_ids[child.ref]
child_proto.local_name = child.name
return named_saveables, object_graph_proto, feed_additions
return object_graph_proto
def _serialize_gathered_objects(
@ -668,13 +675,18 @@ def _serialize_gathered_objects(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
object_names=object_names)
return _serialize_checkpointables(
object_graph_proto = _make_object_graph_proto(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
slot_variables=slot_variables)
named_saveable_objects, feed_additions = _add_attributes_to_object_graph(
checkpointable_objects=checkpointable_objects,
object_graph_proto=object_graph_proto,
node_ids=node_ids,
object_names=object_names,
slot_variables=slot_variables,
saveables_cache=saveables_cache,
object_map=object_map)
return named_saveable_objects, object_graph_proto, feed_additions
def _serialize_object_graph(root_checkpointable, saveables_cache):
@ -716,6 +728,23 @@ def named_saveables(root_checkpointable):
return _serialize_object_graph(root_checkpointable, None)[0]
def _find_objects(root_checkpointable):
"""Find and number objects which are dependencies of `root_checkpointable`."""
checkpointable_objects, path_to_root = (
_breadth_first_checkpointable_traversal(root_checkpointable))
object_names = _ObjectIdentityDictionary()
for obj, path in path_to_root.items():
object_names[obj] = _object_prefix_from_path(path)
node_ids = _ObjectIdentityDictionary()
for node_id, node in enumerate(checkpointable_objects):
node_ids[node] = node_id
slot_variables = _serialize_slot_variables(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
object_names=object_names)
return checkpointable_objects, node_ids, slot_variables
def list_objects(root_checkpointable):
"""Traverse the object graph and list all accessible objects.
@ -730,23 +759,18 @@ def list_objects(root_checkpointable):
Returns:
A flat list of objects.
"""
# TODO(allenl): Extract out gathering logic so the naming logic doesn't have
# to run.
checkpointable_objects, path_to_root = (
_breadth_first_checkpointable_traversal(root_checkpointable))
object_names = _ObjectIdentityDictionary()
for obj, path in path_to_root.items():
object_names[obj] = _object_prefix_from_path(path)
node_ids = _ObjectIdentityDictionary()
for node_id, node in enumerate(checkpointable_objects):
node_ids[node] = node_id
_serialize_slot_variables(
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
object_names=object_names)
checkpointable_objects, _, _ = _find_objects(root_checkpointable)
return checkpointable_objects
def make_object_graph_without_attributes(root_checkpointable):
"""Construct a CheckpointableObjectGraph proto with no variable values."""
checkpointable_objects, node_ids, slot_variables = _find_objects(
root_checkpointable)
return _make_object_graph_proto(
checkpointable_objects, node_ids, slot_variables)
def gather_initializers(root_checkpointable):
"""Traverse the object graph and find initialization ops.

View File

@ -26,7 +26,7 @@ from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -44,6 +44,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import adam
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import momentum
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base
@ -198,6 +199,17 @@ class InterfaceTests(test.TestCase):
with self.assertRaises(NotImplementedError):
checkpoint_reversed.save(prefix)
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def test_object_graph_no_attributes(self):
root = tracking.Checkpointable()
root.v = resource_variable_ops.ResourceVariable(1.)
root.opt = momentum.MomentumOptimizer(0.01, 0.5)
root.opt.minimize(root.v.read_value)
object_graph = checkpointable_utils.make_object_graph_without_attributes(
root)
# Four objects: Root, v, opt, and a slot variable for v
self.assertEqual(4, len(object_graph.nodes))
class _MirroringSaveable(saver_lib.BaseSaverBuilder.SaveableObject):
@ -632,7 +644,7 @@ class CheckpointingTests(test.TestCase):
checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
def train_fn():
@function.defun
@def_function.function
def _call_model(x):
return model(x)
with backprop.GradientTape() as tape: