(rollforward of cl/337218666) Add method to partially load a SavedModel.
PiperOrigin-RevId: 337950896 Change-Id: Idd0a9e963b34671bdf1d7b87389e2325848e5eea
This commit is contained in:
parent
c9849876cd
commit
63f17d0fe1
tensorflow/python
@ -118,7 +118,7 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
|
||||
# TODO(kathywu): Add code to load from objects that contain all endpoints
|
||||
|
||||
model = tf_load.load_internal(
|
||||
path, options=options, loader_cls=KerasObjectLoader)
|
||||
path, options=options, loader_cls=KerasObjectLoader)['root']
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if isinstance(model, training_lib.Model) and compile:
|
||||
|
@ -48,6 +48,7 @@ from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training.saving import checkpoint_options
|
||||
from tensorflow.python.training.saving import saveable_object_util
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.training.tracking import util
|
||||
@ -119,7 +120,7 @@ class Loader(object):
|
||||
"""Helper class to load an object-based SavedModel."""
|
||||
|
||||
def __init__(self, object_graph_proto, saved_model_proto, export_dir,
|
||||
ckpt_options):
|
||||
ckpt_options, filters):
|
||||
meta_graph = saved_model_proto.meta_graphs[0]
|
||||
self._asset_file_def = meta_graph.asset_file_def
|
||||
self._operation_attributes = {
|
||||
@ -131,6 +132,26 @@ class Loader(object):
|
||||
meta_graph.graph_def.library))
|
||||
self._checkpoint_options = ckpt_options
|
||||
|
||||
# Stores user-defined node_filters argument.
|
||||
self._node_filters = filters
|
||||
# Stores map of string paths to integers.
|
||||
self._node_path_to_id = self._convert_node_paths_to_ints()
|
||||
self._loaded_nodes = {}
|
||||
if isinstance(filters, dict):
|
||||
# If node_filters is a dict, then the values may contain already created
|
||||
# trackable objects. In this case, create a dictionary mapping node IDs to
|
||||
# the already created nodes. This dict will be updated in
|
||||
# `_retrieve_all_filtered_nodes` with tracked dependencies.
|
||||
for node_path, node in filters.items():
|
||||
if isinstance(node, tuple):
|
||||
self._loaded_nodes[self._node_path_to_id[node_path]] = node
|
||||
else:
|
||||
self._loaded_nodes[self._node_path_to_id[node_path]] = (node, setattr)
|
||||
|
||||
# Get a list of all integer node ids to load, or None if all nodes should be
|
||||
# loaded. This list includes ids of child nodes.
|
||||
self._filtered_nodes = self._retrieve_all_filtered_nodes()
|
||||
|
||||
for name, concrete_function in self._concrete_functions.items():
|
||||
# Wrap all the concrete function so that they are capable of dealing with
|
||||
# both in replica and cross replica cases.
|
||||
@ -145,6 +166,91 @@ class Loader(object):
|
||||
if not context.executing_eagerly():
|
||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
|
||||
|
||||
def _convert_node_paths_to_ints(self):
|
||||
"""Maps all string node paths in node_filters to the int node ids."""
|
||||
if self._node_filters is None:
|
||||
return None
|
||||
path_to_int = {}
|
||||
for node_id in self._node_filters:
|
||||
int_node_id = None
|
||||
if isinstance(node_id, str):
|
||||
node_path = node_id.split(".")
|
||||
if node_path[0] != "root":
|
||||
raise ValueError(
|
||||
"When passing string identifiers to node_filters, the first name"
|
||||
" must be root.")
|
||||
int_node_id = 0
|
||||
for n, name in enumerate(node_path[1:]):
|
||||
int_node_id = self._find_node_child(
|
||||
int_node_id, name, ".".join(node_path[:n+2]))
|
||||
path_to_int[node_id] = int_node_id
|
||||
else:
|
||||
raise TypeError("Elements in node_filters must be strings.")
|
||||
return path_to_int
|
||||
|
||||
def _retrieve_all_filtered_nodes(self):
|
||||
"""Traverses through the object graph to get the IDs of all nodes to load.
|
||||
|
||||
As a side-effect, if node_filters is a dictionary that contains already-
|
||||
created objects, then the dependencies tracked by those objects will be
|
||||
added to node_filters.
|
||||
|
||||
Returns:
|
||||
List of all nodes to load, or None if all nodes should be loaded.
|
||||
|
||||
"""
|
||||
if self._node_filters is None:
|
||||
return None # All nodes should be loaded.
|
||||
|
||||
all_filtered_nodes = set()
|
||||
nodes_to_visit = list(self._node_filters)
|
||||
|
||||
while nodes_to_visit:
|
||||
node_path = nodes_to_visit.pop(0)
|
||||
node_id = self._node_path_to_id[node_path]
|
||||
if node_id in all_filtered_nodes:
|
||||
continue
|
||||
all_filtered_nodes.add(node_id)
|
||||
|
||||
node, setter = self._loaded_nodes.get(node_id, (None, None))
|
||||
if node is not None:
|
||||
if not isinstance(node, base.Trackable):
|
||||
raise TypeError(
|
||||
"Error when processing dictionary values passed to nodes_to_load."
|
||||
"Object at {} is expected to be a checkpointable TensorFlow "
|
||||
"object (e.g. tf.Variable, tf.Module or Keras layer)."
|
||||
.format(node_path))
|
||||
node._maybe_initialize_trackable() # pylint: disable=protected-access
|
||||
|
||||
for reference in self._proto.nodes[node_id].children:
|
||||
child_object, _ = self._loaded_nodes.get(
|
||||
reference.node_id, (None, None))
|
||||
|
||||
# See if node already tracks the child reference, in which case add the
|
||||
# child to the loaded_nodes dict.
|
||||
if child_object is None and node is not None:
|
||||
child_object = node._lookup_dependency(reference.local_name) # pylint: disable=protected-access
|
||||
if isinstance(child_object, data_structures.TrackableDataStructure):
|
||||
# Make setattr a noop to avoid overwriting already existing data
|
||||
# structures.
|
||||
setter = lambda *args: None
|
||||
|
||||
self._loaded_nodes[reference.node_id] = (child_object, setter)
|
||||
|
||||
child_path = "{}.{}".format(node_path, reference.local_name)
|
||||
self._node_path_to_id[child_path] = reference.node_id
|
||||
nodes_to_visit.append(child_path)
|
||||
|
||||
if 0 in all_filtered_nodes:
|
||||
return None
|
||||
return all_filtered_nodes
|
||||
|
||||
def _find_node_child(self, node_id, child_name, path):
|
||||
for reference in self._proto.nodes[node_id].children:
|
||||
if reference.local_name == child_name:
|
||||
return reference.node_id
|
||||
raise ValueError("unable to find node {}".format(path))
|
||||
|
||||
def _load_all(self):
|
||||
"""Loads all nodes and functions from the SavedModel and their edges."""
|
||||
self._load_nodes()
|
||||
@ -159,7 +265,7 @@ class Loader(object):
|
||||
self._create_saveable_object_factories()
|
||||
|
||||
def _create_saveable_object_factories(self):
|
||||
for node_id, proto in enumerate(self._proto.nodes):
|
||||
for node_id, proto in self._iter_all_nodes():
|
||||
node = self.get(node_id)
|
||||
node._self_saveable_object_factories = {} # pylint: disable=protected-access
|
||||
for name, saveable_object_proto in proto.saveable_objects.items():
|
||||
@ -170,9 +276,24 @@ class Loader(object):
|
||||
|
||||
def _load_edges(self):
|
||||
"""Adds edges from objects to other objects and functions."""
|
||||
for node_id, object_proto in enumerate(self._proto.nodes):
|
||||
for node_id, object_proto in self._iter_all_nodes():
|
||||
self._add_object_graph_edges(object_proto, node_id)
|
||||
|
||||
# If root object isn't loaded, then create edges from the root for
|
||||
# checkpoint compatibility.
|
||||
if self._filtered_nodes is not None and 0 not in self._filtered_nodes:
|
||||
root = self.get(0)
|
||||
for node_path in self._node_filters:
|
||||
loaded_node = self._nodes[self._node_path_to_id[node_path]]
|
||||
path = node_path.split(".")
|
||||
current_node = root
|
||||
for name in path[1:-1]:
|
||||
if not hasattr(current_node, name):
|
||||
setattr(current_node, name, self._recreate_base_user_object()[0])
|
||||
current_node = getattr(current_node, name)
|
||||
if not hasattr(current_node, path[-1]):
|
||||
setattr(current_node, path[-1], loaded_node)
|
||||
|
||||
def _add_object_graph_edges(self, proto, node_id):
|
||||
"""Adds edges from an object to its children."""
|
||||
obj = self._nodes[node_id]
|
||||
@ -214,7 +335,7 @@ class Loader(object):
|
||||
for name, proto in concrete_functions:
|
||||
concrete_function = self._concrete_functions[name]
|
||||
bound_inputs = [
|
||||
self._get_tensor_from_node(node_id)
|
||||
self._get_tensor_from_node(node_id, name)
|
||||
for node_id in proto.bound_inputs]
|
||||
bound_variables = [
|
||||
self._nodes[node_id]
|
||||
@ -251,8 +372,14 @@ class Loader(object):
|
||||
# placeholder for this input.
|
||||
concrete_function.graph.capture(bound_input)
|
||||
|
||||
def _get_tensor_from_node(self, node_id):
|
||||
def _get_tensor_from_node(self, node_id, fn_name):
|
||||
"""Resolves a node id into a tensor to be captured for a function."""
|
||||
if self._node_filters is not None and self._nodes[node_id] is None:
|
||||
raise ValueError(
|
||||
"Error when processing nodes_to_load. Function \"{}\" requires "
|
||||
"inputs/variables that are not loaded when nodes_to_load={}"
|
||||
.format(fn_name, self._node_filters))
|
||||
|
||||
with ops.init_scope():
|
||||
obj = self._nodes[node_id]
|
||||
if distribute_utils.is_distributed_variable(obj):
|
||||
@ -268,24 +395,39 @@ class Loader(object):
|
||||
return obj.resource_handle
|
||||
raise ValueError("Can't convert node %s to tensor" % (type(obj)))
|
||||
|
||||
def _initialize_loaded_nodes(self):
|
||||
nodes = {}
|
||||
node_setters = {}
|
||||
for node_id, (node, setter) in self._loaded_nodes.items():
|
||||
nodes[node_id] = node
|
||||
node_setters[node_id] = setter
|
||||
return nodes, node_setters
|
||||
|
||||
def _iter_all_nodes(self):
|
||||
if self._filtered_nodes is None:
|
||||
return enumerate(self._proto.nodes)
|
||||
else:
|
||||
return [(node_id, self._proto.nodes[node_id])
|
||||
for node_id in self._filtered_nodes]
|
||||
|
||||
def _load_nodes(self):
|
||||
"""Load all saved objects."""
|
||||
# Maps from node ids to recreated objects
|
||||
nodes = {}
|
||||
# Maps from node ids to setter functions (same signature as setattr) for
|
||||
# setting dependencies.
|
||||
node_setters = {}
|
||||
# `nodes` maps from node ids to recreated objects
|
||||
# `node_setters` maps from node ids to setter functions
|
||||
# (same signature as setattr) for setting dependencies.
|
||||
nodes, node_setters = self._initialize_loaded_nodes()
|
||||
|
||||
# Figure out which objects are slot variables. These objects are created
|
||||
# with Optimizer.add_slot rather than _recreate_variable.
|
||||
slot_variable_node_ids = set()
|
||||
for proto in self._proto.nodes:
|
||||
|
||||
for _, proto in self._iter_all_nodes():
|
||||
for slot_variable_proto in proto.slot_variables:
|
||||
slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id)
|
||||
|
||||
# Re-create everything except slot variables.
|
||||
for node_id, proto in enumerate(self._proto.nodes):
|
||||
if node_id in slot_variable_node_ids:
|
||||
for node_id, proto in self._iter_all_nodes():
|
||||
if node_id in slot_variable_node_ids or nodes.get(node_id) is not None:
|
||||
# Defer recreating slot variables so we can use the public Optimizer
|
||||
# interface.
|
||||
continue
|
||||
@ -295,7 +437,7 @@ class Loader(object):
|
||||
|
||||
# Now that we have created the variables being optimized, we have enough
|
||||
# information to re-create slot variables for them.
|
||||
for node_id, proto in enumerate(self._proto.nodes):
|
||||
for node_id, proto in self._iter_all_nodes():
|
||||
optimizer_object = nodes[node_id]
|
||||
for slot_variable_proto in proto.slot_variables:
|
||||
optimized_variable = nodes[
|
||||
@ -306,7 +448,13 @@ class Loader(object):
|
||||
nodes[slot_variable_proto.slot_variable_node_id] = slot_variable
|
||||
node_setters[slot_variable_proto.slot_variable_node_id] = setattr
|
||||
|
||||
self._nodes = [nodes[node_id] for node_id in range(len(self._proto.nodes))]
|
||||
# If root object is not loaded, add a dummy root object for checkpoint
|
||||
# compatibility.
|
||||
if 0 not in nodes:
|
||||
nodes[0] = self._recreate_base_user_object()[0]
|
||||
|
||||
self._nodes = [nodes.get(node_id)
|
||||
for node_id in range(len(self._proto.nodes))]
|
||||
self._node_setters = node_setters
|
||||
|
||||
@property
|
||||
@ -380,6 +528,8 @@ class Loader(object):
|
||||
return output_debug_info
|
||||
|
||||
def get(self, node_id):
|
||||
if isinstance(node_id, str):
|
||||
node_id = self._node_path_to_id[node_id]
|
||||
return self._nodes[node_id]
|
||||
|
||||
def _recreate(self, proto, node_id):
|
||||
@ -408,7 +558,7 @@ class Loader(object):
|
||||
return self._recreate_base_user_object(proto, node_id)
|
||||
return looked_up
|
||||
|
||||
def _recreate_base_user_object(self, proto, node_id):
|
||||
def _recreate_base_user_object(self, proto=None, node_id=None):
|
||||
del proto, node_id
|
||||
# Note: each user object has its own class. This allows making each one
|
||||
# individually callable by adding a `__call__` method to the classes of
|
||||
@ -518,6 +668,103 @@ def _call_attribute(instance, *args, **kwargs):
|
||||
return instance.__call__(*args, **kwargs)
|
||||
|
||||
|
||||
def load_partial(export_dir, filters, tags=None, options=None):
|
||||
"""Partially load a SavedModel (saved from V2).
|
||||
|
||||
Similar to `tf.saved_model.load`, but with an additional argument that
|
||||
lets you specify which nodes to load.
|
||||
`tf.saved_model.load_partial(export_dir, ["root"])` and
|
||||
`tf.saved_model.load(export_dir)` are equivalent.
|
||||
|
||||
Note: This only works for SavedModels saved with TensorFlow V2 from
|
||||
`tf.saved_model.save` or Keras. This will not load SavedModels save from
|
||||
the Estimator API.
|
||||
|
||||
In Tensorflow V2, SavedModel stores the **object graph** of the saved object.
|
||||
The graph contains nodes (`tf.Module`, `tf.Variable`, `tf.function`, Keras
|
||||
layers, etc.) and edges that are the name of the attributes connecting the
|
||||
objects.
|
||||
|
||||
*Example 1*
|
||||
|
||||
```
|
||||
model = tf.Module()
|
||||
model.child_layer = tf.Module()
|
||||
model.child_layer.v = tf.Variable(5.)
|
||||
tf.saved_model.save(model, '/tmp/model')
|
||||
loaded = tf.__internal__.saved_model.load_partial(
|
||||
... '/tmp/model',
|
||||
... ['root.child_layer', 'root.child_layer.v'])
|
||||
loaded['root.child_layer'].v.numpy()
|
||||
5.
|
||||
loaded['root.child_layer'].v is loaded['root.child_layer.v']
|
||||
True
|
||||
|
||||
*Example 2*
|
||||
model = tf.Module()
|
||||
model.child_layer = tf.Module()
|
||||
model.child_layer.v = tf.Variable(5.)
|
||||
>>>
|
||||
tf.saved_model.save(model, '/tmp/model')
|
||||
# Create a variable
|
||||
new_variable = tf.Variable(0.)
|
||||
loaded = tf.__internal__.saved_model.load_partial(
|
||||
... '/tmp/model',
|
||||
... {'root.child_layer': None, 'root.child_layer.v': new_variable})
|
||||
loaded['root.child_layer'].v.numpy()
|
||||
5.
|
||||
new_variable.numpy()
|
||||
5.
|
||||
```
|
||||
|
||||
**Loading under different distribution strategies**
|
||||
You can load different parts of the model under different distribution
|
||||
strategies. Note that this is very experimental so use with care.
|
||||
|
||||
```
|
||||
model = tf.Module()
|
||||
model.layer_1 = tf.Module()
|
||||
model.layer_1.v = tf.Variable(5.)
|
||||
model.layer_2 = tf.Module()
|
||||
model.layer_2.v = tf.Variable(7.)
|
||||
tf.saved_model.save(model, '/tmp/model')
|
||||
# Load with no strategy
|
||||
loaded = tf.__internal__.saved_model.load_partial(
|
||||
... '/tmp/model',
|
||||
... ['root.layer_1'])
|
||||
loaded['root.layer_1'].v
|
||||
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>
|
||||
strategy = tf.distribute.MirroredStrategy()
|
||||
with strategy.scope():
|
||||
... loaded2 = tf.__internal__.saved_model.load_partial(
|
||||
... '/tmp/model',
|
||||
... ['root.layer_2'])
|
||||
loaded2['root.layer_2'].v
|
||||
MirroredVariable:{
|
||||
0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=7.0>
|
||||
}
|
||||
```
|
||||
|
||||
Args:
|
||||
export_dir: The SavedModel directory to load from.
|
||||
filters: A list or dictionary where each element or key is a string
|
||||
path to nodes that should be loaded. Node paths consist of all the child
|
||||
attribute names to reach that node in the form: `root.{attribute_name}`.
|
||||
The loader will load all of the specified nodes and their recursive
|
||||
descendants. When this option is defined, the loader will return a
|
||||
dictionary mapping the node paths to the loaded objects.
|
||||
tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
|
||||
if the SavedModel contains a single MetaGraph, as for those exported from
|
||||
`tf.saved_model.save`.
|
||||
options: `tf.saved_model.LoadOptions` object that specifies options for
|
||||
loading.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping node paths from the filter to loaded objects.
|
||||
"""
|
||||
return load_internal(export_dir, tags, options, filters=filters)
|
||||
|
||||
|
||||
@tf_export("saved_model.load", v1=["saved_model.load_v2"])
|
||||
def load(export_dir, tags=None, options=None):
|
||||
"""Load a SavedModel from `export_dir`.
|
||||
@ -597,8 +844,8 @@ def load(export_dir, tags=None, options=None):
|
||||
tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
|
||||
if the SavedModel contains a single MetaGraph, as for those exported from
|
||||
`tf.saved_model.save`.
|
||||
options: Optional, `tf.saved_model.LoadOptions` object that specifies
|
||||
options for loading.
|
||||
options: `tf.saved_model.LoadOptions` object that specifies options for
|
||||
loading.
|
||||
|
||||
Returns:
|
||||
A trackable object with a `signatures` attribute mapping from signature
|
||||
@ -609,10 +856,11 @@ def load(export_dir, tags=None, options=None):
|
||||
Raises:
|
||||
ValueError: If `tags` don't match a MetaGraph in the SavedModel.
|
||||
"""
|
||||
return load_internal(export_dir, tags, options)
|
||||
return load_internal(export_dir, tags, options)["root"]
|
||||
|
||||
|
||||
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader):
|
||||
def load_internal(export_dir, tags=None, options=None, loader_cls=Loader,
|
||||
filters=None):
|
||||
"""Loader implementation."""
|
||||
options = options or load_options.LoadOptions()
|
||||
if tags is not None and not isinstance(tags, set):
|
||||
@ -639,7 +887,7 @@ def load_internal(export_dir, tags=None, options=None, loader_cls=Loader):
|
||||
with ops.init_scope():
|
||||
try:
|
||||
loader = loader_cls(object_graph_proto, saved_model_proto, export_dir,
|
||||
ckpt_options)
|
||||
ckpt_options, filters)
|
||||
except errors.NotFoundError as err:
|
||||
raise FileNotFoundError(
|
||||
str(err) + "\n If trying to load on a different device from the "
|
||||
@ -654,7 +902,14 @@ def load_internal(export_dir, tags=None, options=None, loader_cls=Loader):
|
||||
root.tensorflow_git_version = (
|
||||
meta_graph_def.meta_info_def.tensorflow_git_version)
|
||||
else:
|
||||
if filters:
|
||||
raise ValueError("SavedModels saved from Tensorflow V1 or Estimator (any "
|
||||
"version) cannot be loaded with node filters.")
|
||||
with ops.init_scope():
|
||||
root = load_v1_in_v2.load(export_dir, tags)
|
||||
root.graph_debug_info = debug_info
|
||||
return root
|
||||
|
||||
if filters:
|
||||
return {node_id: loader.get(node_id) for node_id in filters}
|
||||
else:
|
||||
return {"root": root}
|
||||
|
@ -2028,6 +2028,34 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
tensor_spec.TensorSpec(shape=[], dtype=dtypes.float32)])
|
||||
cycle(root, 1)
|
||||
|
||||
def test_load_partial_object(self):
|
||||
root = module.Module()
|
||||
root.variables_holder = module.Module()
|
||||
root.variables_holder.v = variables.Variable(1.)
|
||||
|
||||
class Adder(module.Module):
|
||||
|
||||
@def_function.function(input_signature=[tensor_spec.TensorSpec(shape=[])])
|
||||
def __call__(self, y):
|
||||
root.variables_holder.v.assign_add(y)
|
||||
return 1
|
||||
|
||||
root.adder = Adder()
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), "saved_model")
|
||||
save.save(root, save_dir)
|
||||
|
||||
imported = load.load_partial(save_dir,
|
||||
["root.variables_holder.v", "root.adder"])
|
||||
v = imported["root.variables_holder.v"]
|
||||
adder = imported["root.adder"]
|
||||
self.assertEqual(self.evaluate(v), 1)
|
||||
adder(5)
|
||||
self.assertEqual(self.evaluate(v), 6)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "requires inputs/variables"):
|
||||
imported = load.load_partial(save_dir, ["root.adder"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user