Stop running variable restoration twice when loading SavedModels
Even worse, the second time split the restore into one op per variable, which is very slow for high-latency filesystems. This was an unfortunate side-effect of a 1.x compatibility case. Adds a unit test asserting that only one restore op is run. PiperOrigin-RevId: 351231481 Change-Id: I43cb7a140f10ad743a64bd7f6893abc77477a90f
This commit is contained in:
parent
86069bb1cb
commit
756840185f
@ -488,13 +488,14 @@ class Loader(object):
|
||||
load_status.assert_existing_objects_matched()
|
||||
checkpoint = load_status._checkpoint
|
||||
|
||||
if not context.executing_eagerly():
|
||||
# When running in eager mode, the `restore` call above has already run and
|
||||
# restored the state of trackables, call `position.restore_ops()` will
|
||||
# return an empty list as there is nothing left to do. In graph mode, that
|
||||
# will return the list of ops that must run to restore the object on that
|
||||
# position. We have to wire them in the initializers of the objects so that
|
||||
# they get initialized properly when using common practices (e.g. the ones
|
||||
# used by ManagedSession) without further user action.
|
||||
# restored the state of trackables, and calling `position.restore_ops()`
|
||||
# would re-run the restore. In graph mode, that will return a cached list
|
||||
# of ops that must run to restore the object on that position. We have to
|
||||
# wire them in the initializers of the objects so that they get
|
||||
# initialized properly when using common practices (e.g. the ones used by
|
||||
# ManagedSession) without further user action.
|
||||
for object_id, obj in dict(checkpoint.object_by_proto_id).items():
|
||||
position = base.CheckpointPosition(checkpoint=checkpoint,
|
||||
proto_id=object_id)
|
||||
|
@ -37,6 +37,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import function as framework_function
|
||||
from tensorflow.python.framework import op_callbacks
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
@ -2007,6 +2008,25 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
load.load(path, tags=tag_constants.SERVING)
|
||||
load.load(path, tags=set([tag_constants.SERVING]))
|
||||
|
||||
def test_single_restore_op_used(self):
|
||||
root = module.Module()
|
||||
root.v1 = variables.Variable(1.)
|
||||
root.v2 = variables.Variable(2.)
|
||||
root.v3 = variables.Variable(3.)
|
||||
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
|
||||
save.save(root, path)
|
||||
restore_count = 0
|
||||
|
||||
def _count_restores(op_type, *unused_args, **unused_kwargs):
|
||||
nonlocal restore_count
|
||||
if op_type == b"RestoreV2":
|
||||
restore_count += 1
|
||||
|
||||
op_callbacks.add_op_callback(_count_restores)
|
||||
load.load(path)
|
||||
op_callbacks.remove_op_callback(_count_restores)
|
||||
self.assertEqual(1, restore_count)
|
||||
|
||||
def test_docstring_examples(self):
|
||||
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
|
||||
exported = util.Checkpoint(v=variables.Variable(3.))
|
||||
|
Loading…
x
Reference in New Issue
Block a user