Add restore function to V1 SavedModels loaded using load_v2, which restores other name-based checkpoints to the loaded SavedModel object.

PiperOrigin-RevId: 286688840
Change-Id: I88049e98de254795193366f094a67ffba41c6ba8
This commit is contained in:
Katherine Wu 2019-12-20 23:07:38 -08:00 committed by TensorFlower Gardener
parent 327b2a7740
commit 001dee2ec4
3 changed files with 75 additions and 28 deletions

View File

@ -497,23 +497,25 @@ def load(export_dir, tags=None):
_Importing SavedModels from TensorFlow 1.x_
SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
graph instead of `tf.function` objects. These SavedModels will have functions
corresponding to their signatures in the `.signatures` attribute, but also
have a `.prune` method which allows you to extract functions for new
subgraphs. This is equivalent to importing the SavedModel and naming feeds and
fetches in a Session from TensorFlow 1.x.
graph instead of `tf.function` objects. These SavedModels will be loaded with
the following attributes:
```python
imported = tf.saved_model.load(path_to_v1_saved_model)
pruned = imported.prune("x:0", "out:0")
pruned(tf.ones([]))
```
* `.signatures`: A dictionary mapping signature names to functions.
* `.prune(feeds, fetches) `: A method which allows you to extract
functions for new subgraphs. This is equivalent to importing the SavedModel
and naming feeds and fetches in a Session from TensorFlow 1.x.
See `tf.compat.v1.wrap_function` for details. These SavedModels also have a
`.variables` attribute containing imported variables, and a `.graph` attribute
representing the whole imported graph. For SavedModels exported from
`tf.saved_model.save`, variables are instead assigned to whichever attributes
they were assigned before export.
```python
imported = tf.saved_model.load(path_to_v1_saved_model)
pruned = imported.prune("x:0", "out:0")
pruned(tf.ones([]))
```
See `tf.compat.v1.wrap_function` for details.
* `.variables`: A list of imported variables.
* `.graph`: The whole imported graph.
* `.restore(save_path)`: A function that restores variables from a checkpoint
saved from `tf.compat.v1.Saver`.
_Consuming SavedModels asynchronously_

View File

@ -91,19 +91,24 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
# pylint: enable=protected-access
returns[0] = saver
def restore_variables(self, wrapped, saver):
def _extract_saver_restore(self, wrapped, saver):
if saver is None:
return None
saver_def = saver.saver_def
filename_tensor = wrapped.graph.as_graph_element(
saver_def.filename_tensor_name)
# We both feed and fetch filename_tensor so we have an operation to use to
# feed into variable initializers (only relevant for v1 graph building).
return wrapped.prune(
feeds=[filename_tensor],
fetches=[filename_tensor,
wrapped.graph.as_graph_element(saver_def.restore_op_name)])
def restore_variables(self, wrapped, restore_from_saver):
"""Restores variables from the checkpoint."""
if saver is not None:
saver_def = saver.saver_def
filename_tensor = wrapped.graph.as_graph_element(
saver_def.filename_tensor_name)
# We both feed and fetch filename_tensor so we have an operation to use to
# feed into variable initializers (only relevant for v1 graph building).
restore_fn = wrapped.prune(
feeds=[filename_tensor],
fetches=[filename_tensor,
wrapped.graph.as_graph_element(saver_def.restore_op_name)])
initializer, _ = restore_fn(constant_op.constant(self._variables_path))
if restore_from_saver is not None:
initializer, _ = restore_from_saver(
constant_op.constant(self._variables_path))
if not ops.executing_eagerly_outside_functions():
# Add the initialization operation to the table initializers collection
# in case we don't have any lifted variables to attach it to. There
@ -203,7 +208,8 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
functools.partial(self.load_graph, load_graph_returns, meta_graph_def),
signature=[])
saver, = load_graph_returns
self.restore_variables(wrapped, saver)
restore_from_saver = self._extract_saver_restore(wrapped, saver)
self.restore_variables(wrapped, restore_from_saver)
with wrapped.graph.as_default():
init_op = loader_impl.get_init_op(
meta_graph_def) or monitored_session.Scaffold.default_local_init_op()
@ -211,6 +217,9 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader):
init_anchor = constant_op.constant(0., name="dummy_fetch")
root = tracking.AutoTrackable()
if restore_from_saver is not None:
root.restore = (
lambda path: restore_from_saver(constant_op.constant(path)))
asset_feed_tensors = []
asset_paths = []
for tensor_name, value in loader_impl.get_asset_tensors(

View File

@ -37,9 +37,12 @@ from tensorflow.python.framework import versions
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import builder_impl
from tensorflow.python.saved_model import load
@ -48,6 +51,7 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import simple_save
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils_impl
from tensorflow.python.training import saver
class LoadTest(test.TestCase):
@ -594,6 +598,38 @@ class LoadTest(test.TestCase):
forty_two = constant_op.constant([42], dtype=dtypes.int64)
self.assertEqual([45], imported_fn(forty_two)["output"].numpy())
def test_load_and_restore_partitioned_variables(self):
export_graph = ops.Graph()
with export_graph.as_default():
partitioned_var = variable_scope.get_variable(
"a", shape=[6], initializer=init_ops.constant_initializer(13),
partitioner=partitioned_variables.fixed_size_partitioner(2),
use_resource=True)
x = array_ops.placeholder(shape=[], dtype=dtypes.float32)
y = x * partitioned_var
with session_lib.Session() as session:
session.run(variables.global_variables_initializer())
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
simple_save.simple_save(session, path,
inputs={"x": x}, outputs={"y": y})
# Create a name-based checkpoint with different values.
session.run(partitioned_var.assign([[5, 4, 3], [2, 1, 0]]))
ckpt_path = os.path.join(self.get_temp_dir(), "restore_ckpt")
saver.Saver().save(session, ckpt_path)
imported = load.load(path)
self.assertAllClose(self.evaluate(imported.variables),
[[13, 13, 13], [13, 13, 13]])
self.evaluate(imported.restore(ckpt_path))
self.assertAllClose(self.evaluate(imported.variables),
[[5, 4, 3], [2, 1, 0]])
self.assertAllClose(
self.evaluate(
imported.signatures["serving_default"](constant_op.constant(2.))),
{"y": [10, 8, 6, 4, 2, 0]})
if __name__ == "__main__":
test.main()