Dist-strat agnostic fix for unwrapping initial_value if it's a CheckpointInitialValue at the time a variable is created to avoid dereferencing a Tensor without a name attr.

PiperOrigin-RevId: 261231177
This commit is contained in:
Rick Chao 2019-08-01 17:23:17 -07:00 committed by TensorFlower Gardener
parent 41884bf3d3
commit 0b1e5f1854
2 changed files with 10 additions and 6 deletions

View File

@ -40,7 +40,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util.tf_export import tf_export
@ -336,11 +335,6 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
if self._num_workers > 1:
if self._is_chief:
# Unwrap `initial_value` if it is a `CheckpointInitialValue`.
# TODO(b/138130844): Revert the following check once
# `CheckpointInitialValue` class is removed.
if isinstance(initial_value, trackable.CheckpointInitialValue):
initial_value = initial_value.wrapped_value
bcast_send = collective_ops.broadcast_send(
initial_value, initial_value.shape, initial_value.dtype,
group_size, group_key, collective_instance_key)

View File

@ -123,6 +123,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import loss_reduction
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
@ -1300,9 +1301,18 @@ class StrategyExtendedV2(object):
def _scope(self, strategy):
"""Implementation of tf.distribute.Strategy.scope()."""
def creator_with_resource_vars(*args, **kwargs):
"""Variable creator to use in `_CurrentDistributionContext`."""
_require_strategy_scope_extended(self)
kwargs["use_resource"] = True
kwargs["distribute_strategy"] = strategy
# Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid
# dereferencing a `Tensor` that is without a `name`.
# TODO(b/138130844): Revisit the following check once
# `CheckpointInitialValue` class is removed.
if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
return self._create_variable(*args, **kwargs)
def distributed_getter(getter, *args, **kwargs):