diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index 4a44dd30e2f..39e6a915379 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -497,23 +497,25 @@ def load(export_dir, tags=None): _Importing SavedModels from TensorFlow 1.x_ SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat - graph instead of `tf.function` objects. These SavedModels will have functions - corresponding to their signatures in the `.signatures` attribute, but also - have a `.prune` method which allows you to extract functions for new - subgraphs. This is equivalent to importing the SavedModel and naming feeds and - fetches in a Session from TensorFlow 1.x. + graph instead of `tf.function` objects. These SavedModels will be loaded with + the following attributes: - ```python - imported = tf.saved_model.load(path_to_v1_saved_model) - pruned = imported.prune("x:0", "out:0") - pruned(tf.ones([])) - ``` + * `.signatures`: A dictionary mapping signature names to functions. + * `.prune(feeds, fetches) `: A method which allows you to extract + functions for new subgraphs. This is equivalent to importing the SavedModel + and naming feeds and fetches in a Session from TensorFlow 1.x. - See `tf.compat.v1.wrap_function` for details. These SavedModels also have a - `.variables` attribute containing imported variables, and a `.graph` attribute - representing the whole imported graph. For SavedModels exported from - `tf.saved_model.save`, variables are instead assigned to whichever attributes - they were assigned before export. + ```python + imported = tf.saved_model.load(path_to_v1_saved_model) + pruned = imported.prune("x:0", "out:0") + pruned(tf.ones([])) + ``` + + See `tf.compat.v1.wrap_function` for details. + * `.variables`: A list of imported variables. + * `.graph`: The whole imported graph. + * `.restore(save_path)`: A function that restores variables from a checkpoint + saved from `tf.compat.v1.Saver`. _Consuming SavedModels asynchronously_ diff --git a/tensorflow/python/saved_model/load_v1_in_v2.py b/tensorflow/python/saved_model/load_v1_in_v2.py index 8cbabf7bcf2..ede91da168c 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2.py +++ b/tensorflow/python/saved_model/load_v1_in_v2.py @@ -91,19 +91,24 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): # pylint: enable=protected-access returns[0] = saver - def restore_variables(self, wrapped, saver): + def _extract_saver_restore(self, wrapped, saver): + if saver is None: + return None + saver_def = saver.saver_def + filename_tensor = wrapped.graph.as_graph_element( + saver_def.filename_tensor_name) + # We both feed and fetch filename_tensor so we have an operation to use to + # feed into variable initializers (only relevant for v1 graph building). + return wrapped.prune( + feeds=[filename_tensor], + fetches=[filename_tensor, + wrapped.graph.as_graph_element(saver_def.restore_op_name)]) + + def restore_variables(self, wrapped, restore_from_saver): """Restores variables from the checkpoint.""" - if saver is not None: - saver_def = saver.saver_def - filename_tensor = wrapped.graph.as_graph_element( - saver_def.filename_tensor_name) - # We both feed and fetch filename_tensor so we have an operation to use to - # feed into variable initializers (only relevant for v1 graph building). - restore_fn = wrapped.prune( - feeds=[filename_tensor], - fetches=[filename_tensor, - wrapped.graph.as_graph_element(saver_def.restore_op_name)]) - initializer, _ = restore_fn(constant_op.constant(self._variables_path)) + if restore_from_saver is not None: + initializer, _ = restore_from_saver( + constant_op.constant(self._variables_path)) if not ops.executing_eagerly_outside_functions(): # Add the initialization operation to the table initializers collection # in case we don't have any lifted variables to attach it to. There @@ -203,7 +208,8 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): functools.partial(self.load_graph, load_graph_returns, meta_graph_def), signature=[]) saver, = load_graph_returns - self.restore_variables(wrapped, saver) + restore_from_saver = self._extract_saver_restore(wrapped, saver) + self.restore_variables(wrapped, restore_from_saver) with wrapped.graph.as_default(): init_op = loader_impl.get_init_op( meta_graph_def) or monitored_session.Scaffold.default_local_init_op() @@ -211,6 +217,9 @@ class _EagerSavedModelLoader(loader_impl.SavedModelLoader): init_anchor = constant_op.constant(0., name="dummy_fetch") root = tracking.AutoTrackable() + if restore_from_saver is not None: + root.restore = ( + lambda path: restore_from_saver(constant_op.constant(path))) asset_feed_tensors = [] asset_paths = [] for tensor_name, value in loader_impl.get_asset_tensors( diff --git a/tensorflow/python/saved_model/load_v1_in_v2_test.py b/tensorflow/python/saved_model/load_v1_in_v2_test.py index f02ab14b21c..37b439fe649 100644 --- a/tensorflow/python/saved_model/load_v1_in_v2_test.py +++ b/tensorflow/python/saved_model/load_v1_in_v2_test.py @@ -37,9 +37,12 @@ from tensorflow.python.framework import versions from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.saved_model import builder_impl from tensorflow.python.saved_model import load @@ -48,6 +51,7 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import simple_save from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import utils_impl +from tensorflow.python.training import saver class LoadTest(test.TestCase): @@ -594,6 +598,38 @@ class LoadTest(test.TestCase): forty_two = constant_op.constant([42], dtype=dtypes.int64) self.assertEqual([45], imported_fn(forty_two)["output"].numpy()) + def test_load_and_restore_partitioned_variables(self): + export_graph = ops.Graph() + with export_graph.as_default(): + partitioned_var = variable_scope.get_variable( + "a", shape=[6], initializer=init_ops.constant_initializer(13), + partitioner=partitioned_variables.fixed_size_partitioner(2), + use_resource=True) + x = array_ops.placeholder(shape=[], dtype=dtypes.float32) + y = x * partitioned_var + with session_lib.Session() as session: + session.run(variables.global_variables_initializer()) + path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) + simple_save.simple_save(session, path, + inputs={"x": x}, outputs={"y": y}) + + # Create a name-based checkpoint with different values. + session.run(partitioned_var.assign([[5, 4, 3], [2, 1, 0]])) + ckpt_path = os.path.join(self.get_temp_dir(), "restore_ckpt") + saver.Saver().save(session, ckpt_path) + + imported = load.load(path) + self.assertAllClose(self.evaluate(imported.variables), + [[13, 13, 13], [13, 13, 13]]) + + self.evaluate(imported.restore(ckpt_path)) + self.assertAllClose(self.evaluate(imported.variables), + [[5, 4, 3], [2, 1, 0]]) + self.assertAllClose( + self.evaluate( + imported.signatures["serving_default"](constant_op.constant(2.))), + {"y": [10, 8, 6, 4, 2, 0]}) + if __name__ == "__main__": test.main()