diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index c43d28b0226..e35f95a0331 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -40,7 +40,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import collective_ops from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util.tf_export import tf_export @@ -336,11 +335,6 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): if self._num_workers > 1: if self._is_chief: - # Unwrap `initial_value` if it is a `CheckpointInitialValue`. - # TODO(b/138130844): Revert the following check once - # `CheckpointInitialValue` class is removed. - if isinstance(initial_value, trackable.CheckpointInitialValue): - initial_value = initial_value.wrapped_value bcast_send = collective_ops.broadcast_send( initial_value, initial_value.shape, initial_value.dtype, group_size, group_key, collective_instance_key) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index ff53634005e..fd9bed1c592 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -123,6 +123,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import loss_reduction from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging +from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export @@ -1300,9 +1301,18 @@ class StrategyExtendedV2(object): def _scope(self, strategy): """Implementation of tf.distribute.Strategy.scope().""" def creator_with_resource_vars(*args, **kwargs): + """Variable creator to use in `_CurrentDistributionContext`.""" _require_strategy_scope_extended(self) kwargs["use_resource"] = True kwargs["distribute_strategy"] = strategy + + # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid + # dereferencing a `Tensor` that is without a `name`. + # TODO(b/138130844): Revisit the following check once + # `CheckpointInitialValue` class is removed. + if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): + kwargs["initial_value"] = kwargs["initial_value"].wrapped_value + return self._create_variable(*args, **kwargs) def distributed_getter(getter, *args, **kwargs):