When restoring a variable with an initializer, pass through restore metadata rather than forgetting it

This avoids 2x memory usage when restoring with a distribution strategy, since otherwise variables are restored twice (with two live copies the second time).

PiperOrigin-RevId: 311778129
Change-Id: I60c1c23d0b554d30e3913f588e6f11a7c430fe71
This commit is contained in:
Allen Lavoie 2020-05-15 12:15:58 -07:00 committed by TensorFlower Gardener
parent d968853cc6
commit 985275ea27
2 changed files with 52 additions and 4 deletions

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.optimizer_v2 import adam
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import adam as adam_v1
@ -96,6 +97,41 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
self.assertEqual((training_continuation + 1) * num_training_steps,
root.optimizer_step.numpy())
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]))
def testInitializeFromCheckpoint(self, distribution):
variable_shape = [5]
save_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(
array_ops.ones(variable_shape)))
save_path = save_checkpoint.save(
os.path.join(self.get_temp_dir(), "checkpoint"))
with distribution.scope():
restore_checkpoint = trackable_utils.Checkpoint()
restore_checkpoint.restore(save_path)
initial_value = restore_checkpoint._preload_simple_restoration(
"v", variable_shape)
v = variables_lib.Variable(initial_value)
# Check that the variable is now tagged as restored. `Checkpoint` then
# knows it doesn't have to restore `v`'s value when it's assigned to an
# object.
self.assertGreater(v._update_uid, 0)
self.assertAllClose(array_ops.ones(variable_shape), v)
v.assign(array_ops.zeros(variable_shape))
# Assignment to an object should not trigger restoration, since we already
# restored the object through an initializer. This wouldn't be a
# correctness issue, but it would mean that models would use twice as much
# memory when loading (the buffer already assigned to the variable, and
# the new restoration).
restore_checkpoint.v = v
self.assertAllClose(array_ops.zeros(variable_shape), v)
@combinations.generate(
combinations.combine(
distribution=[

View File

@ -1772,13 +1772,25 @@ class StrategyExtendedV2(object):
kwargs["distribute_strategy"] = strategy
# Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
# dereferencing a `Tensor` that is without a `name`.
# TODO(b/138130844): Revisit the following check once
# `CheckpointInitialValue` class is removed.
# dereferencing a `Tensor` that is without a `name`. We still need to
# propagate the metadata it's holding.
if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
checkpoint_restore_uid = kwargs[
"initial_value"].checkpoint_position.restore_uid
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
else:
checkpoint_restore_uid = None
return self._create_variable(next_creator, **kwargs)
created = self._create_variable(next_creator, **kwargs)
if checkpoint_restore_uid is not None:
# pylint: disable=protected-access
# Let the checkpointing infrastructure know that the variable was
# already restored so it doesn't waste memory loading the value again.
created._maybe_initialize_trackable()
created._update_uid = checkpoint_restore_uid
# pylint: enable=protected-access
return created
def distributed_getter(getter, *args, **kwargs):
if not self._allow_variable_partition():