Add restore function to V1 SavedModels loaded using load_v2, which restores other name-based checkpoints to the loaded SavedModel object.
PiperOrigin-RevId: 286688840 Change-Id: I88049e98de254795193366f094a67ffba41c6ba8
This commit is contained in:
parent
327b2a7740
commit
001dee2ec4
@ -497,23 +497,25 @@ def load(export_dir, tags=None):
|
||||
_Importing SavedModels from TensorFlow 1.x_
|
||||
|
||||
SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
|
||||
graph instead of `tf.function` objects. These SavedModels will have functions
|
||||
corresponding to their signatures in the `.signatures` attribute, but also
|
||||
have a `.prune` method which allows you to extract functions for new
|
||||
subgraphs. This is equivalent to importing the SavedModel and naming feeds and
|
||||
fetches in a Session from TensorFlow 1.x.
|
||||
graph instead of `tf.function` objects. These SavedModels will be loaded with
|
||||
the following attributes:
|
||||
|
||||
```python
|
||||
imported = tf.saved_model.load(path_to_v1_saved_model)
|
||||
pruned = imported.prune("x:0", "out:0")
|
||||
pruned(tf.ones([]))
|
||||
```
|
||||
* `.signatures`: A dictionary mapping signature names to functions.
|
||||
* `.prune(feeds, fetches) `: A method which allows you to extract
|
||||
functions for new subgraphs. This is equivalent to importing the SavedModel
|
||||
and naming feeds and fetches in a Session from TensorFlow 1.x.
|
||||
|
||||
See `tf.compat.v1.wrap_function` for details. These SavedModels also have a
|
||||
`.variables` attribute containing imported variables, and a `.graph` attribute
|
||||
representing the whole imported graph. For SavedModels exported from
|
||||
`tf.saved_model.save`, variables are instead assigned to whichever attributes
|
||||
they were assigned before export.
|
||||
```python
|
||||
imported = tf.saved_model.load(path_to_v1_saved_model)
|
||||
pruned = imported.prune("x:0", "out:0")
|
||||
pruned(tf.ones([]))
|
||||
```
|
||||
|
||||
See `tf.compat.v1.wrap_function` for details.
|
||||
* `.variables`: A list of imported variables.
|
||||
* `.graph`: The whole imported graph.
|
||||
* `.restore(save_path)`: A function that restores variables from a checkpoint
|
||||
saved from `tf.compat.v1.Saver`.
|
||||
|
||||
_Consuming SavedModels asynchronously_
|
||||
|
||||
|
@ -91,19 +91,24 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
# pylint: enable=protected-access
|
||||
returns[0] = saver
|
||||
|
||||
def restore_variables(self, wrapped, saver):
|
||||
def _extract_saver_restore(self, wrapped, saver):
|
||||
if saver is None:
|
||||
return None
|
||||
saver_def = saver.saver_def
|
||||
filename_tensor = wrapped.graph.as_graph_element(
|
||||
saver_def.filename_tensor_name)
|
||||
# We both feed and fetch filename_tensor so we have an operation to use to
|
||||
# feed into variable initializers (only relevant for v1 graph building).
|
||||
return wrapped.prune(
|
||||
feeds=[filename_tensor],
|
||||
fetches=[filename_tensor,
|
||||
wrapped.graph.as_graph_element(saver_def.restore_op_name)])
|
||||
|
||||
def restore_variables(self, wrapped, restore_from_saver):
|
||||
"""Restores variables from the checkpoint."""
|
||||
if saver is not None:
|
||||
saver_def = saver.saver_def
|
||||
filename_tensor = wrapped.graph.as_graph_element(
|
||||
saver_def.filename_tensor_name)
|
||||
# We both feed and fetch filename_tensor so we have an operation to use to
|
||||
# feed into variable initializers (only relevant for v1 graph building).
|
||||
restore_fn = wrapped.prune(
|
||||
feeds=[filename_tensor],
|
||||
fetches=[filename_tensor,
|
||||
wrapped.graph.as_graph_element(saver_def.restore_op_name)])
|
||||
initializer, _ = restore_fn(constant_op.constant(self._variables_path))
|
||||
if restore_from_saver is not None:
|
||||
initializer, _ = restore_from_saver(
|
||||
constant_op.constant(self._variables_path))
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
# Add the initialization operation to the table initializers collection
|
||||
# in case we don't have any lifted variables to attach it to. There
|
||||
@ -203,7 +208,8 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
|
||||
signature=[])
|
||||
saver, = load_graph_returns
|
||||
self.restore_variables(wrapped, saver)
|
||||
restore_from_saver = self._extract_saver_restore(wrapped, saver)
|
||||
self.restore_variables(wrapped, restore_from_saver)
|
||||
with wrapped.graph.as_default():
|
||||
init_op = loader_impl.get_init_op(
|
||||
meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
|
||||
@ -211,6 +217,9 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
|
||||
init_anchor = constant_op.constant(0., name="dummy_fetch")
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
if restore_from_saver is not None:
|
||||
root.restore = (
|
||||
lambda path: restore_from_saver(constant_op.constant(path)))
|
||||
asset_feed_tensors = []
|
||||
asset_paths = []
|
||||
for tensor_name, value in loader_impl.get_asset_tensors(
|
||||
|
@ -37,9 +37,12 @@ from tensorflow.python.framework import versions
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import lookup_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.saved_model import builder_impl
|
||||
from tensorflow.python.saved_model import load
|
||||
@ -48,6 +51,7 @@ from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import simple_save
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.saved_model import utils_impl
|
||||
from tensorflow.python.training import saver
|
||||
|
||||
|
||||
class LoadTest(test.TestCase):
|
||||
@ -594,6 +598,38 @@ class LoadTest(test.TestCase):
|
||||
forty_two = constant_op.constant([42], dtype=dtypes.int64)
|
||||
self.assertEqual([45], imported_fn(forty_two)["output"].numpy())
|
||||
|
||||
def test_load_and_restore_partitioned_variables(self):
|
||||
export_graph = ops.Graph()
|
||||
with export_graph.as_default():
|
||||
partitioned_var = variable_scope.get_variable(
|
||||
"a", shape=[6], initializer=init_ops.constant_initializer(13),
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2),
|
||||
use_resource=True)
|
||||
x = array_ops.placeholder(shape=[], dtype=dtypes.float32)
|
||||
y = x * partitioned_var
|
||||
with session_lib.Session() as session:
|
||||
session.run(variables.global_variables_initializer())
|
||||
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
|
||||
simple_save.simple_save(session, path,
|
||||
inputs={"x": x}, outputs={"y": y})
|
||||
|
||||
# Create a name-based checkpoint with different values.
|
||||
session.run(partitioned_var.assign([[5, 4, 3], [2, 1, 0]]))
|
||||
ckpt_path = os.path.join(self.get_temp_dir(), "restore_ckpt")
|
||||
saver.Saver().save(session, ckpt_path)
|
||||
|
||||
imported = load.load(path)
|
||||
self.assertAllClose(self.evaluate(imported.variables),
|
||||
[[13, 13, 13], [13, 13, 13]])
|
||||
|
||||
self.evaluate(imported.restore(ckpt_path))
|
||||
self.assertAllClose(self.evaluate(imported.variables),
|
||||
[[5, 4, 3], [2, 1, 0]])
|
||||
self.assertAllClose(
|
||||
self.evaluate(
|
||||
imported.signatures["serving_default"](constant_op.constant(2.))),
|
||||
{"y": [10, 8, 6, 4, 2, 0]})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user