Install _distributed_container only at variable creation
With a following change we're going to create new DistributedVariable as return value of assign*() and scatter*(). Installing _distributed_container multiple times will be messy. PiperOrigin-RevId: 307753201 Change-Id: I3c87abc301ea32b0169034324a108d6967229889
This commit is contained in:
parent
4137c32842
commit
79abfee5c3
@ -433,10 +433,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable):
|
||||
self._aggregation = aggregation
|
||||
super(DistributedVariable, self).__init__(values)
|
||||
self._common_name = self._primary.name.split(":")[0]
|
||||
# Use a weakref to make it easy to map from the contained values
|
||||
# to the container without introducing a reference cycle.
|
||||
for v in values:
|
||||
v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
|
||||
# tf.keras keeps track of variables initialized using this attribute. When
|
||||
# tf.keras gets the default session, it initializes all uninitialized vars.
|
||||
# We need to make _keras_initialized a member of DistributedVariable because
|
||||
@ -774,6 +770,13 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
|
||||
value_list = real_mirrored_creator(**kwargs)
|
||||
var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
|
||||
result = var_cls(strategy, value_list, aggregation)
|
||||
# Install the created DistributedVariable as _distributed_container property
|
||||
# of the underlying variables, to make it easy to map back to the container.
|
||||
for v in result.values:
|
||||
# Hold a strong reference to avoid the container from being GC-ed. After
|
||||
# v = v.assign(), the user code may no longer holds references to the
|
||||
# original container, since v.assign() returns a new DistributedVariable.
|
||||
v._distributed_container = result # pylint: disable=protected-access
|
||||
|
||||
# Add the wrapped variable to the requested collections.
|
||||
# The handling of eager mode and the global step matches
|
||||
@ -1240,10 +1243,10 @@ def regroup(values, wrap_class=PerReplica, always_wrap=False):
|
||||
# pylint: disable=protected-access
|
||||
assert not isinstance(v0, MirroredVariable), (
|
||||
"ids = %s, values = %s" % ([id(v) for v in values], values))
|
||||
distributed_container = v0._distributed_container()
|
||||
distributed_container = v0._distributed_container
|
||||
assert distributed_container is not None
|
||||
for v in values[1:]:
|
||||
assert distributed_container is v._distributed_container()
|
||||
assert distributed_container is v._distributed_container
|
||||
return distributed_container
|
||||
# pylint: enable=protected-access
|
||||
|
||||
@ -1331,7 +1334,7 @@ def value_container(val):
|
||||
# DistributedVariable has _distributed_container defined
|
||||
# but we don't want to return it.
|
||||
not isinstance(val, DistributedVariable)):
|
||||
container = val._distributed_container() # pylint: disable=protected-access
|
||||
container = val._distributed_container # pylint: disable=protected-access
|
||||
if container is not None:
|
||||
return container
|
||||
return val
|
||||
|
@ -365,7 +365,7 @@ def _make_mirrored():
|
||||
return mirrored
|
||||
|
||||
|
||||
class RegroupAndSelectDeviceTest(test.TestCase):
|
||||
class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _is_per_replica(self, result, expected, klass=values.PerReplica):
|
||||
self.assertIsInstance(result, klass)
|
||||
@ -448,12 +448,20 @@ class RegroupAndSelectDeviceTest(test.TestCase):
|
||||
self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
|
||||
self._is_per_replica(result[1], ("2", "4"), values.PerReplica)
|
||||
|
||||
def testMirroredContainer(self):
|
||||
if context.num_gpus() < 1 and context.executing_eagerly():
|
||||
self.skipTest("A GPU is not available for this test in eager mode.")
|
||||
mirrored = _make_mirrored()
|
||||
result = values.regroup(mirrored.values)
|
||||
self.assertIs(mirrored, result)
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
],
|
||||
mode=["graph", "eager"],
|
||||
))
|
||||
def testMirroredContainer(self, distribution):
|
||||
with distribution.scope():
|
||||
v = variable_scope.variable(
|
||||
1., aggregation=variable_scope.VariableAggregation.SUM)
|
||||
self.assertTrue(values.is_distributed_variable(v))
|
||||
self.assertTrue(values.is_distributed_variable(values.regroup(v.values)))
|
||||
|
||||
def testSameId(self):
|
||||
foo = object()
|
||||
@ -479,18 +487,7 @@ class RegroupAndSelectDeviceTest(test.TestCase):
|
||||
result = values.regroup((_nested_value("1"),))
|
||||
# On one device regroup() and select_replica() are basically identity.
|
||||
self.assertEqual(_nested_value("1"), result)
|
||||
self.assertEqual(_nested_value("1"),
|
||||
values.select_replica(0, result))
|
||||
|
||||
# The one exception has to do with MirroredVariables.
|
||||
d = "/device:CPU:0"
|
||||
with ops.device(d):
|
||||
v = variable_scope.get_variable(
|
||||
name="v", initializer=1., use_resource=True)
|
||||
mirrored = values.MirroredVariable(None, (v,),
|
||||
variable_scope.VariableAggregation.SUM)
|
||||
result = values.regroup((v,))
|
||||
self.assertIs(mirrored, result)
|
||||
self.assertEqual(_nested_value("1"), values.select_replica(0, result))
|
||||
|
||||
def testNamedTuple(self):
|
||||
|
||||
|
@ -1274,7 +1274,7 @@ def _var_key(var):
|
||||
# pylint: disable=protected-access
|
||||
# Get the distributed variable if it exists.
|
||||
if hasattr(var, "_distributed_container"):
|
||||
var = var._distributed_container()
|
||||
var = var._distributed_container
|
||||
if var._in_graph_mode:
|
||||
return var._shared_name
|
||||
return var._unique_id
|
||||
|
@ -759,7 +759,7 @@ class Optimizer(
|
||||
if hasattr(var, "_distributed_container"):
|
||||
# NOTE: If this isn't patched, then there is no `handle` in
|
||||
# `_resource_apply_dense`.
|
||||
distributed_container = var._distributed_container()
|
||||
distributed_container = var._distributed_container
|
||||
assert distributed_container is not None
|
||||
if ops.executing_eagerly_outside_functions():
|
||||
key = distributed_container._unique_id
|
||||
|
Loading…
x
Reference in New Issue
Block a user