Add callable wrapper to CheckpointValueInitializer so that we can delay the variable restore until after variable creation scopes have been called.
PiperOrigin-RevId: 329595038 Change-Id: I9983bec354514172573e37d50ed6895a4bafb8dc
This commit is contained in:
parent
9119dd3fad
commit
2c9ffb560c
@ -51,7 +51,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase):
|
||||
restore_checkpoint = trackable_utils.Checkpoint()
|
||||
restore_checkpoint.restore(save_path)
|
||||
initial_value = restore_checkpoint._preload_simple_restoration(
|
||||
"v", variable_shape)
|
||||
"v")
|
||||
v = variables_lib.Variable(initial_value)
|
||||
# Check that the variable is now tagged as restored. `Checkpoint` then
|
||||
# knows it doesn't have to restore `v`'s value when it's assigned to an
|
||||
|
@ -45,6 +45,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -436,6 +437,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
initial_value = kwargs["initial_value"]
|
||||
if callable(initial_value):
|
||||
initial_value = initial_value()
|
||||
if isinstance(initial_value, base.CheckpointInitialValue):
|
||||
initial_value = initial_value.wrapped_value
|
||||
assert not callable(initial_value)
|
||||
initial_value = ops.convert_to_tensor(
|
||||
initial_value, dtype=kwargs.get("dtype", None))
|
||||
|
@ -2130,6 +2130,10 @@ class StrategyExtendedV2(object):
|
||||
checkpoint_restore_uid = kwargs[
|
||||
"initial_value"].checkpoint_position.restore_uid
|
||||
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
|
||||
elif isinstance(kwargs["initial_value"],
|
||||
trackable.CheckpointInitialValueCallable):
|
||||
checkpoint_restore_uid = kwargs[
|
||||
"initial_value"].checkpoint_position.restore_uid
|
||||
else:
|
||||
checkpoint_restore_uid = None
|
||||
|
||||
@ -2139,6 +2143,9 @@ class StrategyExtendedV2(object):
|
||||
# pylint: disable=protected-access
|
||||
# Let the checkpointing infrastructure know that the variable was
|
||||
# already restored so it doesn't waste memory loading the value again.
|
||||
# In this case of CheckpointInitialValueCallable this may already be
|
||||
# done by the final variable creator, but it doesn't hurt to do it
|
||||
# again.
|
||||
created._maybe_initialize_trackable()
|
||||
created._update_uid = checkpoint_restore_uid
|
||||
# pylint: enable=protected-access
|
||||
|
@ -217,17 +217,18 @@ class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
|
||||
if constraint is not None and not callable(constraint):
|
||||
raise ValueError("The `constraint` argument must be a callable.")
|
||||
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
self._maybe_initialize_trackable()
|
||||
self._update_uid = initial_value.checkpoint_position.restore_uid
|
||||
initial_value = initial_value.wrapped_value
|
||||
|
||||
with ops.name_scope(name, "Variable", []
|
||||
if init_from_fn else [initial_value]) as scope_name:
|
||||
with ops.name_scope("Initializer"):
|
||||
initial_value = ops.convert_to_tensor(
|
||||
initial_value() if init_from_fn else initial_value,
|
||||
name="initial_value", dtype=dtype)
|
||||
if init_from_fn:
|
||||
initial_value = initial_value()
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
self._maybe_initialize_trackable()
|
||||
self._update_uid = initial_value.checkpoint_position.restore_uid
|
||||
initial_value = initial_value.wrapped_value
|
||||
|
||||
initial_value = ops.convert_to_tensor(initial_value,
|
||||
name="initial_value", dtype=dtype)
|
||||
assert initial_value is not None
|
||||
|
||||
# Don't use `shape or initial_value.shape` since TensorShape has
|
||||
|
@ -142,8 +142,8 @@ class _DelegatingTrackableMixin(object):
|
||||
return self._trackable._add_variable_with_custom_getter(
|
||||
name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter)
|
||||
|
||||
def _preload_simple_restoration(self, name, shape):
|
||||
return self._trackable._preload_simple_restoration(name, shape)
|
||||
def _preload_simple_restoration(self, name):
|
||||
return self._trackable._preload_simple_restoration(name)
|
||||
|
||||
def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name
|
||||
return self._trackable._track_trackable(trackable, name, overwrite)
|
||||
|
@ -1292,7 +1292,7 @@ class OptimizerV2(trackable.Trackable):
|
||||
# (aside from double initialization), and makes variable creator scopes
|
||||
# behave the same way they do when graph building.
|
||||
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
|
||||
initializer = trackable.CheckpointInitialValue(
|
||||
initializer = trackable.CheckpointInitialValueCallable(
|
||||
checkpoint_position=slot_variable_position)
|
||||
slot_variable = self.add_slot(
|
||||
var=variable,
|
||||
|
@ -1686,11 +1686,6 @@ class ResourceVariable(BaseResourceVariable):
|
||||
if constraint is not None and not callable(constraint):
|
||||
raise ValueError("The `constraint` argument must be a callable.")
|
||||
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
self._maybe_initialize_trackable()
|
||||
self._update_uid = initial_value.checkpoint_position.restore_uid
|
||||
initial_value = initial_value.wrapped_value
|
||||
|
||||
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
|
||||
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
|
||||
with ops.init_scope():
|
||||
@ -1719,10 +1714,15 @@ class ResourceVariable(BaseResourceVariable):
|
||||
s=[compat.as_bytes("loc:@%s" % handle_name)]))
|
||||
with ops.get_default_graph()._attr_scope({"_class": attr}):
|
||||
with ops.name_scope("Initializer"), device_context_manager(None):
|
||||
initial_value = ops.convert_to_tensor(
|
||||
initial_value() if init_from_fn else initial_value,
|
||||
name="initial_value",
|
||||
dtype=dtype)
|
||||
if init_from_fn:
|
||||
initial_value = initial_value()
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
self._maybe_initialize_trackable()
|
||||
self._update_uid = initial_value.checkpoint_position.restore_uid
|
||||
initial_value = initial_value.wrapped_value
|
||||
initial_value = ops.convert_to_tensor(initial_value,
|
||||
name="initial_value",
|
||||
dtype=dtype)
|
||||
if shape is not None:
|
||||
if not initial_value.shape.is_compatible_with(shape):
|
||||
raise ValueError(
|
||||
|
@ -270,6 +270,24 @@ def disable_resource_variables():
|
||||
_api_usage_gauge.get_cell().set(False)
|
||||
|
||||
|
||||
def _needs_no_arguments(python_callable):
|
||||
"""Returns true if the callable needs no arguments to call."""
|
||||
# TODO(bfontain): Switch to inspect.signature when we are python 3 only.
|
||||
# signature = inspect.signature(python_callable)
|
||||
# return not [1 for param in signature.parameters.values()
|
||||
# if param.default == param.empty]
|
||||
num_arguments = len(tf_inspect.getargspec(python_callable).args)
|
||||
if not tf_inspect.isfunction(python_callable) and not isinstance(
|
||||
python_callable, functools.partial):
|
||||
# getargspec includes self for function objects (which aren't
|
||||
# functools.partial). This has no default so we need to remove it.
|
||||
# It is not even an argument so its odd that getargspec returns this.
|
||||
# Note that this is fixed with inspect.signature in Python 3.
|
||||
num_arguments -= 1
|
||||
return num_arguments == len(
|
||||
tf_inspect.getargspec(python_callable).defaults or [])
|
||||
|
||||
|
||||
class _VariableStore(object):
|
||||
"""Variable store that carries a number of named Variables.
|
||||
|
||||
@ -905,18 +923,17 @@ class _VariableStore(object):
|
||||
# Instantiate initializer if provided initializer is a type object.
|
||||
if tf_inspect.isclass(initializer):
|
||||
initializer = initializer()
|
||||
if shape is not None and shape.is_fully_defined():
|
||||
if shape.is_fully_defined():
|
||||
if "partition_info" in tf_inspect.getargspec(initializer).args:
|
||||
init_val = lambda: initializer( # pylint: disable=g-long-lambda
|
||||
shape.as_list(),
|
||||
dtype=dtype,
|
||||
partition_info=partition_info)
|
||||
init_val = functools.partial(initializer,
|
||||
shape.as_list(),
|
||||
dtype=dtype,
|
||||
partition_info=partition_info)
|
||||
else:
|
||||
init_val = lambda: initializer( # pylint: disable=g-long-lambda
|
||||
shape.as_list(), dtype=dtype)
|
||||
init_val = functools.partial(initializer,
|
||||
shape.as_list(), dtype=dtype)
|
||||
variable_dtype = dtype.base_dtype
|
||||
elif len(tf_inspect.getargspec(initializer).args) == len(
|
||||
tf_inspect.getargspec(initializer).defaults or []):
|
||||
elif _needs_no_arguments(initializer):
|
||||
init_val = initializer
|
||||
variable_dtype = None
|
||||
else:
|
||||
|
@ -1794,8 +1794,13 @@ class RefVariable(VariableV1, core.Tensor):
|
||||
# pylint: disable=protected-access
|
||||
with ops.get_default_graph()._attr_scope({"_class": attr}):
|
||||
with ops.name_scope("Initializer"), ops.device(None):
|
||||
initial_value = initial_value()
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
self._maybe_initialize_trackable()
|
||||
self._update_uid = initial_value.checkpoint_position.restore_uid
|
||||
initial_value = initial_value.wrapped_value
|
||||
self._initial_value = ops.convert_to_tensor(
|
||||
initial_value(), name="initial_value", dtype=dtype)
|
||||
initial_value, name="initial_value", dtype=dtype)
|
||||
if shape is None:
|
||||
shape = (
|
||||
self._initial_value.get_shape()
|
||||
|
@ -774,19 +774,19 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
|
||||
def create_variables(table):
|
||||
"""Create all variables."""
|
||||
shape = (table.vocabulary_size, table.dim)
|
||||
variable_shape = (table.vocabulary_size, table.dim)
|
||||
|
||||
def getter(name, shape, dtype, initializer, trainable):
|
||||
# TODO(bfontain): make CheckpointInitialValue a callable rather than
|
||||
# something that inherits from tensor.
|
||||
if not isinstance(initializer, base.CheckpointInitialValue):
|
||||
initial_value = functools.partial(initializer, shape, dtype=dtype)
|
||||
else:
|
||||
initial_value = initializer
|
||||
|
||||
del shape
|
||||
# _add_variable_with_custom_getter clears the shape sometimes, so we
|
||||
# take the global shape from outside the getter.
|
||||
initial_value = functools.partial(initializer, variable_shape,
|
||||
dtype=dtype)
|
||||
return tf_variables.Variable(
|
||||
name=name,
|
||||
initial_value=initial_value,
|
||||
shape=variable_shape,
|
||||
dtype=dtype,
|
||||
trainable=trainable)
|
||||
|
||||
def variable_creator(name, initializer, trainable=True):
|
||||
@ -796,7 +796,7 @@ class TPUEmbedding(tracking.AutoTrackable):
|
||||
return self._add_variable_with_custom_getter(
|
||||
name=name,
|
||||
initializer=initializer,
|
||||
shape=shape,
|
||||
shape=variable_shape,
|
||||
dtype=dtypes.float32,
|
||||
getter=getter,
|
||||
trainable=trainable)
|
||||
@ -1490,9 +1490,6 @@ def extract_variable_info(kwargs):
|
||||
return (kwargs["name"], shape,
|
||||
kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]),
|
||||
kwargs["initial_value"].func)
|
||||
elif isinstance(kwargs["initial_value"], base.CheckpointInitialValue):
|
||||
return (kwargs["name"], kwargs["initial_value"].shape,
|
||||
kwargs["initial_value"].dtype, kwargs["initial_value"])
|
||||
elif "shape" not in kwargs or kwargs["shape"] is None:
|
||||
raise ValueError(
|
||||
"Unable to extract initializer function and shape from {}. Please "
|
||||
@ -1529,26 +1526,21 @@ def make_sharded_variable_creator(hosts):
|
||||
partitions = ([rows // num_hosts + 1] * missing + [rows // num_hosts] *
|
||||
(num_hosts - missing))
|
||||
variables = []
|
||||
newkwargs = kwargs
|
||||
newkwargs["dtype"] = dtype
|
||||
kwargs["dtype"] = dtype
|
||||
# TODO(bfontain): Remove this check once we can pass position and shape of
|
||||
# shards to CheckpointInitialValue.
|
||||
if isinstance(initial_value, base.CheckpointInitialValue) and num_hosts > 1:
|
||||
# shards to CheckpointInitialValueCallable.
|
||||
if isinstance(initial_value,
|
||||
base.CheckpointInitialValueCallable) and num_hosts > 1:
|
||||
raise RuntimeError("Delayed restoration of variables not available when "
|
||||
"there are multiple TPU hosts, please ensure that the "
|
||||
"api object has been built before you restore.")
|
||||
|
||||
for i, p in enumerate(partitions):
|
||||
with ops.device(hosts[i]):
|
||||
newkwargs["shape"] = (p, cols)
|
||||
newkwargs["name"] = "{}_{}".format(name, i)
|
||||
if isinstance(initial_value, base.CheckpointInitialValue):
|
||||
# TODO(bfontain): Patch CheckpointInitialValue to take in account the
|
||||
# position and shape of this shard.
|
||||
newkwargs["initial_value"] = initial_value
|
||||
else:
|
||||
newkwargs["initial_value"] = (
|
||||
lambda: initial_value(newkwargs["shape"], dtype=dtype))
|
||||
kwargs["shape"] = (p, cols)
|
||||
kwargs["name"] = "{}_{}".format(name, i)
|
||||
kwargs["initial_value"] = functools.partial(
|
||||
initial_value, kwargs["shape"], dtype=dtype)
|
||||
variables.append(next_creator(*args, **kwargs))
|
||||
return TPUShardedVariable(variables, name=name)
|
||||
return sharded_variable_creator
|
||||
|
@ -824,7 +824,7 @@ class Optimizer(
|
||||
with distribution_strategy.extended.colocate_vars_with(colocate_with):
|
||||
if eager:
|
||||
restored_initial_value = self._preload_simple_restoration(
|
||||
name=name, shape=None)
|
||||
name=name)
|
||||
if restored_initial_value is not None:
|
||||
initial_value = restored_initial_value
|
||||
v = variable_scope.variable(
|
||||
@ -1213,11 +1213,15 @@ class Optimizer(
|
||||
# (aside from double initialization), and makes variable creator scopes
|
||||
# behave the same way they do when graph building.
|
||||
and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
|
||||
initializer = trackable.CheckpointInitialValue(
|
||||
initializer = trackable.CheckpointInitialValueCallable(
|
||||
checkpoint_position=slot_variable_position)
|
||||
slot_variable = self._get_or_make_slot(
|
||||
# CheckpointInitialValueCallable will ignore the shape and dtype
|
||||
# parameters but they must be passed.
|
||||
slot_variable = self._get_or_make_slot_with_initializer(
|
||||
var=variable,
|
||||
val=initializer,
|
||||
initializer=initializer,
|
||||
shape=variable.shape,
|
||||
dtype=variable.dtype,
|
||||
slot_name=slot_name,
|
||||
op_name=self._name)
|
||||
# Slot variables are not owned by any one object (because we don't want to
|
||||
|
@ -54,6 +54,31 @@ TrackableReference = collections.namedtuple(
|
||||
])
|
||||
|
||||
|
||||
class CheckpointInitialValueCallable(object):
|
||||
"""A callable object that returns a CheckpointInitialValue.
|
||||
|
||||
See CheckpointInitialValue for more information.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_position):
|
||||
self._checkpoint_position = checkpoint_position
|
||||
|
||||
@property
|
||||
def checkpoint_position(self):
|
||||
return self._checkpoint_position
|
||||
|
||||
def __call__(self, shape=None, dtype=None):
|
||||
# Note that the signature here is for compatibility with normal callable
|
||||
# initializers which take shape and dtype. Although dtype isn't used, it
|
||||
# will get passed in by a functool.partial_wrapper in places like
|
||||
# base_layer_utils.py's make_variable.
|
||||
return CheckpointInitialValue(self._checkpoint_position, shape)
|
||||
|
||||
@property
|
||||
def restore_uid(self):
|
||||
return self._checkpoint_position.restore_uid
|
||||
|
||||
|
||||
class CheckpointInitialValue(ops.Tensor):
|
||||
"""Tensor wrapper for managing update UIDs in `Variables`.
|
||||
|
||||
@ -312,7 +337,7 @@ class CheckpointPosition(object):
|
||||
name="%s_checkpoint_read" % (serialized_tensor.name,))
|
||||
# Copy the value to the current device if necessary.
|
||||
value_tensors[serialized_tensor.name] = array_ops.identity(value)
|
||||
return value_tensors
|
||||
return value_tensors
|
||||
|
||||
def gather_ops_or_named_saveables(self):
|
||||
"""Looks up or creates SaveableObjects which don't have cached ops."""
|
||||
@ -735,11 +760,11 @@ class Trackable(object):
|
||||
# then assigning (when executing eagerly). This call returns None if
|
||||
# there is nothing to restore.
|
||||
checkpoint_initializer = self._preload_simple_restoration(
|
||||
name=name, shape=shape)
|
||||
name=name)
|
||||
else:
|
||||
checkpoint_initializer = None
|
||||
if (checkpoint_initializer is not None and
|
||||
not (isinstance(initializer, CheckpointInitialValue) and
|
||||
not (isinstance(initializer, CheckpointInitialValueCallable) and
|
||||
(initializer.restore_uid > checkpoint_initializer.restore_uid))):
|
||||
# If multiple Trackable objects are "creating" the same variable
|
||||
# via the magic of custom getters, the one with the highest restore UID
|
||||
@ -767,7 +792,7 @@ class Trackable(object):
|
||||
# fallback once all get_variable() return types are Trackable.
|
||||
return new_variable
|
||||
|
||||
def _preload_simple_restoration(self, name, shape):
|
||||
def _preload_simple_restoration(self, name):
|
||||
"""Return a dependency's value for restore-on-create.
|
||||
|
||||
Note the restoration is not deleted; if for some reason preload is called
|
||||
@ -778,7 +803,6 @@ class Trackable(object):
|
||||
Args:
|
||||
name: The object-local name of the dependency holding the variable's
|
||||
value.
|
||||
shape: The shape of the variable being loaded into.
|
||||
|
||||
Returns:
|
||||
An callable for use as a variable's initializer/initial_value, or None if
|
||||
@ -801,8 +825,8 @@ class Trackable(object):
|
||||
checkpoint_position = max(
|
||||
deferred_dependencies_list,
|
||||
key=lambda restore: restore.checkpoint.restore_uid)
|
||||
return CheckpointInitialValue(
|
||||
checkpoint_position=checkpoint_position, shape=shape)
|
||||
return CheckpointInitialValueCallable(
|
||||
checkpoint_position=checkpoint_position)
|
||||
|
||||
def _track_trackable(self, trackable, name, overwrite=False):
|
||||
"""Declare a dependency on another `Trackable` object.
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import functools
|
||||
import os
|
||||
import weakref
|
||||
|
||||
@ -57,6 +58,7 @@ from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import object_identity
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -427,10 +429,16 @@ def _default_getter(name,
|
||||
# Instantiate initializer if provided initializer is a type object.
|
||||
if isinstance(initializer, type(init_ops.Initializer)):
|
||||
initializer = initializer(dtype=dtype)
|
||||
|
||||
def initial_value():
|
||||
return initializer(
|
||||
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
|
||||
shape_list = None if shape is None else shape_object.as_list()
|
||||
if "partition_info" in tf_inspect.getargspec(initializer).args:
|
||||
initial_value = functools.partial(initializer,
|
||||
shape_list,
|
||||
dtype=dtype,
|
||||
partition_info=partition_info)
|
||||
else:
|
||||
initial_value = functools.partial(initializer,
|
||||
shape_list,
|
||||
dtype=dtype)
|
||||
|
||||
return variables.VariableV1(
|
||||
initial_value=initial_value,
|
||||
|
Loading…
Reference in New Issue
Block a user