diff --git a/tensorflow/python/distribute/checkpointing_test.py b/tensorflow/python/distribute/checkpointing_test.py index a4be193284e..8b4af726643 100644 --- a/tensorflow/python/distribute/checkpointing_test.py +++ b/tensorflow/python/distribute/checkpointing_test.py @@ -51,7 +51,7 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase): restore_checkpoint = trackable_utils.Checkpoint() restore_checkpoint.restore(save_path) initial_value = restore_checkpoint._preload_simple_restoration( - "v", variable_shape) + "v") v = variables_lib.Variable(initial_value) # Check that the variable is now tagged as restored. `Checkpoint` then # knows it doesn't have to restore `v`'s value when it's assigned to an diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index 49b6a93678c..5c57805ada4 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -45,6 +45,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.tracking import base from tensorflow.python.util.tf_export import tf_export @@ -436,6 +437,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): initial_value = kwargs["initial_value"] if callable(initial_value): initial_value = initial_value() + if isinstance(initial_value, base.CheckpointInitialValue): + initial_value = initial_value.wrapped_value assert not callable(initial_value) initial_value = ops.convert_to_tensor( initial_value, dtype=kwargs.get("dtype", None)) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index bd24ec0145d..e52534ef8ee 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -2130,6 +2130,10 @@ class StrategyExtendedV2(object): checkpoint_restore_uid = kwargs[ "initial_value"].checkpoint_position.restore_uid kwargs["initial_value"] = kwargs["initial_value"].wrapped_value + elif isinstance(kwargs["initial_value"], + trackable.CheckpointInitialValueCallable): + checkpoint_restore_uid = kwargs[ + "initial_value"].checkpoint_position.restore_uid else: checkpoint_restore_uid = None @@ -2139,6 +2143,9 @@ class StrategyExtendedV2(object): # pylint: disable=protected-access # Let the checkpointing infrastructure know that the variable was # already restored so it doesn't waste memory loading the value again. + # In this case of CheckpointInitialValueCallable this may already be + # done by the final variable creator, but it doesn't hurt to do it + # again. created._maybe_initialize_trackable() created._update_uid = checkpoint_restore_uid # pylint: enable=protected-access diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 3199747de53..9d9101bc27b 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -217,17 +217,18 @@ class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") - if isinstance(initial_value, trackable.CheckpointInitialValue): - self._maybe_initialize_trackable() - self._update_uid = initial_value.checkpoint_position.restore_uid - initial_value = initial_value.wrapped_value - with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as scope_name: with ops.name_scope("Initializer"): - initial_value = ops.convert_to_tensor( - initial_value() if init_from_fn else initial_value, - name="initial_value", dtype=dtype) + if init_from_fn: + initial_value = initial_value() + if isinstance(initial_value, trackable.CheckpointInitialValue): + self._maybe_initialize_trackable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + + initial_value = ops.convert_to_tensor(initial_value, + name="initial_value", dtype=dtype) assert initial_value is not None # Don't use `shape or initial_value.shape` since TensorShape has diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py index 69b39e3f989..c309fafa4a4 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_optimizer.py @@ -142,8 +142,8 @@ class _DelegatingTrackableMixin(object): return self._trackable._add_variable_with_custom_getter( name, shape, dtype, initializer, getter, overwrite, **kwargs_for_getter) - def _preload_simple_restoration(self, name, shape): - return self._trackable._preload_simple_restoration(name, shape) + def _preload_simple_restoration(self, name): + return self._trackable._preload_simple_restoration(name) def _track_trackable(self, trackable, name, overwrite=False): # pylint: disable=redefined-outer-name return self._trackable._track_trackable(trackable, name, overwrite) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 1227e661101..918b0cb8692 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -1292,7 +1292,7 @@ class OptimizerV2(trackable.Trackable): # (aside from double initialization), and makes variable creator scopes # behave the same way they do when graph building. and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access - initializer = trackable.CheckpointInitialValue( + initializer = trackable.CheckpointInitialValueCallable( checkpoint_position=slot_variable_position) slot_variable = self.add_slot( var=variable, diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 5d4eeba2994..b12a2023d54 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -1686,11 +1686,6 @@ class ResourceVariable(BaseResourceVariable): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") - if isinstance(initial_value, trackable.CheckpointInitialValue): - self._maybe_initialize_trackable() - self._update_uid = initial_value.checkpoint_position.restore_uid - initial_value = initial_value.wrapped_value - if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): @@ -1719,10 +1714,15 @@ class ResourceVariable(BaseResourceVariable): s=[compat.as_bytes("loc:@%s" % handle_name)])) with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), device_context_manager(None): - initial_value = ops.convert_to_tensor( - initial_value() if init_from_fn else initial_value, - name="initial_value", - dtype=dtype) + if init_from_fn: + initial_value = initial_value() + if isinstance(initial_value, trackable.CheckpointInitialValue): + self._maybe_initialize_trackable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + initial_value = ops.convert_to_tensor(initial_value, + name="initial_value", + dtype=dtype) if shape is not None: if not initial_value.shape.is_compatible_with(shape): raise ValueError( diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 6e0e83f8564..c1804f770b1 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -270,6 +270,24 @@ def disable_resource_variables(): _api_usage_gauge.get_cell().set(False) +def _needs_no_arguments(python_callable): + """Returns true if the callable needs no arguments to call.""" + # TODO(bfontain): Switch to inspect.signature when we are python 3 only. + # signature = inspect.signature(python_callable) + # return not [1 for param in signature.parameters.values() + # if param.default == param.empty] + num_arguments = len(tf_inspect.getargspec(python_callable).args) + if not tf_inspect.isfunction(python_callable) and not isinstance( + python_callable, functools.partial): + # getargspec includes self for function objects (which aren't + # functools.partial). This has no default so we need to remove it. + # It is not even an argument so its odd that getargspec returns this. + # Note that this is fixed with inspect.signature in Python 3. + num_arguments -= 1 + return num_arguments == len( + tf_inspect.getargspec(python_callable).defaults or []) + + class _VariableStore(object): """Variable store that carries a number of named Variables. @@ -905,18 +923,17 @@ class _VariableStore(object): # Instantiate initializer if provided initializer is a type object. if tf_inspect.isclass(initializer): initializer = initializer() - if shape is not None and shape.is_fully_defined(): + if shape.is_fully_defined(): if "partition_info" in tf_inspect.getargspec(initializer).args: - init_val = lambda: initializer( # pylint: disable=g-long-lambda - shape.as_list(), - dtype=dtype, - partition_info=partition_info) + init_val = functools.partial(initializer, + shape.as_list(), + dtype=dtype, + partition_info=partition_info) else: - init_val = lambda: initializer( # pylint: disable=g-long-lambda - shape.as_list(), dtype=dtype) + init_val = functools.partial(initializer, + shape.as_list(), dtype=dtype) variable_dtype = dtype.base_dtype - elif len(tf_inspect.getargspec(initializer).args) == len( - tf_inspect.getargspec(initializer).defaults or []): + elif _needs_no_arguments(initializer): init_val = initializer variable_dtype = None else: diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 4d6ca923af3..4e79ec97ff9 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -1794,8 +1794,13 @@ class RefVariable(VariableV1, core.Tensor): # pylint: disable=protected-access with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), ops.device(None): + initial_value = initial_value() + if isinstance(initial_value, trackable.CheckpointInitialValue): + self._maybe_initialize_trackable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value self._initial_value = ops.convert_to_tensor( - initial_value(), name="initial_value", dtype=dtype) + initial_value, name="initial_value", dtype=dtype) if shape is None: shape = ( self._initial_value.get_shape() diff --git a/tensorflow/python/tpu/tpu_embedding_v2.py b/tensorflow/python/tpu/tpu_embedding_v2.py index 74f04bdd945..61bc443c637 100644 --- a/tensorflow/python/tpu/tpu_embedding_v2.py +++ b/tensorflow/python/tpu/tpu_embedding_v2.py @@ -774,19 +774,19 @@ class TPUEmbedding(tracking.AutoTrackable): def create_variables(table): """Create all variables.""" - shape = (table.vocabulary_size, table.dim) + variable_shape = (table.vocabulary_size, table.dim) def getter(name, shape, dtype, initializer, trainable): - # TODO(bfontain): make CheckpointInitialValue a callable rather than - # something that inherits from tensor. - if not isinstance(initializer, base.CheckpointInitialValue): - initial_value = functools.partial(initializer, shape, dtype=dtype) - else: - initial_value = initializer - + del shape + # _add_variable_with_custom_getter clears the shape sometimes, so we + # take the global shape from outside the getter. + initial_value = functools.partial(initializer, variable_shape, + dtype=dtype) return tf_variables.Variable( name=name, initial_value=initial_value, + shape=variable_shape, + dtype=dtype, trainable=trainable) def variable_creator(name, initializer, trainable=True): @@ -796,7 +796,7 @@ class TPUEmbedding(tracking.AutoTrackable): return self._add_variable_with_custom_getter( name=name, initializer=initializer, - shape=shape, + shape=variable_shape, dtype=dtypes.float32, getter=getter, trainable=trainable) @@ -1490,9 +1490,6 @@ def extract_variable_info(kwargs): return (kwargs["name"], shape, kwargs["initial_value"].keywords.get("dtype", kwargs["dtype"]), kwargs["initial_value"].func) - elif isinstance(kwargs["initial_value"], base.CheckpointInitialValue): - return (kwargs["name"], kwargs["initial_value"].shape, - kwargs["initial_value"].dtype, kwargs["initial_value"]) elif "shape" not in kwargs or kwargs["shape"] is None: raise ValueError( "Unable to extract initializer function and shape from {}. Please " @@ -1529,26 +1526,21 @@ def make_sharded_variable_creator(hosts): partitions = ([rows // num_hosts + 1] * missing + [rows // num_hosts] * (num_hosts - missing)) variables = [] - newkwargs = kwargs - newkwargs["dtype"] = dtype + kwargs["dtype"] = dtype # TODO(bfontain): Remove this check once we can pass position and shape of - # shards to CheckpointInitialValue. - if isinstance(initial_value, base.CheckpointInitialValue) and num_hosts > 1: + # shards to CheckpointInitialValueCallable. + if isinstance(initial_value, + base.CheckpointInitialValueCallable) and num_hosts > 1: raise RuntimeError("Delayed restoration of variables not available when " "there are multiple TPU hosts, please ensure that the " "api object has been built before you restore.") for i, p in enumerate(partitions): with ops.device(hosts[i]): - newkwargs["shape"] = (p, cols) - newkwargs["name"] = "{}_{}".format(name, i) - if isinstance(initial_value, base.CheckpointInitialValue): - # TODO(bfontain): Patch CheckpointInitialValue to take in account the - # position and shape of this shard. - newkwargs["initial_value"] = initial_value - else: - newkwargs["initial_value"] = ( - lambda: initial_value(newkwargs["shape"], dtype=dtype)) + kwargs["shape"] = (p, cols) + kwargs["name"] = "{}_{}".format(name, i) + kwargs["initial_value"] = functools.partial( + initial_value, kwargs["shape"], dtype=dtype) variables.append(next_creator(*args, **kwargs)) return TPUShardedVariable(variables, name=name) return sharded_variable_creator diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 1fe8a8c729b..0b98efe44a4 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -824,7 +824,7 @@ class Optimizer( with distribution_strategy.extended.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( - name=name, shape=None) + name=name) if restored_initial_value is not None: initial_value = restored_initial_value v = variable_scope.variable( @@ -1213,11 +1213,15 @@ class Optimizer( # (aside from double initialization), and makes variable creator scopes # behave the same way they do when graph building. and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access - initializer = trackable.CheckpointInitialValue( + initializer = trackable.CheckpointInitialValueCallable( checkpoint_position=slot_variable_position) - slot_variable = self._get_or_make_slot( + # CheckpointInitialValueCallable will ignore the shape and dtype + # parameters but they must be passed. + slot_variable = self._get_or_make_slot_with_initializer( var=variable, - val=initializer, + initializer=initializer, + shape=variable.shape, + dtype=variable.dtype, slot_name=slot_name, op_name=self._name) # Slot variables are not owned by any one object (because we don't want to diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py index a8b0410dc77..3dff0097bbc 100644 --- a/tensorflow/python/training/tracking/base.py +++ b/tensorflow/python/training/tracking/base.py @@ -54,6 +54,31 @@ TrackableReference = collections.namedtuple( ]) +class CheckpointInitialValueCallable(object): + """A callable object that returns a CheckpointInitialValue. + + See CheckpointInitialValue for more information. + """ + + def __init__(self, checkpoint_position): + self._checkpoint_position = checkpoint_position + + @property + def checkpoint_position(self): + return self._checkpoint_position + + def __call__(self, shape=None, dtype=None): + # Note that the signature here is for compatibility with normal callable + # initializers which take shape and dtype. Although dtype isn't used, it + # will get passed in by a functool.partial_wrapper in places like + # base_layer_utils.py's make_variable. + return CheckpointInitialValue(self._checkpoint_position, shape) + + @property + def restore_uid(self): + return self._checkpoint_position.restore_uid + + class CheckpointInitialValue(ops.Tensor): """Tensor wrapper for managing update UIDs in `Variables`. @@ -312,7 +337,7 @@ class CheckpointPosition(object): name="%s_checkpoint_read" % (serialized_tensor.name,)) # Copy the value to the current device if necessary. value_tensors[serialized_tensor.name] = array_ops.identity(value) - return value_tensors + return value_tensors def gather_ops_or_named_saveables(self): """Looks up or creates SaveableObjects which don't have cached ops.""" @@ -735,11 +760,11 @@ class Trackable(object): # then assigning (when executing eagerly). This call returns None if # there is nothing to restore. checkpoint_initializer = self._preload_simple_restoration( - name=name, shape=shape) + name=name) else: checkpoint_initializer = None if (checkpoint_initializer is not None and - not (isinstance(initializer, CheckpointInitialValue) and + not (isinstance(initializer, CheckpointInitialValueCallable) and (initializer.restore_uid > checkpoint_initializer.restore_uid))): # If multiple Trackable objects are "creating" the same variable # via the magic of custom getters, the one with the highest restore UID @@ -767,7 +792,7 @@ class Trackable(object): # fallback once all get_variable() return types are Trackable. return new_variable - def _preload_simple_restoration(self, name, shape): + def _preload_simple_restoration(self, name): """Return a dependency's value for restore-on-create. Note the restoration is not deleted; if for some reason preload is called @@ -778,7 +803,6 @@ class Trackable(object): Args: name: The object-local name of the dependency holding the variable's value. - shape: The shape of the variable being loaded into. Returns: An callable for use as a variable's initializer/initial_value, or None if @@ -801,8 +825,8 @@ class Trackable(object): checkpoint_position = max( deferred_dependencies_list, key=lambda restore: restore.checkpoint.restore_uid) - return CheckpointInitialValue( - checkpoint_position=checkpoint_position, shape=shape) + return CheckpointInitialValueCallable( + checkpoint_position=checkpoint_position) def _track_trackable(self, trackable, name, overwrite=False): """Declare a dependency on another `Trackable` object. diff --git a/tensorflow/python/training/tracking/util.py b/tensorflow/python/training/tracking/util.py index 95c8f9d2b60..d6fdfbc04ee 100644 --- a/tensorflow/python/training/tracking/util.py +++ b/tensorflow/python/training/tracking/util.py @@ -19,6 +19,7 @@ from __future__ import print_function import abc import collections +import functools import os import weakref @@ -57,6 +58,7 @@ from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import object_identity from tensorflow.python.util import tf_contextlib +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -427,10 +429,16 @@ def _default_getter(name, # Instantiate initializer if provided initializer is a type object. if isinstance(initializer, type(init_ops.Initializer)): initializer = initializer(dtype=dtype) - - def initial_value(): - return initializer( - shape_object.as_list(), dtype=dtype, partition_info=partition_info) + shape_list = None if shape is None else shape_object.as_list() + if "partition_info" in tf_inspect.getargspec(initializer).args: + initial_value = functools.partial(initializer, + shape_list, + dtype=dtype, + partition_info=partition_info) + else: + initial_value = functools.partial(initializer, + shape_list, + dtype=dtype) return variables.VariableV1( initial_value=initial_value,