Add callable wrapper to CheckpointValueInitializer so that we can delay the variable restore until after variable creation scopes have been called.

PiperOrigin-RevId: 329595038
Change-Id: I9983bec354514172573e37d50ed6895a4bafb8dc
This commit is contained in:
Bruce Fontaine 2020-09-01 15:35:04 -07:00 committed by TensorFlower Gardener
parent 9119dd3fad
commit 2c9ffb560c
13 changed files with 132 additions and 71 deletions

View File

@ -51,7 +51,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
restore_checkpoint = trackable_utils.Checkpoint()
restore_checkpoint.restore(save_path)
initial_value = restore_checkpoint._preload_simple_restoration(
"v", variable_shape)
"v")
v = variables_lib.Variable(initial_value)
# Check that the variable is now tagged as restored. `Checkpoint` then
# knows it doesn't have to restore `v`'s value when it's assigned to an

View File

@ -45,6 +45,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base
from tensorflow.python.util.tf_export import tf_export
@ -436,6 +437,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
initial_value = kwargs["initial_value"]
if callable(initial_value):
initial_value = initial_value()
if isinstance(initial_value, base.CheckpointInitialValue):
initial_value = initial_value.wrapped_value
assert not callable(initial_value)
initial_value = ops.convert_to_tensor(
initial_value, dtype=kwargs.get("dtype", None))

View File

@ -2130,6 +2130,10 @@ class StrategyExtendedV2(object):
checkpoint_restore_uid = kwargs[
"initial_value"].checkpoint_position.restore_uid
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
elif isinstance(kwargs["initial_value"],
trackable.CheckpointInitialValueCallable):
checkpoint_restore_uid = kwargs[
"initial_value"].checkpoint_position.restore_uid
else:
checkpoint_restore_uid = None
@ -2139,6 +2143,9 @@ class StrategyExtendedV2(object):
# pylint: disable=protected-access
# Let the checkpointing infrastructure know that the variable was
# already restored so it doesn't waste memory loading the value again.
# In this case of CheckpointInitialValueCallable this may already be
# done by the final variable creator, but it doesn't hurt to do it
# again.
created._maybe_initialize_trackable()
created._update_uid = checkpoint_restore_uid
# pylint: enable=protected-access

View File

@ -217,17 +217,18 @@ class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
if constraint is not None and not callable(constraint):
raise ValueError("The `constraint` argument must be a callable.")
if isinstance(initial_value, trackable.CheckpointInitialValue):
self._maybe_initialize_trackable()
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value
with ops.name_scope(name, "Variable", []
if init_from_fn else [initial_value]) as scope_name:
with ops.name_scope("Initializer"):
initial_value = ops.convert_to_tensor(
initial_value() if init_from_fn else initial_value,
name="initial_value", dtype=dtype)
if init_from_fn:
initial_value = initial_value()
if isinstance(initial_value, trackable.CheckpointInitialValue):
self._maybe_initialize_trackable()
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value
initial_value = ops.convert_to_tensor(initial_value,
name="initial_value", dtype=dtype)
assert initial_value is not None
# Don't use `shape or initial_value.shape` since TensorShape has

View File

@ -142,8 +142,8 @@ class _DelegatingTrackableMixin(object):
return self._trackable._add_variable_with_custom_getter(
name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
def _preload_simple_restoration(self, name, shape):
return self._trackable._preload_simple_restoration(name, shape)
def _preload_simple_restoration(self, name):
return self._trackable._preload_simple_restoration(name)
def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name
return self._trackable._track_trackable(trackable, name, overwrite)

View File

@ -1292,7 +1292,7 @@ class OptimizerV2(trackable.Trackable):
# (aside from double initialization), and makes variable creator scopes
# behave the same way they do when graph building.
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = trackable.CheckpointInitialValue(
initializer = trackable.CheckpointInitialValueCallable(
checkpoint_position=slot_variable_position)
slot_variable = self.add_slot(
var=variable,

View File

@ -1686,11 +1686,6 @@ class ResourceVariable(BaseResourceVariable):
if constraint is not None and not callable(constraint):
raise ValueError("The `constraint` argument must be a callable.")
if isinstance(initial_value, trackable.CheckpointInitialValue):
self._maybe_initialize_trackable()
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.init_scope():
@ -1719,10 +1714,15 @@ class ResourceVariable(BaseResourceVariable):
s=[compat.as_bytes("loc:@%s" % handle_name)]))
with ops.get_default_graph()._attr_scope({"_class": attr}):
with ops.name_scope("Initializer"), device_context_manager(None):
initial_value = ops.convert_to_tensor(
initial_value() if init_from_fn else initial_value,
name="initial_value",
dtype=dtype)
if init_from_fn:
initial_value = initial_value()
if isinstance(initial_value, trackable.CheckpointInitialValue):
self._maybe_initialize_trackable()
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value
initial_value = ops.convert_to_tensor(initial_value,
name="initial_value",
dtype=dtype)
if shape is not None:
if not initial_value.shape.is_compatible_with(shape):
raise ValueError(

View File

@ -270,6 +270,24 @@ def disable_resource_variables():
_api_usage_gauge.get_cell().set(False)
def _needs_no_arguments(python_callable):
"""Returns true if the callable needs no arguments to call."""
# TODO(bfontain): Switch to inspect.signature when we are python 3 only.
# signature = inspect.signature(python_callable)
# return not [1 for param in signature.parameters.values()
# if param.default == param.empty]
num_arguments = len(tf_inspect.getargspec(python_callable).args)
if not tf_inspect.isfunction(python_callable) and not isinstance(
python_callable, functools.partial):
# getargspec includes self for function objects (which aren't
# functools.partial). This has no default so we need to remove it.
# It is not even an argument so its odd that getargspec returns this.
# Note that this is fixed with inspect.signature in Python 3.
num_arguments -= 1
return num_arguments == len(
tf_inspect.getargspec(python_callable).defaults or [])
class _VariableStore(object):
"""Variable store that carries a number of named Variables.
@ -905,18 +923,17 @@ class _VariableStore(object):
# Instantiate initializer if provided initializer is a type object.
if tf_inspect.isclass(initializer):
initializer = initializer()
if shape is not None and shape.is_fully_defined():
if shape.is_fully_defined():
if "partition_info" in tf_inspect.getargspec(initializer).args:
init_val = lambda: initializer( # pylint: disable=g-long-lambda
shape.as_list(),
dtype=dtype,
partition_info=partition_info)
init_val = functools.partial(initializer,
shape.as_list(),
dtype=dtype,
partition_info=partition_info)
else:
init_val = lambda: initializer( # pylint: disable=g-long-lambda
shape.as_list(), dtype=dtype)
init_val = functools.partial(initializer,
shape.as_list(), dtype=dtype)
variable_dtype = dtype.base_dtype
elif len(tf_inspect.getargspec(initializer).args) == len(
tf_inspect.getargspec(initializer).defaults or []):
elif _needs_no_arguments(initializer):
init_val = initializer
variable_dtype = None
else:

View File

@ -1794,8 +1794,13 @@ class RefVariable(VariableV1, core.Tensor):
# pylint: disable=protected-access
with ops.get_default_graph()._attr_scope({"_class": attr}):
with ops.name_scope("Initializer"), ops.device(None):
initial_value = initial_value()
if isinstance(initial_value, trackable.CheckpointInitialValue):
self._maybe_initialize_trackable()
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value
self._initial_value = ops.convert_to_tensor(
initial_value(), name="initial_value", dtype=dtype)
initial_value, name="initial_value", dtype=dtype)
if shape is None:
shape = (
self._initial_value.get_shape()

View File

@ -774,19 +774,19 @@ class TPUEmbedding(tracking.AutoTrackable):
def create_variables(table):
"""Create all variables."""
shape = (table.vocabulary_size, table.dim)
variable_shape = (table.vocabulary_size, table.dim)
def getter(name, shape, dtype, initializer, trainable):
# TODO(bfontain): make CheckpointInitialValue a callable rather than
# something that inherits from tensor.
if not isinstance(initializer, base.CheckpointInitialValue):
initial_value = functools.partial(initializer, shape, dtype=dtype)
else:
initial_value = initializer
del shape
# _add_variable_with_custom_getter clears the shape sometimes, so we
# take the global shape from outside the getter.
initial_value = functools.partial(initializer, variable_shape,
dtype=dtype)
return tf_variables.Variable(
name=name,
initial_value=initial_value,
shape=variable_shape,
dtype=dtype,
trainable=trainable)
def variable_creator(name, initializer, trainable=True):
@ -796,7 +796,7 @@ class TPUEmbedding(tracking.AutoTrackable):
return self._add_variable_with_custom_getter(
name=name,
initializer=initializer,
shape=shape,
shape=variable_shape,
dtype=dtypes.float32,
getter=getter,
trainable=trainable)
@ -1490,9 +1490,6 @@ def extract_variable_info(kwargs):
return (kwargs["name"], shape,
kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]),
kwargs["initial_value"].func)
elif isinstance(kwargs["initial_value"], base.CheckpointInitialValue):
return (kwargs["name"], kwargs["initial_value"].shape,
kwargs["initial_value"].dtype, kwargs["initial_value"])
elif "shape" not in kwargs or kwargs["shape"] is None:
raise ValueError(
"Unable to extract initializer function and shape from {}. Please "
@ -1529,26 +1526,21 @@ def make_sharded_variable_creator(hosts):
partitions = ([rows // num_hosts + 1] * missing + [rows // num_hosts] *
(num_hosts - missing))
variables = []
newkwargs = kwargs
newkwargs["dtype"] = dtype
kwargs["dtype"] = dtype
# TODO(bfontain): Remove this check once we can pass position and shape of
# shards to CheckpointInitialValue.
if isinstance(initial_value, base.CheckpointInitialValue) and num_hosts > 1:
# shards to CheckpointInitialValueCallable.
if isinstance(initial_value,
base.CheckpointInitialValueCallable) and num_hosts > 1:
raise RuntimeError("Delayed restoration of variables not available when "
"there are multiple TPU hosts, please ensure that the "
"api object has been built before you restore.")
for i, p in enumerate(partitions):
with ops.device(hosts[i]):
newkwargs["shape"] = (p, cols)
newkwargs["name"] = "{}_{}".format(name, i)
if isinstance(initial_value, base.CheckpointInitialValue):
# TODO(bfontain): Patch CheckpointInitialValue to take in account the
# position and shape of this shard.
newkwargs["initial_value"] = initial_value
else:
newkwargs["initial_value"] = (
lambda: initial_value(newkwargs["shape"], dtype=dtype))
kwargs["shape"] = (p, cols)
kwargs["name"] = "{}_{}".format(name, i)
kwargs["initial_value"] = functools.partial(
initial_value, kwargs["shape"], dtype=dtype)
variables.append(next_creator(*args, **kwargs))
return TPUShardedVariable(variables, name=name)
return sharded_variable_creator

View File

@ -824,7 +824,7 @@ class Optimizer(
with distribution_strategy.extended.colocate_vars_with(colocate_with):
if eager:
restored_initial_value = self._preload_simple_restoration(
name=name, shape=None)
name=name)
if restored_initial_value is not None:
initial_value = restored_initial_value
v = variable_scope.variable(
@ -1213,11 +1213,15 @@ class Optimizer(
# (aside from double initialization), and makes variable creator scopes
# behave the same way they do when graph building.
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = trackable.CheckpointInitialValue(
initializer = trackable.CheckpointInitialValueCallable(
checkpoint_position=slot_variable_position)
slot_variable = self._get_or_make_slot(
# CheckpointInitialValueCallable will ignore the shape and dtype
# parameters but they must be passed.
slot_variable = self._get_or_make_slot_with_initializer(
var=variable,
val=initializer,
initializer=initializer,
shape=variable.shape,
dtype=variable.dtype,
slot_name=slot_name,
op_name=self._name)
# Slot variables are not owned by any one object (because we don't want to

View File

@ -54,6 +54,31 @@ TrackableReference = collections.namedtuple(
])
class CheckpointInitialValueCallable(object):
"""A callable object that returns a CheckpointInitialValue.
See CheckpointInitialValue for more information.
"""
def __init__(self, checkpoint_position):
self._checkpoint_position = checkpoint_position
@property
def checkpoint_position(self):
return self._checkpoint_position
def __call__(self, shape=None, dtype=None):
# Note that the signature here is for compatibility with normal callable
# initializers which take shape and dtype. Although dtype isn't used, it
# will get passed in by a functool.partial_wrapper in places like
# base_layer_utils.py's make_variable.
return CheckpointInitialValue(self._checkpoint_position, shape)
@property
def restore_uid(self):
return self._checkpoint_position.restore_uid
class CheckpointInitialValue(ops.Tensor):
"""Tensor wrapper for managing update UIDs in `Variables`.
@ -312,7 +337,7 @@ class CheckpointPosition(object):
name="%s_checkpoint_read" % (serialized_tensor.name,))
# Copy the value to the current device if necessary.
value_tensors[serialized_tensor.name] = array_ops.identity(value)
return value_tensors
return value_tensors
def gather_ops_or_named_saveables(self):
"""Looks up or creates SaveableObjects which don't have cached ops."""
@ -735,11 +760,11 @@ class Trackable(object):
# then assigning (when executing eagerly). This call returns None if
# there is nothing to restore.
checkpoint_initializer = self._preload_simple_restoration(
name=name, shape=shape)
name=name)
else:
checkpoint_initializer = None
if (checkpoint_initializer is not None and
not (isinstance(initializer, CheckpointInitialValue) and
not (isinstance(initializer, CheckpointInitialValueCallable) and
(initializer.restore_uid > checkpoint_initializer.restore_uid))):
# If multiple Trackable objects are "creating" the same variable
# via the magic of custom getters, the one with the highest restore UID
@ -767,7 +792,7 @@ class Trackable(object):
# fallback once all get_variable() return types are Trackable.
return new_variable
def _preload_simple_restoration(self, name, shape):
def _preload_simple_restoration(self, name):
"""Return a dependency's value for restore-on-create.
Note the restoration is not deleted; if for some reason preload is called
@ -778,7 +803,6 @@ class Trackable(object):
Args:
name: The object-local name of the dependency holding the variable's
value.
shape: The shape of the variable being loaded into.
Returns:
An callable for use as a variable's initializer/initial_value, or None if
@ -801,8 +825,8 @@ class Trackable(object):
checkpoint_position = max(
deferred_dependencies_list,
key=lambda restore: restore.checkpoint.restore_uid)
return CheckpointInitialValue(
checkpoint_position=checkpoint_position, shape=shape)
return CheckpointInitialValueCallable(
checkpoint_position=checkpoint_position)
def _track_trackable(self, trackable, name, overwrite=False):
"""Declare a dependency on another `Trackable` object.

View File

@ -19,6 +19,7 @@ from __future__ import print_function
import abc
import collections
import functools
import os
import weakref
@ -57,6 +58,7 @@ from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@ -427,10 +429,16 @@ def _default_getter(name,
# Instantiate initializer if provided initializer is a type object.
if isinstance(initializer, type(init_ops.Initializer)):
initializer = initializer(dtype=dtype)
def initial_value():
return initializer(
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
shape_list = None if shape is None else shape_object.as_list()
if "partition_info" in tf_inspect.getargspec(initializer).args:
initial_value = functools.partial(initializer,
shape_list,
dtype=dtype,
partition_info=partition_info)
else:
initial_value = functools.partial(initializer,
shape_list,
dtype=dtype)
return variables.VariableV1(
initial_value=initial_value,