Stop running variable restoration twice when loading SavedModels

Even worse, the second time split the restore into one op per variable, which is very slow for high-latency filesystems.

This was an unfortunate side-effect of a 1.x compatibility case. Adds a unit test asserting that only one restore op is run.

PiperOrigin-RevId: 351231481
Change-Id: I43cb7a140f10ad743a64bd7f6893abc77477a90f
This commit is contained in:
Allen Lavoie 2021-01-11 14:05:41 -08:00 committed by TensorFlower Gardener
parent 86069bb1cb
commit 756840185f
2 changed files with 45 additions and 24 deletions

View File

@ -488,31 +488,32 @@ class Loader(object):
load_status.assert_existing_objects_matched() load_status.assert_existing_objects_matched()
checkpoint = load_status._checkpoint checkpoint = load_status._checkpoint
# When running in eager mode, the `restore` call above has already run and if not context.executing_eagerly():
# restored the state of trackables, call `position.restore_ops()` will # When running in eager mode, the `restore` call above has already run and
# return an empty list as there is nothing left to do. In graph mode, that # restored the state of trackables, and calling `position.restore_ops()`
# will return the list of ops that must run to restore the object on that # would re-run the restore. In graph mode, that will return a cached list
# position. We have to wire them in the initializers of the objects so that # of ops that must run to restore the object on that position. We have to
# they get initialized properly when using common practices (e.g. the ones # wire them in the initializers of the objects so that they get
# used by ManagedSession) without further user action. # initialized properly when using common practices (e.g. the ones used by
for object_id, obj in dict(checkpoint.object_by_proto_id).items(): # ManagedSession) without further user action.
position = base.CheckpointPosition(checkpoint=checkpoint, for object_id, obj in dict(checkpoint.object_by_proto_id).items():
proto_id=object_id) position = base.CheckpointPosition(checkpoint=checkpoint,
restore_ops = position.restore_ops() proto_id=object_id)
if restore_ops: restore_ops = position.restore_ops()
if resource_variable_ops.is_resource_variable(obj): if restore_ops:
if len(restore_ops) == 1: if resource_variable_ops.is_resource_variable(obj):
obj._initializer_op = restore_ops[0] if len(restore_ops) == 1:
obj._initializer_op = restore_ops[0]
else:
obj._initializer_op = control_flow_ops.group(*restore_ops)
elif isinstance(obj, lookup_ops.LookupInterface):
# We don't need to check for eager execution here, since this code
# path should only be taken if we are restoring in graph mode.
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
else: else:
obj._initializer_op = control_flow_ops.group(*restore_ops) raise NotImplementedError(
elif isinstance(obj, lookup_ops.LookupInterface): ("Missing functionality to restore state of object "
# We don't need to check for eager execution here, since this code "%r from the checkpoint." % obj))
# path should only be taken if we are restoring in graph mode.
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
else:
raise NotImplementedError(
("Missing functionality to restore state of object "
"%r from the checkpoint." % obj))
def adjust_debug_info_func_names(self, debug_info): def adjust_debug_info_func_names(self, debug_info):
"""Rewrite func names in the debug info by using the concrete func names.""" """Rewrite func names in the debug info by using the concrete func names."""

View File

@ -37,6 +37,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import function as framework_function from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
@ -2007,6 +2008,25 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
load.load(path, tags=tag_constants.SERVING) load.load(path, tags=tag_constants.SERVING)
load.load(path, tags=set([tag_constants.SERVING])) load.load(path, tags=set([tag_constants.SERVING]))
def test_single_restore_op_used(self):
root = module.Module()
root.v1 = variables.Variable(1.)
root.v2 = variables.Variable(2.)
root.v3 = variables.Variable(3.)
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
save.save(root, path)
restore_count = 0
def _count_restores(op_type, *unused_args, **unused_kwargs):
nonlocal restore_count
if op_type == b"RestoreV2":
restore_count += 1
op_callbacks.add_op_callback(_count_restores)
load.load(path)
op_callbacks.remove_op_callback(_count_restores)
self.assertEqual(1, restore_count)
def test_docstring_examples(self): def test_docstring_examples(self):
path = tempfile.mkdtemp(prefix=self.get_temp_dir()) path = tempfile.mkdtemp(prefix=self.get_temp_dir())
exported = util.Checkpoint(v=variables.Variable(3.)) exported = util.Checkpoint(v=variables.Variable(3.))