Install _distributed_container only at variable creation

With a following change we're going to create new DistributedVariable as return
value of assign*() and scatter*(). Installing _distributed_container multiple
times will be messy.

PiperOrigin-RevId: 331945344
Change-Id: Ida1b15cabed2591e9181eee6e2afe42899aa7049
This commit is contained in:
Ruoxin Sang 2020-09-16 00:12:39 -07:00 committed by TensorFlower Gardener
parent cc0621d70f
commit fc3c78312f
4 changed files with 10 additions and 12 deletions

View File

@ -115,10 +115,10 @@ def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
# pylint: disable=protected-access
assert not isinstance(v0, values_lib.MirroredVariable), (
"ids = %s, values = %s" % ([id(v) for v in values], values))
distributed_container = v0._distributed_container
distributed_container = v0._distributed_container()
assert distributed_container is not None
for v in values[1:]:
assert distributed_container is v._distributed_container
assert distributed_container is v._distributed_container()
return distributed_container
# pylint: enable=protected-access
@ -209,7 +209,7 @@ def value_container(val):
# DistributedVariable has _distributed_container defined
# but we don't want to return it.
not isinstance(val, values_lib.DistributedVariable)):
container = val._distributed_container # pylint: disable=protected-access
container = val._distributed_container() # pylint: disable=protected-access
if container is not None:
return container
return val
@ -318,13 +318,6 @@ def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
else:
var_cls = class_mapping.get(synchronization)
result = var_cls(strategy, value_list, aggregation)
# Install the created DistributedVariable as _distributed_container property
# of the underlying variables, to make it easy to map back to the container.
for v in result.values:
# Hold a strong reference to avoid the container from being GC-ed. After
# v = v.assign(), the user code may no longer holds references to the
# original container, since v.assign() returns a new DistributedVariable.
v._distributed_container = result # pylint: disable=protected-access
# Add the wrapped variable to the requested collections.
# The handling of eager mode and the global step matches

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import copy
import weakref
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_lib
@ -446,6 +447,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
self._aggregation = aggregation
super(DistributedVariable, self).__init__(values)
self._common_name = self._primary.name.split(":")[0]
# Use a weakref to make it easy to map from the contained values
# to the container without introducing a reference cycle.
for v in values:
v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
# Packed variable is used to reduce the overhead of function execution.
# For a DistributedVariable, only one variable handle is captured into a

View File

@ -1345,7 +1345,7 @@ def _var_key(var):
# pylint: disable=protected-access
# Get the distributed variable if it exists.
if hasattr(var, "_distributed_container"):
var = var._distributed_container
var = var._distributed_container()
if var._in_graph_mode:
return var._shared_name
return var._unique_id

View File

@ -759,7 +759,7 @@ class Optimizer(
if hasattr(var, "_distributed_container"):
# NOTE: If this isn't patched, then there is no `handle` in
# `_resource_apply_dense`.
distributed_container = var._distributed_container
distributed_container = var._distributed_container()
assert distributed_container is not None
if ops.executing_eagerly_outside_functions():
key = distributed_container._unique_id