From a3e64f721c9a09ace61a30de0587add8a17e50cb Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Fri, 26 Jun 2020 14:58:58 -0700 Subject: [PATCH] Add SaveableObjects to SavedModel. When objects are loaded from the SavedModel, they don't retain their `_gather_saveables_for_checkpoint` functions, which can result in values not being loaded from the checkpoint. This CL adds a field in the SavedModel proto that stores a save and restore function for each SaveableObject in each node. When loading into Python, the SaveableObjects are restored using the functions. PiperOrigin-RevId: 318549786 Change-Id: I688c72d7658e1bca98abf373a13a0e15a7fb83e2 --- RELEASE.md | 4 + .../core/protobuf/saved_object_graph.proto | 8 + tensorflow/python/ops/lookup_ops.py | 22 ++- tensorflow/python/saved_model/load.py | 13 ++ tensorflow/python/saved_model/load_test.py | 16 ++ tensorflow/python/saved_model/save.py | 48 +++++- .../training/saving/saveable_object_util.py | 157 +++++++++++++++++- tensorflow/python/training/tracking/base.py | 8 +- .../python/training/tracking/graph_view.py | 18 +- 9 files changed, 274 insertions(+), 20 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index f93626cc876..2218fab5808 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,6 +11,10 @@ exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved models will not be impacted. +## Bug Fixes and Other Changes + +* Mutable tables now restore checkpointed values when loaded from SavedModel. + # Release 2.1.1 ## Bug Fixes and Other Changes diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto index e794b885dec..981908cfa3c 100644 --- a/tensorflow/core/protobuf/saved_object_graph.proto +++ b/tensorflow/core/protobuf/saved_object_graph.proto @@ -61,6 +61,8 @@ message SavedObject { SavedConstant constant = 9; SavedResource resource = 10; } + + map saveable_objects = 11; } // A SavedUserObject is an object (in the object-oriented language of the @@ -162,3 +164,9 @@ message SavedResource { // device. string device = 1; } + +message SaveableObject { + // Node ids of concrete functions for saving and loading from a checkpoint. + int32 save_function = 2; + int32 restore_function = 3; +} diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 15c7f12f89c..96f3cf91499 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -1870,25 +1870,27 @@ class MutableHashTable(LookupInterface): return { "table": functools.partial( - MutableHashTable._Saveable, table=self, name=self._name) + MutableHashTable._Saveable, table=self, name=self._name, + table_name=self._name) } class _Saveable(BaseSaverBuilder.SaveableObject): - """SaveableObject implementation for MutableHashTable.""" + """SaveableObject implementation for DenseHashTable.""" - def __init__(self, table, name): + def __init__(self, table, name, table_name=None): tensors = table.export() specs = [ BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") ] + self.table_name = table_name or name # pylint: disable=protected-access super(MutableHashTable._Saveable, self).__init__(table, specs, name) - def restore(self, restored_tensors, restored_shapes, name=None): + def restore(self, restored_tensors, restored_shapes): del restored_shapes # unused # pylint: disable=protected-access - with ops.name_scope(name, "%s_table_restore" % self.name): + with ops.name_scope("%s_table_restore" % self.table_name): with ops.colocate_with(self.op.resource_handle): return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle, restored_tensors[0], @@ -2166,25 +2168,27 @@ class DenseHashTable(LookupInterface): return { "table": functools.partial( - DenseHashTable._Saveable, table=self, name=self._name) + DenseHashTable._Saveable, table=self, name=self._name, + table_name=self._name) } class _Saveable(BaseSaverBuilder.SaveableObject): """SaveableObject implementation for DenseHashTable.""" - def __init__(self, table, name): + def __init__(self, table, name, table_name=None): tensors = table.export() specs = [ BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"), BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values") ] + self.table_name = table_name or name # pylint: disable=protected-access super(DenseHashTable._Saveable, self).__init__(table, specs, name) - def restore(self, restored_tensors, restored_shapes, name=None): + def restore(self, restored_tensors, restored_shapes): del restored_shapes # unused # pylint: disable=protected-access - with ops.name_scope(name, "%s_table_restore" % self.name): + with ops.name_scope("%s_table_restore" % self.table_name): with ops.colocate_with(self.op.resource_handle): return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle, restored_tensors[0], diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index fb2d01cbee2..0835481ab69 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -45,6 +45,7 @@ from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.saved_model import revived_types from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training.saving import checkpoint_options +from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import graph_view from tensorflow.python.training.tracking import tracking @@ -146,6 +147,18 @@ class Loader(object): self._setup_functions_structures() self._setup_functions_captures() + self._create_saveable_object_factories() + + def _create_saveable_object_factories(self): + for node_id, proto in enumerate(self._proto.nodes): + node = self.get(node_id) + node._self_saveable_object_factories = {} # pylint: disable=protected-access + for name, saveable_object_proto in proto.saveable_objects.items(): + node._self_saveable_object_factories[name] = ( # pylint: disable=protected-access + saveable_object_util.restored_saved_object_factory( + self.get(saveable_object_proto.save_function), + self.get(saveable_object_proto.restore_function))) + def _load_edges(self): """Adds edges from objects to other objects and functions.""" for node_id, object_proto in enumerate(self._proto.nodes): diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 5449cc1c9a2..c392c7feb31 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -1795,6 +1795,22 @@ class LoadTest(test.TestCase, parameterized.TestCase): options = load_options.LoadOptions(experimental_io_device="/job:localhost") self.assertEqual("/job:localhost", options.experimental_io_device) + def test_load_custom_saveable_object(self, cycles): + root = tracking.AutoTrackable() + root.table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1) + root.table.insert("foo", 15) + + @def_function.function( + input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) + def lookup(key): + return root.table.lookup(key) + + root.lookup = lookup + + imported = cycle(root, cycles) + self.assertEqual(self.evaluate(imported.lookup("foo")), 15) + self.assertEqual(self.evaluate(imported.lookup("idk")), -1) + class SingleCycleTests(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 802ce1d61b7..deedbd6794e 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import functools import os from tensorflow.core.framework import versions_pb2 @@ -53,6 +54,7 @@ from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import utils_impl from tensorflow.python.training.saving import checkpoint_options from tensorflow.python.training.saving import functional_saver +from tensorflow.python.training.saving import saveable_object_util from tensorflow.python.training.tracking import base from tensorflow.python.training.tracking import graph_view from tensorflow.python.training.tracking import tracking @@ -136,12 +138,15 @@ class _AugmentedGraphView(graph_view.ObjectGraphView): return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access self._serialization_cache) - def list_functions(self, obj): + def list_functions(self, obj, extra_functions=None): obj_functions = self._functions.get(obj, None) if obj_functions is None: obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access self._serialization_cache) self._functions[obj] = obj_functions + if extra_functions: + obj_functions = obj_functions.copy() + obj_functions.update(extra_functions) return obj_functions @@ -177,6 +182,12 @@ class _SaveableView(object): self.slot_variables = slot_variables self.concrete_functions = [] + self.saveable_objects_for_node, all_saveable_functions = ( + self._add_saveable_objects()) + saveable_object_functions = { + "__SAVEABLE_FUNCTION_{}".format(n): fn + for n, fn in enumerate(all_saveable_functions)} + # Maps functions -> wrapped functions that capture variables self.wrapped_functions = wrapped_functions or {} # Maps names of concrete functions in the object to names of wrapped @@ -190,7 +201,8 @@ class _SaveableView(object): nodes_without_functions = list(self.nodes) seen_function_names = set() for node in nodes_without_functions: - for function in checkpoint_view.list_functions(node).values(): + for function in checkpoint_view.list_functions( + node, saveable_object_functions).values(): if function not in self.node_ids: self.node_ids[function] = len(self.nodes) self.nodes.append(function) @@ -209,6 +221,25 @@ class _SaveableView(object): seen_function_names.add(concrete_function.name) self.concrete_functions.append(concrete_function) + def _add_saveable_objects(self): + """Retrieves SaveablesObjects and traces their save/restore functions.""" + # Maps node -> local name -> (save function, restore function) + saveable_objects_map = object_identity.ObjectIdentityDictionary() + all_saveable_functions = [] + for node in self.nodes: + if resource_variable_ops.is_resource_variable(node): + # Resource (and TPU/Mirrored) variables are automatically revived with + # their saveables defined, so there is no need to trace the save + # and restore functions. + continue + saveable_map = saveable_object_util.trace_save_restore_functions(node) + if saveable_map: + saveable_objects_map[node] = saveable_map + for save_fn, restore_fn in saveable_map.values(): + all_saveable_functions.append(save_fn) + all_saveable_functions.append(restore_fn) + return saveable_objects_map, all_saveable_functions + @property def root(self): return self.nodes[0] @@ -233,6 +264,15 @@ class _SaveableView(object): child_proto.node_id = self.node_ids[ref_function] child_proto.local_name = local_name + if node not in self.saveable_objects_for_node: + continue + + for local_name, (save_fn, restore_fn) in ( + self.saveable_objects_for_node[node].items()): + saveable_object_proto = object_proto.saveable_objects[local_name] + saveable_object_proto.save_function = self.node_ids[save_fn] + saveable_object_proto.restore_function = self.node_ids[restore_fn] + def map_resources(self): """Makes new resource handle ops corresponding to existing resource tensors. @@ -605,7 +645,9 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions, # the exported graph (thus the `to_graph` argument). saver = functional_saver.MultiDeviceSaver( saveable_view.checkpoint_view.frozen_saveable_objects( - object_map=object_map, to_graph=exported_graph)) + object_map=object_map, to_graph=exported_graph, + call_with_mapped_captures=functools.partial( + _call_function_with_mapped_captures, resource_map=resource_map))) with exported_graph.as_default(): signatures = _generate_signatures(signature_functions, resource_map) diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py index 59d65ade573..c3c3570c0f8 100644 --- a/tensorflow/python/training/saving/saveable_object_util.py +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -17,15 +17,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import six from tensorflow.python.eager import context +from tensorflow.python.eager import def_function + +from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import type_spec + + from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.saving import saveable_object from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest @@ -279,7 +290,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True): raise ValueError( ("Two different ResourceVariable objects with the same " "shared_name '%s' were passed to the Saver. This likely means " - "that they were created in different Graphs or isolation " + "that they were created in different Graphs or isoWlation " "contexts, and may not be checkpointed together.") % (var._shared_name,)) else: @@ -349,3 +360,147 @@ def validate_and_slice_inputs(names_to_saveables): for converted_saveable_object in saveable_objects_for_op(op, name): _add_saveable(saveables, seen_ops, converted_saveable_object) return saveables + + +def trace_save_restore_functions(object_to_save): + """Gathers all SaveableObjects and traces the save and restore ops.""" + saveable_map = {} # Maps name -> (save function, restore function) + for name, saveable_factory in ( + object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access + if not callable(saveable_factory): + if isinstance(saveable_factory, saveable_object.SaveableObject): + logging.debug( + "Trackable {} should return callable factories, not SaveableObjects" + " in `_gather_saveables_for_checkpoint`. This could lead to " + "problems loading the SavedModel back into Python." + .format(object_to_save)) + continue + + if is_factory_for_restored_saveable_object(saveable_factory): + saveable_map[name] = (saveable_factory.keywords["save_function"], + saveable_factory.keywords["restore_function"]) + else: + concrete_save_fn, concrete_restore_fn = _trace_save_and_restore_function( + saveable_factory, object_to_save) + if concrete_save_fn is not None: + saveable_map[name] = (concrete_save_fn, concrete_restore_fn) + return saveable_map + + +def _trace_save_and_restore_function(saveable_factory, object_to_save): + """Traces the save and restore concrete functions.""" + saveables = [] + + @def_function.function( + input_signature=[tensor_spec.TensorSpec([], dtypes.string)]) + def save_fn(checkpoint_key): + maybe_saveable = saveable_factory(name=checkpoint_key) + if isinstance(maybe_saveable, saveable_object.SaveableObject): + maybe_saveable = [maybe_saveable] + saveables[:] = maybe_saveable + + # Return list of all SaveSpecs created by the factory. + ret = [] + for saveable in saveables: + for spec in saveable.specs: + ret.append({"name": spec.name, "tensor": spec.tensor, + "slice_spec": spec.slice_spec}) + return ret + + concrete_save_fn = save_fn.get_concrete_function() + if any(isinstance(saveable, trackable.PythonStateSaveable) + for saveable in saveables): + logging.warn( + "Note that object {} stores python values into the checkpoint. " + "These values will not be restored when loading the SavedModel " + "into python.".format(object_to_save)) + return None, None + if any(isinstance(saveable, trackable.NoRestoreSaveable) + for saveable in saveables): + return None, None + + restored_type_specs = [] + tensor_structure = [] + for saveable in saveables: + saveable_tensor_structure = [] + tensor_structure.append(saveable_tensor_structure) + for spec in saveable.specs: + restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor)) + saveable_tensor_structure.append(spec.name) + + @def_function.function(input_signature=restored_type_specs) + def restore_fn(*restored_tensors): + structured_restored_tensors = nest.pack_sequence_as( + tensor_structure, restored_tensors) + for saveable, restored_tensors in zip(saveables, + structured_restored_tensors): + saveable.restore(restored_tensors, restored_shapes=None) + return 1 + + concrete_restore_fn = restore_fn.get_concrete_function() + return concrete_save_fn, concrete_restore_fn + + +class RestoredSaveableObject(saveable_object.SaveableObject): + """SaveableObject restored from SavedModel using the traced save/restore.""" + + def __init__(self, save_function, restore_function, name): + self.save_function = save_function + self.restore_function = restore_function + + if tensor_util.is_tensor(name): + name_tensor = name + else: + with ops.init_scope(): + name_tensor = constant_op.constant(name) + tensors = save_function(name_tensor) + specs = [saveable_object.SaveSpec(x["tensor"], x["slice_spec"], x["name"]) + for x in tensors] + super(RestoredSaveableObject, self).__init__(None, specs, name) + + def restore(self, restored_tensors, restored_shapes): + del restored_shapes # unused + return self.restore_function( + *[restored_tensors[i] for i in range(len(self.specs))]) + + +def restored_saved_object_factory(save_function, restore_function): + return functools.partial(RestoredSaveableObject, + save_function=save_function, + restore_function=restore_function) + + +def create_saveable_object(factory, name, call_with_mapped_captures): + """Creates a SaveableObject while potentially in a different graph. + + When creating the frozen saver for SavedModel, the save and restore ops are + placed in a separate graph. Since RestoredSaveableObject uses tf.functions to + save and restore, the function captures must be mapped to the new graph. + + Args: + factory: Factory method for creating the SaveableObject. + name: Checkpoint key of this SaveableObject. + call_with_mapped_captures: Helper that calls a tf.function while remapping + the captures. + + Returns: + a SaveableObject. + """ + if (call_with_mapped_captures is None or + not is_factory_for_restored_saveable_object(factory)): + return factory(name=name) + + concrete_save_fn = factory.keywords["save_function"] + def save_fn(name): + return call_with_mapped_captures(concrete_save_fn, [name]) + + concrete_restore_fn = factory.keywords["restore_function"] + def restore_fn(*restored_tensors): + return call_with_mapped_captures(concrete_restore_fn, restored_tensors) + + return factory(save_function=save_fn, restore_function=restore_fn, name=name) + + +def is_factory_for_restored_saveable_object(factory): + return (isinstance(factory, functools.partial) and + factory.func is RestoredSaveableObject) diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py index ea76ad8db47..9337adbf88a 100644 --- a/tensorflow/python/training/tracking/base.py +++ b/tensorflow/python/training/tracking/base.py @@ -611,6 +611,12 @@ class Trackable(object): # building. self._self_name_based_restores = set() + # Dictionary of SaveableObjects factories. This dictionary is defined when + # the object is loaded from the SavedModel. When writing a custom class, + # prefer overriding "_gather_saveables_from_checkpoint" to using this + # attribute. + self._self_saveable_object_factories = {} + @property def _object_identifier(self): """String used to identify this object in a SavedModel. @@ -972,7 +978,7 @@ class Trackable(object): lambda name="global_name_for_this_object": SaveableObject(name=name, ...)} """ - return {} + return self._self_saveable_object_factories def _list_extra_dependencies_for_serialization(self, serialization_cache): """Lists extra dependencies to serialize. diff --git a/tensorflow/python/training/tracking/graph_view.py b/tensorflow/python/training/tracking/graph_view.py index 041ff38eedd..1cf84023b1c 100644 --- a/tensorflow/python/training/tracking/graph_view.py +++ b/tensorflow/python/training/tracking/graph_view.py @@ -208,7 +208,7 @@ class ObjectGraphView(object): def _add_attributes_to_object_graph( self, trackable_objects, object_graph_proto, node_ids, object_names, - object_map): + object_map, call_with_mapped_captures): """Create SaveableObjects and corresponding SerializedTensor protos.""" named_saveable_objects = [] if self._saveables_cache is None: @@ -253,7 +253,9 @@ class ObjectGraphView(object): break if saveables is None: if callable(saveable_factory): - maybe_saveable = saveable_factory(name=attribute.checkpoint_key) + maybe_saveable = saveable_object_util.create_saveable_object( + saveable_factory, attribute.checkpoint_key, + call_with_mapped_captures) else: maybe_saveable = saveable_factory if isinstance(maybe_saveable, saveable_object_lib.SaveableObject): @@ -332,7 +334,8 @@ class ObjectGraphView(object): return object_graph_proto def _serialize_gathered_objects(self, trackable_objects, path_to_root, - object_map=None): + object_map=None, + call_with_mapped_captures=None): """Create SaveableObjects and protos for gathered objects.""" object_names = object_identity.ObjectIdentityDictionary() for obj, path in path_to_root.items(): @@ -354,7 +357,8 @@ class ObjectGraphView(object): object_graph_proto=object_graph_proto, node_ids=node_ids, object_names=object_names, - object_map=object_map)) + object_map=object_map, + call_with_mapped_captures=call_with_mapped_captures)) return named_saveable_objects, object_graph_proto, feed_additions def serialize_object_graph(self): @@ -382,7 +386,8 @@ class ObjectGraphView(object): return self._serialize_gathered_objects( trackable_objects, path_to_root) - def frozen_saveable_objects(self, object_map=None, to_graph=None): + def frozen_saveable_objects(self, object_map=None, to_graph=None, + call_with_mapped_captures=None): """Creates SaveableObjects with the current object graph frozen.""" trackable_objects, path_to_root = self._breadth_first_traversal() if to_graph: @@ -393,7 +398,8 @@ class ObjectGraphView(object): named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects( trackable_objects, path_to_root, - object_map) + object_map, + call_with_mapped_captures) with ops.device("/cpu:0"): object_graph_tensor = constant_op.constant( graph_proto.SerializeToString(), dtype=dtypes.string)