Clean up initialization of _SaveableView
[no functional change].
PiperOrigin-RevId: 340680049 Change-Id: I60ee42e726ea8a18b2e977818938a5cb35c4da61
This commit is contained in:
parent
7a050b85d8
commit
2ed1873d25
@ -145,15 +145,12 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
|
||||
return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access
|
||||
self._serialization_cache)
|
||||
|
||||
def list_functions(self, obj, extra_functions=None):
|
||||
def list_functions(self, obj):
|
||||
obj_functions = self._functions.get(obj, None)
|
||||
if obj_functions is None:
|
||||
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
|
||||
self._serialization_cache)
|
||||
self._functions[obj] = obj_functions
|
||||
if extra_functions:
|
||||
obj_functions = obj_functions.copy()
|
||||
obj_functions.update(extra_functions)
|
||||
return obj_functions
|
||||
|
||||
|
||||
@ -181,85 +178,100 @@ class _SaveableView(object):
|
||||
wrapped_functions: Dictionary that maps concrete functions to functions
|
||||
that do not capture cached variable values.
|
||||
"""
|
||||
self.options = options
|
||||
|
||||
self.checkpoint_view = checkpoint_view
|
||||
trackable_objects, path_to_root, node_ids, slot_variables = (
|
||||
self.checkpoint_view.objects_ids_and_slot_variables_and_paths())
|
||||
self.node_paths = path_to_root
|
||||
self.nodes = trackable_objects
|
||||
self.node_ids = node_ids
|
||||
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
|
||||
self.slot_variables = slot_variables
|
||||
self.concrete_functions = []
|
||||
self.untraced_functions = []
|
||||
|
||||
self.saveable_objects_for_node, all_saveable_functions = (
|
||||
self._add_saveable_objects())
|
||||
saveable_object_functions = {
|
||||
"__SAVEABLE_FUNCTION_{}".format(n): fn
|
||||
for n, fn in enumerate(all_saveable_functions)}
|
||||
|
||||
self._options = options
|
||||
# Maps functions -> wrapped functions that capture variables
|
||||
self.wrapped_functions = wrapped_functions or {}
|
||||
self._wrapped_functions = wrapped_functions or {}
|
||||
# Run through the nodes in the object graph first for side effects of
|
||||
# creating variables.
|
||||
self._trace_all_concrete_functions()
|
||||
|
||||
(self._trackable_objects, self.node_paths, self._node_ids,
|
||||
self._slot_variables) = (
|
||||
self.checkpoint_view.objects_ids_and_slot_variables_and_paths())
|
||||
self._initialize_nodes_and_concrete_functions()
|
||||
|
||||
# Maps names of concrete functions in the object to names of wrapped
|
||||
# functions. When writing the SavedFunction protos, the names of the
|
||||
# wrapped functions should be used in place of the original functions.
|
||||
self.function_name_map = {
|
||||
compat.as_text(original.name): compat.as_text(wrapped.name)
|
||||
for original, wrapped in self.wrapped_functions.items()}
|
||||
for original, wrapped in self._wrapped_functions.items()}
|
||||
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
|
||||
|
||||
# Also add `Function`s as nodes.
|
||||
nodes_without_functions = list(self.nodes)
|
||||
seen_function_names = set()
|
||||
for node in nodes_without_functions:
|
||||
for function in checkpoint_view.list_functions(
|
||||
node, saveable_object_functions).values():
|
||||
if function not in self.node_ids:
|
||||
self.node_ids[function] = len(self.nodes)
|
||||
self.nodes.append(function)
|
||||
if isinstance(function, def_function.Function):
|
||||
# Force listing the concrete functions for the side effects:
|
||||
# - populate the cache for functions that have an input_signature
|
||||
# and have not been called.
|
||||
# - force side effects of creation of concrete functions, e.g. create
|
||||
# variables on first run.
|
||||
concrete_functions = (
|
||||
function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access
|
||||
else:
|
||||
concrete_functions = [function]
|
||||
if not concrete_functions:
|
||||
self.untraced_functions.append(function._name)
|
||||
def _initialize_nodes_and_concrete_functions(self):
|
||||
"""Creates graph with nodes for trackable objects and functions.
|
||||
|
||||
for concrete_function in concrete_functions:
|
||||
if concrete_function.name not in seen_function_names:
|
||||
seen_function_names.add(concrete_function.name)
|
||||
self.concrete_functions.append(concrete_function)
|
||||
if self.untraced_functions:
|
||||
Adds functions for each trackable object to `self.nodes` and associated
|
||||
concrete functions to `self.concrete_functions` for serialization. Also adds
|
||||
the object's save and restore functions for loading values from checkpoint.
|
||||
"""
|
||||
self.nodes = list(self._trackable_objects)
|
||||
self.concrete_functions = []
|
||||
self._seen_function_names = set()
|
||||
self._untraced_functions = []
|
||||
# Maps node -> local name -> (save function, restore function)
|
||||
self._saveable_objects_map = object_identity.ObjectIdentityDictionary()
|
||||
|
||||
for obj in self._trackable_objects:
|
||||
for function in self.checkpoint_view.list_functions(obj).values():
|
||||
self._add_function_to_graph(function)
|
||||
# Resource (and TPU/Mirrored) variables are automatically revived with
|
||||
# their saveables defined, so there is no need to trace the save
|
||||
# and restore functions.
|
||||
if resource_variable_ops.is_resource_variable(obj):
|
||||
continue
|
||||
# Trace object save and restore functions to populate `saveables_map`
|
||||
# field in the SavedModel proto.
|
||||
saveable_map = saveable_object_util.trace_save_restore_functions(obj)
|
||||
if saveable_map:
|
||||
for save_fn, restore_fn in saveable_map.values():
|
||||
self._add_function_to_graph(save_fn)
|
||||
self._add_function_to_graph(restore_fn)
|
||||
self._saveable_objects_map[obj] = saveable_map
|
||||
|
||||
if self._untraced_functions:
|
||||
logging.warning(
|
||||
"Found untraced functions such as %s while saving (showing %d of %d)."
|
||||
" These functions will not be directly callable after loading.",
|
||||
", ".join(self.untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
|
||||
min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self.untraced_functions)),
|
||||
len(self.untraced_functions))
|
||||
", ".join(self._untraced_functions[:_NUM_DISPLAY_UNTRACED_FUNCTIONS]),
|
||||
min(_NUM_DISPLAY_UNTRACED_FUNCTIONS, len(self._untraced_functions)),
|
||||
len(self._untraced_functions))
|
||||
|
||||
def _add_saveable_objects(self):
|
||||
"""Retrieves SaveablesObjects and traces their save/restore functions."""
|
||||
# Maps node -> local name -> (save function, restore function)
|
||||
saveable_objects_map = object_identity.ObjectIdentityDictionary()
|
||||
all_saveable_functions = []
|
||||
for node in self.nodes:
|
||||
if resource_variable_ops.is_resource_variable(node):
|
||||
# Resource (and TPU/Mirrored) variables are automatically revived with
|
||||
# their saveables defined, so there is no need to trace the save
|
||||
# and restore functions.
|
||||
continue
|
||||
saveable_map = saveable_object_util.trace_save_restore_functions(node)
|
||||
if saveable_map:
|
||||
saveable_objects_map[node] = saveable_map
|
||||
for save_fn, restore_fn in saveable_map.values():
|
||||
all_saveable_functions.append(save_fn)
|
||||
all_saveable_functions.append(restore_fn)
|
||||
return saveable_objects_map, all_saveable_functions
|
||||
def _add_function_to_graph(self, function):
|
||||
"""Adds function to serialize to graph."""
|
||||
# Updates self.nodes, self._node_ids, self.concrete_functions,
|
||||
# and self._untraced_functions.
|
||||
if function not in self._node_ids:
|
||||
self._node_ids[function] = len(self.nodes)
|
||||
# Add the function to nodes as well.
|
||||
self.nodes.append(function)
|
||||
if isinstance(function, def_function.Function):
|
||||
concrete_functions = (
|
||||
function._list_all_concrete_functions_for_serialization()) # pylint: disable=protected-access
|
||||
else:
|
||||
concrete_functions = [function]
|
||||
if not concrete_functions:
|
||||
self._untraced_functions.append(function._name) # pylint: disable=protected-access
|
||||
for concrete_function in concrete_functions:
|
||||
if concrete_function.name not in self._seen_function_names:
|
||||
self.concrete_functions.append(concrete_function)
|
||||
self._seen_function_names.add(concrete_function.name)
|
||||
|
||||
def _trace_all_concrete_functions(self):
|
||||
"""Trace concrete functions to force side-effects.
|
||||
|
||||
Lists the concrete functions in order to:
|
||||
- populate the cache for functions that have an input_signature
|
||||
and have not been called
|
||||
- force side effects of creation of concrete functions, e.g. create
|
||||
variables on first run.
|
||||
"""
|
||||
for obj in self.checkpoint_view.list_objects():
|
||||
for function in self.checkpoint_view.list_functions(obj).values():
|
||||
if isinstance(function, def_function.Function):
|
||||
function._list_all_concrete_functions_for_serialization() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
@ -268,31 +280,31 @@ class _SaveableView(object):
|
||||
def fill_object_graph_proto(self, proto):
|
||||
"""Populate the nodes, children and slot_variables of a SavedObjectGraph."""
|
||||
for node_id, node in enumerate(self.nodes):
|
||||
assert self.node_ids[node] == node_id
|
||||
assert self._node_ids[node] == node_id
|
||||
object_proto = proto.nodes.add()
|
||||
object_proto.slot_variables.extend(self.slot_variables.get(node, ()))
|
||||
object_proto.slot_variables.extend(self._slot_variables.get(node, ()))
|
||||
if isinstance(
|
||||
node,
|
||||
(def_function.Function, defun.ConcreteFunction, _CapturedConstant)):
|
||||
continue
|
||||
for child in self.checkpoint_view.list_dependencies(node):
|
||||
child_proto = object_proto.children.add()
|
||||
child_proto.node_id = self.node_ids[child.ref]
|
||||
child_proto.node_id = self._node_ids[child.ref]
|
||||
child_proto.local_name = child.name
|
||||
for local_name, ref_function in (
|
||||
self.checkpoint_view.list_functions(node).items()):
|
||||
child_proto = object_proto.children.add()
|
||||
child_proto.node_id = self.node_ids[ref_function]
|
||||
child_proto.node_id = self._node_ids[ref_function]
|
||||
child_proto.local_name = local_name
|
||||
|
||||
if node not in self.saveable_objects_for_node:
|
||||
if node not in self._saveable_objects_map:
|
||||
continue
|
||||
|
||||
for local_name, (save_fn, restore_fn) in (
|
||||
self.saveable_objects_for_node[node].items()):
|
||||
self._saveable_objects_map[node].items()):
|
||||
saveable_object_proto = object_proto.saveable_objects[local_name]
|
||||
saveable_object_proto.save_function = self.node_ids[save_fn]
|
||||
saveable_object_proto.restore_function = self.node_ids[restore_fn]
|
||||
saveable_object_proto.save_function = self._node_ids[save_fn]
|
||||
saveable_object_proto.restore_function = self._node_ids[restore_fn]
|
||||
|
||||
def map_resources(self):
|
||||
"""Makes new resource handle ops corresponding to existing resource tensors.
|
||||
@ -328,7 +340,7 @@ class _SaveableView(object):
|
||||
_process_asset(obj, asset_info, resource_map)
|
||||
self.captured_tensor_node_ids[obj.asset_path] = node_id
|
||||
elif isinstance(obj, base.Trackable):
|
||||
node_object_map, node_resource_map = obj._map_resources(self.options) # pylint: disable=protected-access
|
||||
node_object_map, node_resource_map = obj._map_resources(self._options) # pylint: disable=protected-access
|
||||
for capturable in node_resource_map.keys():
|
||||
self.captured_tensor_node_ids[capturable] = node_id
|
||||
object_map.update(node_object_map)
|
||||
@ -350,8 +362,8 @@ class _SaveableView(object):
|
||||
capture.dtype not in _UNCOPIABLE_DTYPES and
|
||||
capture not in self.captured_tensor_node_ids):
|
||||
if hasattr(capture, "_cached_variable"):
|
||||
if concrete_function not in self.wrapped_functions:
|
||||
wrapped = self.wrapped_functions[concrete_function] = (
|
||||
if concrete_function not in self._wrapped_functions:
|
||||
wrapped = self._wrapped_functions[concrete_function] = (
|
||||
function_serialization.wrap_cached_variables(
|
||||
concrete_function))
|
||||
self.function_name_map[compat.as_text(concrete_function.name)] = (
|
||||
@ -366,13 +378,13 @@ class _SaveableView(object):
|
||||
node = _CapturedConstant(
|
||||
eager_tensor=capture, graph_tensor=copied_tensor)
|
||||
self.nodes.append(node)
|
||||
self.node_ids[capture] = node_id
|
||||
self.node_ids[node] = node_id
|
||||
self._node_ids[capture] = node_id
|
||||
self._node_ids[node] = node_id
|
||||
self.captured_tensor_node_ids[capture] = node_id
|
||||
resource_map[capture] = copied_tensor
|
||||
|
||||
self.concrete_functions = [
|
||||
self.wrapped_functions.get(x, x) for x in self.concrete_functions
|
||||
self._wrapped_functions.get(x, x) for x in self.concrete_functions
|
||||
if x not in bad_functions
|
||||
]
|
||||
return object_map, resource_map, asset_info
|
||||
@ -1178,9 +1190,6 @@ def _build_meta_graph_impl(obj,
|
||||
subgraph_root=signature_map)
|
||||
|
||||
# Use _SaveableView to provide a frozen listing of properties and functions.
|
||||
# Note we run this twice since, while constructing the view the first time
|
||||
# there can be side effects of creating variables.
|
||||
_ = _SaveableView(checkpoint_graph_view, options)
|
||||
saveable_view = _SaveableView(checkpoint_graph_view, options,
|
||||
wrapped_functions)
|
||||
object_saver = util.TrackableSaver(checkpoint_graph_view)
|
||||
|
Loading…
Reference in New Issue
Block a user