Add SaveableObjects to SavedModel.
When objects are loaded from the SavedModel, they don't retain their `_gather_saveables_for_checkpoint` functions, which can result in values not being loaded from the checkpoint. This CL adds a field in the SavedModel proto that stores a save and restore function for each SaveableObject in each node. When loading into Python, the SaveableObjects are restored using the functions. PiperOrigin-RevId: 318549786 Change-Id: I688c72d7658e1bca98abf373a13a0e15a7fb83e2
This commit is contained in:
parent
99fea8da0d
commit
a3e64f721c
|
@ -11,6 +11,10 @@
|
|||
exsiting C++ kernel `ExtractGlimpse` does not change as well, so saved
|
||||
models will not be impacted.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
* Mutable tables now restore checkpointed values when loaded from SavedModel.
|
||||
|
||||
# Release 2.1.1
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
|
|
@ -61,6 +61,8 @@ message SavedObject {
|
|||
SavedConstant constant = 9;
|
||||
SavedResource resource = 10;
|
||||
}
|
||||
|
||||
map<string, SaveableObject> saveable_objects = 11;
|
||||
}
|
||||
|
||||
// A SavedUserObject is an object (in the object-oriented language of the
|
||||
|
@ -162,3 +164,9 @@ message SavedResource {
|
|||
// device.
|
||||
string device = 1;
|
||||
}
|
||||
|
||||
message SaveableObject {
|
||||
// Node ids of concrete functions for saving and loading from a checkpoint.
|
||||
int32 save_function = 2;
|
||||
int32 restore_function = 3;
|
||||
}
|
||||
|
|
|
@ -1870,25 +1870,27 @@ class MutableHashTable(LookupInterface):
|
|||
return {
|
||||
"table":
|
||||
functools.partial(
|
||||
MutableHashTable._Saveable, table=self, name=self._name)
|
||||
MutableHashTable._Saveable, table=self, name=self._name,
|
||||
table_name=self._name)
|
||||
}
|
||||
|
||||
class _Saveable(BaseSaverBuilder.SaveableObject):
|
||||
"""SaveableObject implementation for MutableHashTable."""
|
||||
"""SaveableObject implementation for DenseHashTable."""
|
||||
|
||||
def __init__(self, table, name):
|
||||
def __init__(self, table, name, table_name=None):
|
||||
tensors = table.export()
|
||||
specs = [
|
||||
BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
|
||||
BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
|
||||
]
|
||||
self.table_name = table_name or name
|
||||
# pylint: disable=protected-access
|
||||
super(MutableHashTable._Saveable, self).__init__(table, specs, name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes, name=None):
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
del restored_shapes # unused
|
||||
# pylint: disable=protected-access
|
||||
with ops.name_scope(name, "%s_table_restore" % self.name):
|
||||
with ops.name_scope("%s_table_restore" % self.table_name):
|
||||
with ops.colocate_with(self.op.resource_handle):
|
||||
return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
|
||||
restored_tensors[0],
|
||||
|
@ -2166,25 +2168,27 @@ class DenseHashTable(LookupInterface):
|
|||
return {
|
||||
"table":
|
||||
functools.partial(
|
||||
DenseHashTable._Saveable, table=self, name=self._name)
|
||||
DenseHashTable._Saveable, table=self, name=self._name,
|
||||
table_name=self._name)
|
||||
}
|
||||
|
||||
class _Saveable(BaseSaverBuilder.SaveableObject):
|
||||
"""SaveableObject implementation for DenseHashTable."""
|
||||
|
||||
def __init__(self, table, name):
|
||||
def __init__(self, table, name, table_name=None):
|
||||
tensors = table.export()
|
||||
specs = [
|
||||
BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
|
||||
BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
|
||||
]
|
||||
self.table_name = table_name or name
|
||||
# pylint: disable=protected-access
|
||||
super(DenseHashTable._Saveable, self).__init__(table, specs, name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes, name=None):
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
del restored_shapes # unused
|
||||
# pylint: disable=protected-access
|
||||
with ops.name_scope(name, "%s_table_restore" % self.name):
|
||||
with ops.name_scope("%s_table_restore" % self.table_name):
|
||||
with ops.colocate_with(self.op.resource_handle):
|
||||
return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
|
||||
restored_tensors[0],
|
||||
|
|
|
@ -45,6 +45,7 @@ from tensorflow.python.saved_model import nested_structure_coder
|
|||
from tensorflow.python.saved_model import revived_types
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training.saving import checkpoint_options
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
@ -146,6 +147,18 @@ class Loader(object):
|
|||
self._setup_functions_structures()
|
||||
self._setup_functions_captures()
|
||||
|
||||
self._create_saveable_object_factories()
|
||||
|
||||
def _create_saveable_object_factories(self):
|
||||
for node_id, proto in enumerate(self._proto.nodes):
|
||||
node = self.get(node_id)
|
||||
node._self_saveable_object_factories = {} # pylint: disable=protected-access
|
||||
for name, saveable_object_proto in proto.saveable_objects.items():
|
||||
node._self_saveable_object_factories[name] = ( # pylint: disable=protected-access
|
||||
saveable_object_util.restored_saved_object_factory(
|
||||
self.get(saveable_object_proto.save_function),
|
||||
self.get(saveable_object_proto.restore_function)))
|
||||
|
||||
def _load_edges(self):
|
||||
"""Adds edges from objects to other objects and functions."""
|
||||
for node_id, object_proto in enumerate(self._proto.nodes):
|
||||
|
|
|
@ -1795,6 +1795,22 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
|||
options = load_options.LoadOptions(experimental_io_device="/job:localhost")
|
||||
self.assertEqual("/job:localhost", options.experimental_io_device)
|
||||
|
||||
def test_load_custom_saveable_object(self, cycles):
|
||||
root = tracking.AutoTrackable()
|
||||
root.table = lookup_ops.MutableHashTable(dtypes.string, dtypes.float32, -1)
|
||||
root.table.insert("foo", 15)
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec(None, dtypes.string)])
|
||||
def lookup(key):
|
||||
return root.table.lookup(key)
|
||||
|
||||
root.lookup = lookup
|
||||
|
||||
imported = cycle(root, cycles)
|
||||
self.assertEqual(self.evaluate(imported.lookup("foo")), 15)
|
||||
self.assertEqual(self.evaluate(imported.lookup("idk")), -1)
|
||||
|
||||
|
||||
class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import os
|
||||
|
||||
from tensorflow.core.framework import versions_pb2
|
||||
|
@ -53,6 +54,7 @@ from tensorflow.python.saved_model import tag_constants
|
|||
from tensorflow.python.saved_model import utils_impl
|
||||
from tensorflow.python.training.saving import checkpoint_options
|
||||
from tensorflow.python.training.saving import functional_saver
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
@ -136,12 +138,15 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
|
|||
return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access
|
||||
self._serialization_cache)
|
||||
|
||||
def list_functions(self, obj):
|
||||
def list_functions(self, obj, extra_functions=None):
|
||||
obj_functions = self._functions.get(obj, None)
|
||||
if obj_functions is None:
|
||||
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
|
||||
self._serialization_cache)
|
||||
self._functions[obj] = obj_functions
|
||||
if extra_functions:
|
||||
obj_functions = obj_functions.copy()
|
||||
obj_functions.update(extra_functions)
|
||||
return obj_functions
|
||||
|
||||
|
||||
|
@ -177,6 +182,12 @@ class _SaveableView(object):
|
|||
self.slot_variables = slot_variables
|
||||
self.concrete_functions = []
|
||||
|
||||
self.saveable_objects_for_node, all_saveable_functions = (
|
||||
self._add_saveable_objects())
|
||||
saveable_object_functions = {
|
||||
"__SAVEABLE_FUNCTION_{}".format(n): fn
|
||||
for n, fn in enumerate(all_saveable_functions)}
|
||||
|
||||
# Maps functions -> wrapped functions that capture variables
|
||||
self.wrapped_functions = wrapped_functions or {}
|
||||
# Maps names of concrete functions in the object to names of wrapped
|
||||
|
@ -190,7 +201,8 @@ class _SaveableView(object):
|
|||
nodes_without_functions = list(self.nodes)
|
||||
seen_function_names = set()
|
||||
for node in nodes_without_functions:
|
||||
for function in checkpoint_view.list_functions(node).values():
|
||||
for function in checkpoint_view.list_functions(
|
||||
node, saveable_object_functions).values():
|
||||
if function not in self.node_ids:
|
||||
self.node_ids[function] = len(self.nodes)
|
||||
self.nodes.append(function)
|
||||
|
@ -209,6 +221,25 @@ class _SaveableView(object):
|
|||
seen_function_names.add(concrete_function.name)
|
||||
self.concrete_functions.append(concrete_function)
|
||||
|
||||
def _add_saveable_objects(self):
|
||||
"""Retrieves SaveablesObjects and traces their save/restore functions."""
|
||||
# Maps node -> local name -> (save function, restore function)
|
||||
saveable_objects_map = object_identity.ObjectIdentityDictionary()
|
||||
all_saveable_functions = []
|
||||
for node in self.nodes:
|
||||
if resource_variable_ops.is_resource_variable(node):
|
||||
# Resource (and TPU/Mirrored) variables are automatically revived with
|
||||
# their saveables defined, so there is no need to trace the save
|
||||
# and restore functions.
|
||||
continue
|
||||
saveable_map = saveable_object_util.trace_save_restore_functions(node)
|
||||
if saveable_map:
|
||||
saveable_objects_map[node] = saveable_map
|
||||
for save_fn, restore_fn in saveable_map.values():
|
||||
all_saveable_functions.append(save_fn)
|
||||
all_saveable_functions.append(restore_fn)
|
||||
return saveable_objects_map, all_saveable_functions
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self.nodes[0]
|
||||
|
@ -233,6 +264,15 @@ class _SaveableView(object):
|
|||
child_proto.node_id = self.node_ids[ref_function]
|
||||
child_proto.local_name = local_name
|
||||
|
||||
if node not in self.saveable_objects_for_node:
|
||||
continue
|
||||
|
||||
for local_name, (save_fn, restore_fn) in (
|
||||
self.saveable_objects_for_node[node].items()):
|
||||
saveable_object_proto = object_proto.saveable_objects[local_name]
|
||||
saveable_object_proto.save_function = self.node_ids[save_fn]
|
||||
saveable_object_proto.restore_function = self.node_ids[restore_fn]
|
||||
|
||||
def map_resources(self):
|
||||
"""Makes new resource handle ops corresponding to existing resource tensors.
|
||||
|
||||
|
@ -605,7 +645,9 @@ def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
|
|||
# the exported graph (thus the `to_graph` argument).
|
||||
saver = functional_saver.MultiDeviceSaver(
|
||||
saveable_view.checkpoint_view.frozen_saveable_objects(
|
||||
object_map=object_map, to_graph=exported_graph))
|
||||
object_map=object_map, to_graph=exported_graph,
|
||||
call_with_mapped_captures=functools.partial(
|
||||
_call_function_with_mapped_captures, resource_map=resource_map)))
|
||||
|
||||
with exported_graph.as_default():
|
||||
signatures = _generate_signatures(signature_functions, resource_map)
|
||||
|
|
|
@ -17,15 +17,26 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import six
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
|
||||
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
|
@ -279,7 +290,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
|||
raise ValueError(
|
||||
("Two different ResourceVariable objects with the same "
|
||||
"shared_name '%s' were passed to the Saver. This likely means "
|
||||
"that they were created in different Graphs or isolation "
|
||||
"that they were created in different Graphs or isoWlation "
|
||||
"contexts, and may not be checkpointed together.") %
|
||||
(var._shared_name,))
|
||||
else:
|
||||
|
@ -349,3 +360,147 @@ def validate_and_slice_inputs(names_to_saveables):
|
|||
for converted_saveable_object in saveable_objects_for_op(op, name):
|
||||
_add_saveable(saveables, seen_ops, converted_saveable_object)
|
||||
return saveables
|
||||
|
||||
|
||||
def trace_save_restore_functions(object_to_save):
|
||||
"""Gathers all SaveableObjects and traces the save and restore ops."""
|
||||
saveable_map = {} # Maps name -> (save function, restore function)
|
||||
for name, saveable_factory in (
|
||||
object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
|
||||
if not callable(saveable_factory):
|
||||
if isinstance(saveable_factory, saveable_object.SaveableObject):
|
||||
logging.debug(
|
||||
"Trackable {} should return callable factories, not SaveableObjects"
|
||||
" in `_gather_saveables_for_checkpoint`. This could lead to "
|
||||
"problems loading the SavedModel back into Python."
|
||||
.format(object_to_save))
|
||||
continue
|
||||
|
||||
if is_factory_for_restored_saveable_object(saveable_factory):
|
||||
saveable_map[name] = (saveable_factory.keywords["save_function"],
|
||||
saveable_factory.keywords["restore_function"])
|
||||
else:
|
||||
concrete_save_fn, concrete_restore_fn = _trace_save_and_restore_function(
|
||||
saveable_factory, object_to_save)
|
||||
if concrete_save_fn is not None:
|
||||
saveable_map[name] = (concrete_save_fn, concrete_restore_fn)
|
||||
return saveable_map
|
||||
|
||||
|
||||
def _trace_save_and_restore_function(saveable_factory, object_to_save):
|
||||
"""Traces the save and restore concrete functions."""
|
||||
saveables = []
|
||||
|
||||
@def_function.function(
|
||||
input_signature=[tensor_spec.TensorSpec([], dtypes.string)])
|
||||
def save_fn(checkpoint_key):
|
||||
maybe_saveable = saveable_factory(name=checkpoint_key)
|
||||
if isinstance(maybe_saveable, saveable_object.SaveableObject):
|
||||
maybe_saveable = [maybe_saveable]
|
||||
saveables[:] = maybe_saveable
|
||||
|
||||
# Return list of all SaveSpecs created by the factory.
|
||||
ret = []
|
||||
for saveable in saveables:
|
||||
for spec in saveable.specs:
|
||||
ret.append({"name": spec.name, "tensor": spec.tensor,
|
||||
"slice_spec": spec.slice_spec})
|
||||
return ret
|
||||
|
||||
concrete_save_fn = save_fn.get_concrete_function()
|
||||
if any(isinstance(saveable, trackable.PythonStateSaveable)
|
||||
for saveable in saveables):
|
||||
logging.warn(
|
||||
"Note that object {} stores python values into the checkpoint. "
|
||||
"These values will not be restored when loading the SavedModel "
|
||||
"into python.".format(object_to_save))
|
||||
return None, None
|
||||
if any(isinstance(saveable, trackable.NoRestoreSaveable)
|
||||
for saveable in saveables):
|
||||
return None, None
|
||||
|
||||
restored_type_specs = []
|
||||
tensor_structure = []
|
||||
for saveable in saveables:
|
||||
saveable_tensor_structure = []
|
||||
tensor_structure.append(saveable_tensor_structure)
|
||||
for spec in saveable.specs:
|
||||
restored_type_specs.append(type_spec.type_spec_from_value(spec.tensor))
|
||||
saveable_tensor_structure.append(spec.name)
|
||||
|
||||
@def_function.function(input_signature=restored_type_specs)
|
||||
def restore_fn(*restored_tensors):
|
||||
structured_restored_tensors = nest.pack_sequence_as(
|
||||
tensor_structure, restored_tensors)
|
||||
for saveable, restored_tensors in zip(saveables,
|
||||
structured_restored_tensors):
|
||||
saveable.restore(restored_tensors, restored_shapes=None)
|
||||
return 1
|
||||
|
||||
concrete_restore_fn = restore_fn.get_concrete_function()
|
||||
return concrete_save_fn, concrete_restore_fn
|
||||
|
||||
|
||||
class RestoredSaveableObject(saveable_object.SaveableObject):
|
||||
"""SaveableObject restored from SavedModel using the traced save/restore."""
|
||||
|
||||
def __init__(self, save_function, restore_function, name):
|
||||
self.save_function = save_function
|
||||
self.restore_function = restore_function
|
||||
|
||||
if tensor_util.is_tensor(name):
|
||||
name_tensor = name
|
||||
else:
|
||||
with ops.init_scope():
|
||||
name_tensor = constant_op.constant(name)
|
||||
tensors = save_function(name_tensor)
|
||||
specs = [saveable_object.SaveSpec(x["tensor"], x["slice_spec"], x["name"])
|
||||
for x in tensors]
|
||||
super(RestoredSaveableObject, self).__init__(None, specs, name)
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
del restored_shapes # unused
|
||||
return self.restore_function(
|
||||
*[restored_tensors[i] for i in range(len(self.specs))])
|
||||
|
||||
|
||||
def restored_saved_object_factory(save_function, restore_function):
|
||||
return functools.partial(RestoredSaveableObject,
|
||||
save_function=save_function,
|
||||
restore_function=restore_function)
|
||||
|
||||
|
||||
def create_saveable_object(factory, name, call_with_mapped_captures):
|
||||
"""Creates a SaveableObject while potentially in a different graph.
|
||||
|
||||
When creating the frozen saver for SavedModel, the save and restore ops are
|
||||
placed in a separate graph. Since RestoredSaveableObject uses tf.functions to
|
||||
save and restore, the function captures must be mapped to the new graph.
|
||||
|
||||
Args:
|
||||
factory: Factory method for creating the SaveableObject.
|
||||
name: Checkpoint key of this SaveableObject.
|
||||
call_with_mapped_captures: Helper that calls a tf.function while remapping
|
||||
the captures.
|
||||
|
||||
Returns:
|
||||
a SaveableObject.
|
||||
"""
|
||||
if (call_with_mapped_captures is None or
|
||||
not is_factory_for_restored_saveable_object(factory)):
|
||||
return factory(name=name)
|
||||
|
||||
concrete_save_fn = factory.keywords["save_function"]
|
||||
def save_fn(name):
|
||||
return call_with_mapped_captures(concrete_save_fn, [name])
|
||||
|
||||
concrete_restore_fn = factory.keywords["restore_function"]
|
||||
def restore_fn(*restored_tensors):
|
||||
return call_with_mapped_captures(concrete_restore_fn, restored_tensors)
|
||||
|
||||
return factory(save_function=save_fn, restore_function=restore_fn, name=name)
|
||||
|
||||
|
||||
def is_factory_for_restored_saveable_object(factory):
|
||||
return (isinstance(factory, functools.partial) and
|
||||
factory.func is RestoredSaveableObject)
|
||||
|
|
|
@ -611,6 +611,12 @@ class Trackable(object):
|
|||
# building.
|
||||
self._self_name_based_restores = set()
|
||||
|
||||
# Dictionary of SaveableObjects factories. This dictionary is defined when
|
||||
# the object is loaded from the SavedModel. When writing a custom class,
|
||||
# prefer overriding "_gather_saveables_from_checkpoint" to using this
|
||||
# attribute.
|
||||
self._self_saveable_object_factories = {}
|
||||
|
||||
@property
|
||||
def _object_identifier(self):
|
||||
"""String used to identify this object in a SavedModel.
|
||||
|
@ -972,7 +978,7 @@ class Trackable(object):
|
|||
lambda name="global_name_for_this_object":
|
||||
SaveableObject(name=name, ...)}
|
||||
"""
|
||||
return {}
|
||||
return self._self_saveable_object_factories
|
||||
|
||||
def _list_extra_dependencies_for_serialization(self, serialization_cache):
|
||||
"""Lists extra dependencies to serialize.
|
||||
|
|
|
@ -208,7 +208,7 @@ class ObjectGraphView(object):
|
|||
|
||||
def _add_attributes_to_object_graph(
|
||||
self, trackable_objects, object_graph_proto, node_ids, object_names,
|
||||
object_map):
|
||||
object_map, call_with_mapped_captures):
|
||||
"""Create SaveableObjects and corresponding SerializedTensor protos."""
|
||||
named_saveable_objects = []
|
||||
if self._saveables_cache is None:
|
||||
|
@ -253,7 +253,9 @@ class ObjectGraphView(object):
|
|||
break
|
||||
if saveables is None:
|
||||
if callable(saveable_factory):
|
||||
maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
|
||||
maybe_saveable = saveable_object_util.create_saveable_object(
|
||||
saveable_factory, attribute.checkpoint_key,
|
||||
call_with_mapped_captures)
|
||||
else:
|
||||
maybe_saveable = saveable_factory
|
||||
if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
|
||||
|
@ -332,7 +334,8 @@ class ObjectGraphView(object):
|
|||
return object_graph_proto
|
||||
|
||||
def _serialize_gathered_objects(self, trackable_objects, path_to_root,
|
||||
object_map=None):
|
||||
object_map=None,
|
||||
call_with_mapped_captures=None):
|
||||
"""Create SaveableObjects and protos for gathered objects."""
|
||||
object_names = object_identity.ObjectIdentityDictionary()
|
||||
for obj, path in path_to_root.items():
|
||||
|
@ -354,7 +357,8 @@ class ObjectGraphView(object):
|
|||
object_graph_proto=object_graph_proto,
|
||||
node_ids=node_ids,
|
||||
object_names=object_names,
|
||||
object_map=object_map))
|
||||
object_map=object_map,
|
||||
call_with_mapped_captures=call_with_mapped_captures))
|
||||
return named_saveable_objects, object_graph_proto, feed_additions
|
||||
|
||||
def serialize_object_graph(self):
|
||||
|
@ -382,7 +386,8 @@ class ObjectGraphView(object):
|
|||
return self._serialize_gathered_objects(
|
||||
trackable_objects, path_to_root)
|
||||
|
||||
def frozen_saveable_objects(self, object_map=None, to_graph=None):
|
||||
def frozen_saveable_objects(self, object_map=None, to_graph=None,
|
||||
call_with_mapped_captures=None):
|
||||
"""Creates SaveableObjects with the current object graph frozen."""
|
||||
trackable_objects, path_to_root = self._breadth_first_traversal()
|
||||
if to_graph:
|
||||
|
@ -393,7 +398,8 @@ class ObjectGraphView(object):
|
|||
named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
|
||||
trackable_objects,
|
||||
path_to_root,
|
||||
object_map)
|
||||
object_map,
|
||||
call_with_mapped_captures)
|
||||
with ops.device("/cpu:0"):
|
||||
object_graph_tensor = constant_op.constant(
|
||||
graph_proto.SerializeToString(), dtype=dtypes.string)
|
||||
|
|
Loading…
Reference in New Issue