When restoring a variable with an initializer, pass through restore metadata rather than forgetting it
This avoids 2x memory usage when restoring with a distribution strategy, since otherwise variables are restored twice (with two live copies the second time). PiperOrigin-RevId: 311778129 Change-Id: I60c1c23d0b554d30e3913f588e6f11a7c430fe71
This commit is contained in:
parent
d968853cc6
commit
985275ea27
@ -32,6 +32,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.optimizer_v2 import adam
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.training import adam as adam_v1
|
||||
@ -96,6 +97,41 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual((training_continuation + 1) * num_training_steps,
|
||||
root.optimizer_step.numpy())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
mode=["eager"]))
|
||||
def testInitializeFromCheckpoint(self, distribution):
|
||||
variable_shape = [5]
|
||||
save_checkpoint = trackable_utils.Checkpoint(v=variables_lib.Variable(
|
||||
array_ops.ones(variable_shape)))
|
||||
save_path = save_checkpoint.save(
|
||||
os.path.join(self.get_temp_dir(), "checkpoint"))
|
||||
with distribution.scope():
|
||||
restore_checkpoint = trackable_utils.Checkpoint()
|
||||
restore_checkpoint.restore(save_path)
|
||||
initial_value = restore_checkpoint._preload_simple_restoration(
|
||||
"v", variable_shape)
|
||||
v = variables_lib.Variable(initial_value)
|
||||
# Check that the variable is now tagged as restored. `Checkpoint` then
|
||||
# knows it doesn't have to restore `v`'s value when it's assigned to an
|
||||
# object.
|
||||
self.assertGreater(v._update_uid, 0)
|
||||
self.assertAllClose(array_ops.ones(variable_shape), v)
|
||||
v.assign(array_ops.zeros(variable_shape))
|
||||
# Assignment to an object should not trigger restoration, since we already
|
||||
# restored the object through an initializer. This wouldn't be a
|
||||
# correctness issue, but it would mean that models would use twice as much
|
||||
# memory when loading (the buffer already assigned to the variable, and
|
||||
# the new restoration).
|
||||
restore_checkpoint.v = v
|
||||
self.assertAllClose(array_ops.zeros(variable_shape), v)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
|
@ -1772,13 +1772,25 @@ class StrategyExtendedV2(object):
|
||||
kwargs["distribute_strategy"] = strategy
|
||||
|
||||
# Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
|
||||
# dereferencing a `Tensor` that is without a `name`.
|
||||
# TODO(b/138130844): Revisit the following check once
|
||||
# `CheckpointInitialValue` class is removed.
|
||||
# dereferencing a `Tensor` that is without a `name`. We still need to
|
||||
# propagate the metadata it's holding.
|
||||
if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
|
||||
checkpoint_restore_uid = kwargs[
|
||||
"initial_value"].checkpoint_position.restore_uid
|
||||
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
|
||||
else:
|
||||
checkpoint_restore_uid = None
|
||||
|
||||
return self._create_variable(next_creator, **kwargs)
|
||||
created = self._create_variable(next_creator, **kwargs)
|
||||
|
||||
if checkpoint_restore_uid is not None:
|
||||
# pylint: disable=protected-access
|
||||
# Let the checkpointing infrastructure know that the variable was
|
||||
# already restored so it doesn't waste memory loading the value again.
|
||||
created._maybe_initialize_trackable()
|
||||
created._update_uid = checkpoint_restore_uid
|
||||
# pylint: enable=protected-access
|
||||
return created
|
||||
|
||||
def distributed_getter(getter, *args, **kwargs):
|
||||
if not self._allow_variable_partition():
|
||||
|
Loading…
Reference in New Issue
Block a user