Stop running variable restoration twice when loading SavedModels

Even worse, the second time split the restore into one op per variable, which is very slow for high-latency filesystems.

This was an unfortunate side-effect of a 1.x compatibility case. Adds a unit test asserting that only one restore op is run.

PiperOrigin-RevId: 351231481
Change-Id: I43cb7a140f10ad743a64bd7f6893abc77477a90f
This commit is contained in:
Allen Lavoie 2021-01-11 14:05:41 -08:00 committed by TensorFlower Gardener
parent 86069bb1cb
commit 756840185f
2 changed files with 45 additions and 24 deletions

View File

@ -488,31 +488,32 @@ class Loader(object):
load_status.assert_existing_objects_matched()
checkpoint = load_status._checkpoint
# When running in eager mode, the `restore` call above has already run and
# restored the state of trackables, call `position.restore_ops()` will
# return an empty list as there is nothing left to do. In graph mode, that
# will return the list of ops that must run to restore the object on that
# position. We have to wire them in the initializers of the objects so that
# they get initialized properly when using common practices (e.g. the ones
# used by ManagedSession) without further user action.
for object_id, obj in dict(checkpoint.object_by_proto_id).items():
position = base.CheckpointPosition(checkpoint=checkpoint,
proto_id=object_id)
restore_ops = position.restore_ops()
if restore_ops:
if resource_variable_ops.is_resource_variable(obj):
if len(restore_ops) == 1:
obj._initializer_op = restore_ops[0]
if not context.executing_eagerly():
# When running in eager mode, the `restore` call above has already run and
# restored the state of trackables, and calling `position.restore_ops()`
# would re-run the restore. In graph mode, that will return a cached list
# of ops that must run to restore the object on that position. We have to
# wire them in the initializers of the objects so that they get
# initialized properly when using common practices (e.g. the ones used by
# ManagedSession) without further user action.
for object_id, obj in dict(checkpoint.object_by_proto_id).items():
position = base.CheckpointPosition(checkpoint=checkpoint,
proto_id=object_id)
restore_ops = position.restore_ops()
if restore_ops:
if resource_variable_ops.is_resource_variable(obj):
if len(restore_ops) == 1:
obj._initializer_op = restore_ops[0]
else:
obj._initializer_op = control_flow_ops.group(*restore_ops)
elif isinstance(obj, lookup_ops.LookupInterface):
# We don't need to check for eager execution here, since this code
# path should only be taken if we are restoring in graph mode.
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
else:
obj._initializer_op = control_flow_ops.group(*restore_ops)
elif isinstance(obj, lookup_ops.LookupInterface):
# We don't need to check for eager execution here, since this code
# path should only be taken if we are restoring in graph mode.
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
else:
raise NotImplementedError(
("Missing functionality to restore state of object "
"%r from the checkpoint." % obj))
raise NotImplementedError(
("Missing functionality to restore state of object "
"%r from the checkpoint." % obj))
def adjust_debug_info_func_names(self, debug_info):
"""Rewrite func names in the debug info by using the concrete func names."""

View File

@ -37,6 +37,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
@ -2007,6 +2008,25 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
load.load(path, tags=tag_constants.SERVING)
load.load(path, tags=set([tag_constants.SERVING]))
def test_single_restore_op_used(self):
root = module.Module()
root.v1 = variables.Variable(1.)
root.v2 = variables.Variable(2.)
root.v3 = variables.Variable(3.)
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
save.save(root, path)
restore_count = 0
def _count_restores(op_type, *unused_args, **unused_kwargs):
nonlocal restore_count
if op_type == b"RestoreV2":
restore_count += 1
op_callbacks.add_op_callback(_count_restores)
load.load(path)
op_callbacks.remove_op_callback(_count_restores)
self.assertEqual(1, restore_count)
def test_docstring_examples(self):
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
exported = util.Checkpoint(v=variables.Variable(3.))