Add SaveableObjects to SavedModel.

When objects are loaded from the SavedModel, they don't retain their `_gather_saveables_for_checkpoint` functions, which can result in values not being loaded from the checkpoint.

This CL adds a field in the SavedModel proto that stores a save and restore function for each SaveableObject in each node. When loading into Python, the SaveableObjects are restored using the functions.

PiperOrigin-RevId: 318549786
Change-Id: I688c72d7658e1bca98abf373a13a0e15a7fb83e2
This commit is contained in:
Katherine Wu 2020-06-26 14:58:58 -07:00 committed by Kathy Wu
parent 99fea8da0d
commit a3e64f721c
9 changed files with 274 additions and 20 deletions

View File

@ -11,6 +11,10 @@
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
models will not be impacted.
## Bug Fixes and Other Changes
* Mutable tables now restore checkpointed values when loaded from SavedModel.
# Release 2.1.1
## Bug Fixes and Other Changes

View File

@ -61,6 +61,8 @@ message SavedObject {
SavedConstant constant = 9;
SavedResource resource = 10;
}
map<string, SaveableObject> saveable_objects = 11;
}
// A SavedUserObject is an object (in the object-oriented language of the
@ -162,3 +164,9 @@ message SavedResource {
// device.
string device = 1;
}
message SaveableObject {
// Node ids of concrete functions for saving and loading from a checkpoint.
int32 save_function = 2;
int32 restore_function = 3;
}

View File

@ -1870,25 +1870,27 @@ class MutableHashTable(LookupInterface):
return {
"table":
functools.partial(
MutableHashTable._Saveable, table=self, name=self._name)
MutableHashTable._Saveable, table=self, name=self._name,
table_name=self._name)
}
class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for MutableHashTable."""
"""SaveableObject implementation for DenseHashTable."""
def __init__(self, table, name):
def __init__(self, table, name, table_name=None):
tensors = table.export()
specs = [
BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
]
self.table_name = table_name or name
# pylint: disable=protected-access
super(MutableHashTable._Saveable, self).__init__(table, specs, name)
def restore(self, restored_tensors, restored_shapes, name=None):
def restore(self, restored_tensors, restored_shapes):
del restored_shapes # unused
# pylint: disable=protected-access
with ops.name_scope(name, "%s_table_restore" % self.name):
with ops.name_scope("%s_table_restore" % self.table_name):
with ops.colocate_with(self.op.resource_handle):
return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
restored_tensors[0],
@ -2166,25 +2168,27 @@ class DenseHashTable(LookupInterface):
return {
"table":
functools.partial(
DenseHashTable._Saveable, table=self, name=self._name)
DenseHashTable._Saveable, table=self, name=self._name,
table_name=self._name)
}
class _Saveable(BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for DenseHashTable."""
def __init__(self, table, name):
def __init__(self, table, name, table_name=None):
tensors = table.export()
specs = [
BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
]
self.table_name = table_name or name
# pylint: disable=protected-access
super(DenseHashTable._Saveable, self).__init__(table, specs, name)
def restore(self, restored_tensors, restored_shapes, name=None):
def restore(self, restored_tensors, restored_shapes):
del restored_shapes # unused
# pylint: disable=protected-access
with ops.name_scope(name, "%s_table_restore" % self.name):
with ops.name_scope("%s_table_restore" % self.table_name):
with ops.colocate_with(self.op.resource_handle):
return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
restored_tensors[0],

View File

@ -45,6 +45,7 @@ from tensorflow.python.saved_model import nested_structure_coder
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
@ -146,6 +147,18 @@ class Loader(object):
self._setup_functions_structures()
self._setup_functions_captures()
self._create_saveable_object_factories()
def _create_saveable_object_factories(self):
for node_id, proto in enumerate(self._proto.nodes):
node = self.get(node_id)
node._self_saveable_object_factories = {} # pylint: disable=protected-access
for name, saveable_object_proto in proto.saveable_objects.items():
node._self_saveable_object_factories[name] = ( # pylint: disable=protected-access
saveable_object_util.restored_saved_object_factory(
self.get(saveable_object_proto.save_function),
self.get(saveable_object_proto.restore_function)))
def _load_edges(self):
"""Adds edges from objects to other objects and functions."""
for node_id, object_proto in enumerate(self._proto.nodes):

View File

@ -1795,6 +1795,22 @@ class LoadTest(test.TestCase, parameterized.TestCase):
options = load_options.LoadOptions(experimental_io_device="/job:localhost")
self.assertEqual("/job:localhost", options.experimental_io_device)
def test_load_custom_saveable_object(self, cycles):
root = tracking.AutoTrackable()
root.table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1)
root.table.insert("foo", 15)
@def_function.function(
input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
def lookup(key):
return root.table.lookup(key)
root.lookup = lookup
imported = cycle(root, cycles)
self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
class SingleCycleTests(test.TestCase, parameterized.TestCase):

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
import functools
import os
from tensorflow.core.framework import versions_pb2
@ -53,6 +54,7 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils_impl
from tensorflow.python.training.saving import checkpoint_options
from tensorflow.python.training.saving import functional_saver
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.training.tracking import base
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import tracking
@ -136,12 +138,15 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access
self._serialization_cache)
def list_functions(self, obj):
def list_functions(self, obj, extra_functions=None):
obj_functions = self._functions.get(obj, None)
if obj_functions is None:
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
self._serialization_cache)
self._functions[obj] = obj_functions
if extra_functions:
obj_functions = obj_functions.copy()
obj_functions.update(extra_functions)
return obj_functions
@ -177,6 +182,12 @@ class _SaveableView(object):
self.slot_variables = slot_variables
self.concrete_functions = []
self.saveable_objects_for_node, all_saveable_functions = (
self._add_saveable_objects())
saveable_object_functions = {
"__SAVEABLE_FUNCTION_{}".format(n): fn
for n, fn in enumerate(all_saveable_functions)}
# Maps functions -> wrapped functions that capture variables
self.wrapped_functions = wrapped_functions or {}
# Maps names of concrete functions in the object to names of wrapped
@ -190,7 +201,8 @@ class _SaveableView(object):
nodes_without_functions = list(self.nodes)
seen_function_names = set()
for node in nodes_without_functions:
for function in checkpoint_view.list_functions(node).values():
for function in checkpoint_view.list_functions(
node, saveable_object_functions).values():
if function not in self.node_ids:
self.node_ids[function] = len(self.nodes)
self.nodes.append(function)
@ -209,6 +221,25 @@ class _SaveableView(object):
seen_function_names.add(concrete_function.name)
self.concrete_functions.append(concrete_function)
def _add_saveable_objects(self):
"""Retrieves SaveablesObjects and traces their save/restore functions."""
# Maps node -> local name -> (save function, restore function)
saveable_objects_map = object_identity.ObjectIdentityDictionary()
all_saveable_functions = []
for node in self.nodes:
if resource_variable_ops.is_resource_variable(node):
# Resource (and TPU/Mirrored) variables are automatically revived with
# their saveables defined, so there is no need to trace the save
# and restore functions.
continue
saveable_map = saveable_object_util.trace_save_restore_functions(node)
if saveable_map:
saveable_objects_map[node] = saveable_map
for save_fn, restore_fn in saveable_map.values():
all_saveable_functions.append(save_fn)
all_saveable_functions.append(restore_fn)
return saveable_objects_map, all_saveable_functions
@property
def root(self):
return self.nodes[0]
@ -233,6 +264,15 @@ class _SaveableView(object):
child_proto.node_id = self.node_ids[ref_function]
child_proto.local_name = local_name
if node not in self.saveable_objects_for_node:
continue
for local_name, (save_fn, restore_fn) in (
self.saveable_objects_for_node[node].items()):
saveable_object_proto = object_proto.saveable_objects[local_name]
saveable_object_proto.save_function = self.node_ids[save_fn]
saveable_object_proto.restore_function = self.node_ids[restore_fn]
def map_resources(self):
"""Makes new resource handle ops corresponding to existing resource tensors.
@ -605,7 +645,9 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
# the exported graph (thus the `to_graph` argument).
saver = functional_saver.MultiDeviceSaver(
saveable_view.checkpoint_view.frozen_saveable_objects(
object_map=object_map, to_graph=exported_graph))
object_map=object_map, to_graph=exported_graph,
call_with_mapped_captures=functools.partial(
_call_function_with_mapped_captures, resource_map=resource_map)))
with exported_graph.as_default():
signatures = _generate_signatures(signature_functions, resource_map)

View File

@ -17,15 +17,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import six
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
@ -279,7 +290,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
raise ValueError(
("Two different ResourceVariable objects with the same "
"shared_name '%s' were passed to the Saver. This likely means "
"that they were created in different Graphs or isolation "
"that they were created in different Graphs or isoWlation "
"contexts, and may not be checkpointed together.") %
(var._shared_name,))
else:
@ -349,3 +360,147 @@ def validate_and_slice_inputs(names_to_saveables):
for converted_saveable_object in saveable_objects_for_op(op, name):
_add_saveable(saveables, seen_ops, converted_saveable_object)
return saveables
def trace_save_restore_functions(object_to_save):
"""Gathers all SaveableObjects and traces the save and restore ops."""
saveable_map = {} # Maps name -> (save function, restore function)
for name, saveable_factory in (
object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
if not callable(saveable_factory):
if isinstance(saveable_factory, saveable_object.SaveableObject):
logging.debug(
"Trackable {} should return callable factories, not SaveableObjects"
" in `_gather_saveables_for_checkpoint`. This could lead to "
"problems loading the SavedModel back into Python."
.format(object_to_save))
continue
if is_factory_for_restored_saveable_object(saveable_factory):
saveable_map[name] = (saveable_factory.keywords["save_function"],
saveable_factory.keywords["restore_function"])
else:
concrete_save_fn, concrete_restore_fn = _trace_save_and_restore_function(
saveable_factory, object_to_save)
if concrete_save_fn is not None:
saveable_map[name] = (concrete_save_fn, concrete_restore_fn)
return saveable_map
def _trace_save_and_restore_function(saveable_factory, object_to_save):
"""Traces the save and restore concrete functions."""
saveables = []
@def_function.function(
input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
def save_fn(checkpoint_key):
maybe_saveable = saveable_factory(name=checkpoint_key)
if isinstance(maybe_saveable, saveable_object.SaveableObject):
maybe_saveable = [maybe_saveable]
saveables[:] = maybe_saveable
# Return list of all SaveSpecs created by the factory.
ret = []
for saveable in saveables:
for spec in saveable.specs:
ret.append({"name": spec.name, "tensor": spec.tensor,
"slice_spec": spec.slice_spec})
return ret
concrete_save_fn = save_fn.get_concrete_function()
if any(isinstance(saveable, trackable.PythonStateSaveable)
for saveable in saveables):
logging.warn(
"Note that object {} stores python values into the checkpoint. "
"These values will not be restored when loading the SavedModel "
"into python.".format(object_to_save))
return None, None
if any(isinstance(saveable, trackable.NoRestoreSaveable)
for saveable in saveables):
return None, None
restored_type_specs = []
tensor_structure = []
for saveable in saveables:
saveable_tensor_structure = []
tensor_structure.append(saveable_tensor_structure)
for spec in saveable.specs:
restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor))
saveable_tensor_structure.append(spec.name)
@def_function.function(input_signature=restored_type_specs)
def restore_fn(*restored_tensors):
structured_restored_tensors = nest.pack_sequence_as(
tensor_structure, restored_tensors)
for saveable, restored_tensors in zip(saveables,
structured_restored_tensors):
saveable.restore(restored_tensors, restored_shapes=None)
return 1
concrete_restore_fn = restore_fn.get_concrete_function()
return concrete_save_fn, concrete_restore_fn
class RestoredSaveableObject(saveable_object.SaveableObject):
"""SaveableObject restored from SavedModel using the traced save/restore."""
def __init__(self, save_function, restore_function, name):
self.save_function = save_function
self.restore_function = restore_function
if tensor_util.is_tensor(name):
name_tensor = name
else:
with ops.init_scope():
name_tensor = constant_op.constant(name)
tensors = save_function(name_tensor)
specs = [saveable_object.SaveSpec(x["tensor"], x["slice_spec"], x["name"])
for x in tensors]
super(RestoredSaveableObject, self).__init__(None, specs, name)
def restore(self, restored_tensors, restored_shapes):
del restored_shapes # unused
return self.restore_function(
*[restored_tensors[i] for i in range(len(self.specs))])
def restored_saved_object_factory(save_function, restore_function):
return functools.partial(RestoredSaveableObject,
save_function=save_function,
restore_function=restore_function)
def create_saveable_object(factory, name, call_with_mapped_captures):
"""Creates a SaveableObject while potentially in a different graph.
When creating the frozen saver for SavedModel, the save and restore ops are
placed in a separate graph. Since RestoredSaveableObject uses tf.functions to
save and restore, the function captures must be mapped to the new graph.
Args:
factory: Factory method for creating the SaveableObject.
name: Checkpoint key of this SaveableObject.
call_with_mapped_captures: Helper that calls a tf.function while remapping
the captures.
Returns:
a SaveableObject.
"""
if (call_with_mapped_captures is None or
not is_factory_for_restored_saveable_object(factory)):
return factory(name=name)
concrete_save_fn = factory.keywords["save_function"]
def save_fn(name):
return call_with_mapped_captures(concrete_save_fn, [name])
concrete_restore_fn = factory.keywords["restore_function"]
def restore_fn(*restored_tensors):
return call_with_mapped_captures(concrete_restore_fn, restored_tensors)
return factory(save_function=save_fn, restore_function=restore_fn, name=name)
def is_factory_for_restored_saveable_object(factory):
return (isinstance(factory, functools.partial) and
factory.func is RestoredSaveableObject)

View File

@ -611,6 +611,12 @@ class Trackable(object):
# building.
self._self_name_based_restores = set()
# Dictionary of SaveableObjects factories. This dictionary is defined when
# the object is loaded from the SavedModel. When writing a custom class,
# prefer overriding "_gather_saveables_from_checkpoint" to using this
# attribute.
self._self_saveable_object_factories = {}
@property
def _object_identifier(self):
"""String used to identify this object in a SavedModel.
@ -972,7 +978,7 @@ class Trackable(object):
lambda name="global_name_for_this_object":
SaveableObject(name=name, ...)}
"""
return {}
return self._self_saveable_object_factories
def _list_extra_dependencies_for_serialization(self, serialization_cache):
"""Lists extra dependencies to serialize.

View File

@ -208,7 +208,7 @@ class ObjectGraphView(object):
def _add_attributes_to_object_graph(
self, trackable_objects, object_graph_proto, node_ids, object_names,
object_map):
object_map, call_with_mapped_captures):
"""Create SaveableObjects and corresponding SerializedTensor protos."""
named_saveable_objects = []
if self._saveables_cache is None:
@ -253,7 +253,9 @@ class ObjectGraphView(object):
break
if saveables is None:
if callable(saveable_factory):
maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
maybe_saveable = saveable_object_util.create_saveable_object(
saveable_factory, attribute.checkpoint_key,
call_with_mapped_captures)
else:
maybe_saveable = saveable_factory
if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
@ -332,7 +334,8 @@ class ObjectGraphView(object):
return object_graph_proto
def _serialize_gathered_objects(self, trackable_objects, path_to_root,
object_map=None):
object_map=None,
call_with_mapped_captures=None):
"""Create SaveableObjects and protos for gathered objects."""
object_names = object_identity.ObjectIdentityDictionary()
for obj, path in path_to_root.items():
@ -354,7 +357,8 @@ class ObjectGraphView(object):
object_graph_proto=object_graph_proto,
node_ids=node_ids,
object_names=object_names,
object_map=object_map))
object_map=object_map,
call_with_mapped_captures=call_with_mapped_captures))
return named_saveable_objects, object_graph_proto, feed_additions
def serialize_object_graph(self):
@ -382,7 +386,8 @@ class ObjectGraphView(object):
return self._serialize_gathered_objects(
trackable_objects, path_to_root)
def frozen_saveable_objects(self, object_map=None, to_graph=None):
def frozen_saveable_objects(self, object_map=None, to_graph=None,
call_with_mapped_captures=None):
"""Creates SaveableObjects with the current object graph frozen."""
trackable_objects, path_to_root = self._breadth_first_traversal()
if to_graph:
@ -393,7 +398,8 @@ class ObjectGraphView(object):
named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
trackable_objects,
path_to_root,
object_map)
object_map,
call_with_mapped_captures)
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
graph_proto.SerializeToString(), dtype=dtypes.string)