Stop running variable restoration twice when loading SavedModels
Even worse, the second time split the restore into one op per variable, which is very slow for high-latency filesystems. This was an unfortunate side-effect of a 1.x compatibility case. Adds a unit test asserting that only one restore op is run. PiperOrigin-RevId: 351231481 Change-Id: I43cb7a140f10ad743a64bd7f6893abc77477a90f
This commit is contained in:
parent
86069bb1cb
commit
756840185f
@ -488,31 +488,32 @@ class Loader(object):
|
||||
load_status.assert_existing_objects_matched()
|
||||
checkpoint = load_status._checkpoint
|
||||
|
||||
# When running in eager mode, the `restore` call above has already run and
|
||||
# restored the state of trackables, call `position.restore_ops()` will
|
||||
# return an empty list as there is nothing left to do. In graph mode, that
|
||||
# will return the list of ops that must run to restore the object on that
|
||||
# position. We have to wire them in the initializers of the objects so that
|
||||
# they get initialized properly when using common practices (e.g. the ones
|
||||
# used by ManagedSession) without further user action.
|
||||
for object_id, obj in dict(checkpoint.object_by_proto_id).items():
|
||||
position = base.CheckpointPosition(checkpoint=checkpoint,
|
||||
proto_id=object_id)
|
||||
restore_ops = position.restore_ops()
|
||||
if restore_ops:
|
||||
if resource_variable_ops.is_resource_variable(obj):
|
||||
if len(restore_ops) == 1:
|
||||
obj._initializer_op = restore_ops[0]
|
||||
if not context.executing_eagerly():
|
||||
# When running in eager mode, the `restore` call above has already run and
|
||||
# restored the state of trackables, and calling `position.restore_ops()`
|
||||
# would re-run the restore. In graph mode, that will return a cached list
|
||||
# of ops that must run to restore the object on that position. We have to
|
||||
# wire them in the initializers of the objects so that they get
|
||||
# initialized properly when using common practices (e.g. the ones used by
|
||||
# ManagedSession) without further user action.
|
||||
for object_id, obj in dict(checkpoint.object_by_proto_id).items():
|
||||
position = base.CheckpointPosition(checkpoint=checkpoint,
|
||||
proto_id=object_id)
|
||||
restore_ops = position.restore_ops()
|
||||
if restore_ops:
|
||||
if resource_variable_ops.is_resource_variable(obj):
|
||||
if len(restore_ops) == 1:
|
||||
obj._initializer_op = restore_ops[0]
|
||||
else:
|
||||
obj._initializer_op = control_flow_ops.group(*restore_ops)
|
||||
elif isinstance(obj, lookup_ops.LookupInterface):
|
||||
# We don't need to check for eager execution here, since this code
|
||||
# path should only be taken if we are restoring in graph mode.
|
||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
|
||||
else:
|
||||
obj._initializer_op = control_flow_ops.group(*restore_ops)
|
||||
elif isinstance(obj, lookup_ops.LookupInterface):
|
||||
# We don't need to check for eager execution here, since this code
|
||||
# path should only be taken if we are restoring in graph mode.
|
||||
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, restore_ops)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
("Missing functionality to restore state of object "
|
||||
"%r from the checkpoint." % obj))
|
||||
raise NotImplementedError(
|
||||
("Missing functionality to restore state of object "
|
||||
"%r from the checkpoint." % obj))
|
||||
|
||||
def adjust_debug_info_func_names(self, debug_info):
|
||||
"""Rewrite func names in the debug info by using the concrete func names."""
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function as framework_function
|
||||
from tensorflow.python.framework import op_callbacks
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -2007,6 +2008,25 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
load.load(path, tags=tag_constants.SERVING)
|
||||
load.load(path, tags=set([tag_constants.SERVING]))
|
||||
|
||||
def test_single_restore_op_used(self):
|
||||
root = module.Module()
|
||||
root.v1 = variables.Variable(1.)
|
||||
root.v2 = variables.Variable(2.)
|
||||
root.v3 = variables.Variable(3.)
|
||||
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
|
||||
save.save(root, path)
|
||||
restore_count = 0
|
||||
|
||||
def _count_restores(op_type, *unused_args, **unused_kwargs):
|
||||
nonlocal restore_count
|
||||
if op_type == b"RestoreV2":
|
||||
restore_count += 1
|
||||
|
||||
op_callbacks.add_op_callback(_count_restores)
|
||||
load.load(path)
|
||||
op_callbacks.remove_op_callback(_count_restores)
|
||||
self.assertEqual(1, restore_count)
|
||||
|
||||
def test_docstring_examples(self):
|
||||
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
|
||||
exported = util.Checkpoint(v=variables.Variable(3.))
|
||||
|
Loading…
x
Reference in New Issue
Block a user