diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index fd6459cb26a..591bbe3bd26 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -488,31 +488,32 @@ class Loader(object): load_status.assert_existing_objects_matched() checkpoint = load_status._checkpoint - # When running in eager mode, the `restore` call above has already run and - # restored the state of trackables, call `position.restore_ops()` will - # return an empty list as there is nothing left to do. In graph mode, that - # will return the list of ops that must run to restore the object on that - # position. We have to wire them in the initializers of the objects so that - # they get initialized properly when using common practices (e.g. the ones - # used by ManagedSession) without further user action. - for object_id, obj in dict(checkpoint.object_by_proto_id).items(): - position = base.CheckpointPosition(checkpoint=checkpoint, - proto_id=object_id) - restore_ops = position.restore_ops() - if restore_ops: - if resource_variable_ops.is_resource_variable(obj): - if len(restore_ops) == 1: - obj._initializer_op = restore_ops[0] + if not context.executing_eagerly(): + # When running in eager mode, the `restore` call above has already run and + # restored the state of trackables, and calling `position.restore_ops()` + # would re-run the restore. In graph mode, that will return a cached list + # of ops that must run to restore the object on that position. We have to + # wire them in the initializers of the objects so that they get + # initialized properly when using common practices (e.g. the ones used by + # ManagedSession) without further user action. + for object_id, obj in dict(checkpoint.object_by_proto_id).items(): + position = base.CheckpointPosition(checkpoint=checkpoint, + proto_id=object_id) + restore_ops = position.restore_ops() + if restore_ops: + if resource_variable_ops.is_resource_variable(obj): + if len(restore_ops) == 1: + obj._initializer_op = restore_ops[0] + else: + obj._initializer_op = control_flow_ops.group(*restore_ops) + elif isinstance(obj, lookup_ops.LookupInterface): + # We don't need to check for eager execution here, since this code + # path should only be taken if we are restoring in graph mode. + ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops) else: - obj._initializer_op = control_flow_ops.group(*restore_ops) - elif isinstance(obj, lookup_ops.LookupInterface): - # We don't need to check for eager execution here, since this code - # path should only be taken if we are restoring in graph mode. - ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops) - else: - raise NotImplementedError( - ("Missing functionality to restore state of object " - "%r from the checkpoint." % obj)) + raise NotImplementedError( + ("Missing functionality to restore state of object " + "%r from the checkpoint." % obj)) def adjust_debug_info_func_names(self, debug_info): """Rewrite func names in the debug info by using the concrete func names.""" diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py index 5c67dce9134..d9b86303831 100644 --- a/tensorflow/python/saved_model/load_test.py +++ b/tensorflow/python/saved_model/load_test.py @@ -37,6 +37,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function as framework_function +from tensorflow.python.framework import op_callbacks from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec @@ -2007,6 +2008,25 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase): load.load(path, tags=tag_constants.SERVING) load.load(path, tags=set([tag_constants.SERVING])) + def test_single_restore_op_used(self): + root = module.Module() + root.v1 = variables.Variable(1.) + root.v2 = variables.Variable(2.) + root.v3 = variables.Variable(3.) + path = tempfile.mkdtemp(prefix=self.get_temp_dir()) + save.save(root, path) + restore_count = 0 + + def _count_restores(op_type, *unused_args, **unused_kwargs): + nonlocal restore_count + if op_type == b"RestoreV2": + restore_count += 1 + + op_callbacks.add_op_callback(_count_restores) + load.load(path) + op_callbacks.remove_op_callback(_count_restores) + self.assertEqual(1, restore_count) + def test_docstring_examples(self): path = tempfile.mkdtemp(prefix=self.get_temp_dir()) exported = util.Checkpoint(v=variables.Variable(3.))