Dist-strat agnostic fix for unwrapping initial_value
if it's a CheckpointInitialValue
at the time a variable is created to avoid dereferencing a Tensor
without a name
attr.
PiperOrigin-RevId: 261231177
This commit is contained in:
parent
41884bf3d3
commit
0b1e5f1854
@ -40,7 +40,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import collective_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -336,11 +335,6 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
|
||||
|
||||
if self._num_workers > 1:
|
||||
if self._is_chief:
|
||||
# Unwrap `initial_value` if it is a `CheckpointInitialValue`.
|
||||
# TODO(b/138130844): Revert the following check once
|
||||
# `CheckpointInitialValue` class is removed.
|
||||
if isinstance(initial_value, trackable.CheckpointInitialValue):
|
||||
initial_value = initial_value.wrapped_value
|
||||
bcast_send = collective_ops.broadcast_send(
|
||||
initial_value, initial_value.shape, initial_value.dtype,
|
||||
group_size, group_key, collective_instance_key)
|
||||
|
@ -123,6 +123,7 @@ from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.losses import loss_reduction
|
||||
from tensorflow.python.ops.losses import losses_impl
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -1300,9 +1301,18 @@ class StrategyExtendedV2(object):
|
||||
def _scope(self, strategy):
|
||||
"""Implementation of tf.distribute.Strategy.scope()."""
|
||||
def creator_with_resource_vars(*args, **kwargs):
|
||||
"""Variable creator to use in `_CurrentDistributionContext`."""
|
||||
_require_strategy_scope_extended(self)
|
||||
kwargs["use_resource"] = True
|
||||
kwargs["distribute_strategy"] = strategy
|
||||
|
||||
# Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
|
||||
# dereferencing a `Tensor` that is without a `name`.
|
||||
# TODO(b/138130844): Revisit the following check once
|
||||
# `CheckpointInitialValue` class is removed.
|
||||
if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
|
||||
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
|
||||
|
||||
return self._create_variable(*args, **kwargs)
|
||||
|
||||
def distributed_getter(getter, *args, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user