Install _distributed_container only at variable creation
With a following change we're going to create new DistributedVariable as return value of assign*() and scatter*(). Installing _distributed_container multiple times will be messy. PiperOrigin-RevId: 331945344 Change-Id: Ida1b15cabed2591e9181eee6e2afe42899aa7049
This commit is contained in:
parent
cc0621d70f
commit
fc3c78312f
@ -115,10 +115,10 @@ def regroup(values, wrap_class=values_lib.PerReplica, always_wrap=False):
|
||||
# pylint: disable=protected-access
|
||||
assert not isinstance(v0, values_lib.MirroredVariable), (
|
||||
"ids = %s, values = %s" % ([id(v) for v in values], values))
|
||||
distributed_container = v0._distributed_container
|
||||
distributed_container = v0._distributed_container()
|
||||
assert distributed_container is not None
|
||||
for v in values[1:]:
|
||||
assert distributed_container is v._distributed_container
|
||||
assert distributed_container is v._distributed_container()
|
||||
return distributed_container
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@ -209,7 +209,7 @@ def value_container(val):
|
||||
# DistributedVariable has _distributed_container defined
|
||||
# but we don't want to return it.
|
||||
not isinstance(val, values_lib.DistributedVariable)):
|
||||
container = val._distributed_container # pylint: disable=protected-access
|
||||
container = val._distributed_container() # pylint: disable=protected-access
|
||||
if container is not None:
|
||||
return container
|
||||
return val
|
||||
@ -318,13 +318,6 @@ def create_mirrored_variable(strategy, real_mirrored_creator, class_mapping,
|
||||
else:
|
||||
var_cls = class_mapping.get(synchronization)
|
||||
result = var_cls(strategy, value_list, aggregation)
|
||||
# Install the created DistributedVariable as _distributed_container property
|
||||
# of the underlying variables, to make it easy to map back to the container.
|
||||
for v in result.values:
|
||||
# Hold a strong reference to avoid the container from being GC-ed. After
|
||||
# v = v.assign(), the user code may no longer holds references to the
|
||||
# original container, since v.assign() returns a new DistributedVariable.
|
||||
v._distributed_container = result # pylint: disable=protected-access
|
||||
|
||||
# Add the wrapped variable to the requested collections.
|
||||
# The handling of eager mode and the global step matches
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.distribute import device_util
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
@ -446,6 +447,10 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable,
|
||||
self._aggregation = aggregation
|
||||
super(DistributedVariable, self).__init__(values)
|
||||
self._common_name = self._primary.name.split(":")[0]
|
||||
# Use a weakref to make it easy to map from the contained values
|
||||
# to the container without introducing a reference cycle.
|
||||
for v in values:
|
||||
v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
|
||||
|
||||
# Packed variable is used to reduce the overhead of function execution.
|
||||
# For a DistributedVariable, only one variable handle is captured into a
|
||||
|
@ -1345,7 +1345,7 @@ def _var_key(var):
|
||||
# pylint: disable=protected-access
|
||||
# Get the distributed variable if it exists.
|
||||
if hasattr(var, "_distributed_container"):
|
||||
var = var._distributed_container
|
||||
var = var._distributed_container()
|
||||
if var._in_graph_mode:
|
||||
return var._shared_name
|
||||
return var._unique_id
|
||||
|
@ -759,7 +759,7 @@ class Optimizer(
|
||||
if hasattr(var, "_distributed_container"):
|
||||
# NOTE: If this isn't patched, then there is no `handle` in
|
||||
# `_resource_apply_dense`.
|
||||
distributed_container = var._distributed_container
|
||||
distributed_container = var._distributed_container()
|
||||
assert distributed_container is not None
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
key = distributed_container._unique_id
|
||||
|
Loading…
Reference in New Issue
Block a user