diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index a9442f99bd6..fda258578aa 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -433,10 +433,6 @@ class DistributedVariable(DistributedDelegate, variables_lib.Variable): self._aggregation = aggregation super(DistributedVariable, self).__init__(values) self._common_name = self._primary.name.split(":")[0] - # Use a weakref to make it easy to map from the contained values - # to the container without introducing a reference cycle. - for v in values: - v._distributed_container = weakref.ref(self) # pylint: disable=protected-access # tf.keras keeps track of variables initialized using this attribute. When # tf.keras gets the default session, it initializes all uninitialized vars. # We need to make _keras_initialized a member of DistributedVariable because @@ -774,6 +770,13 @@ def create_mirrored_variable( # pylint: disable=missing-docstring value_list = real_mirrored_creator(**kwargs) var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls result = var_cls(strategy, value_list, aggregation) + # Install the created DistributedVariable as _distributed_container property + # of the underlying variables, to make it easy to map back to the container. + for v in result.values: + # Hold a strong reference to avoid the container from being GC-ed. After + # v = v.assign(), the user code may no longer holds references to the + # original container, since v.assign() returns a new DistributedVariable. + v._distributed_container = result # pylint: disable=protected-access # Add the wrapped variable to the requested collections. # The handling of eager mode and the global step matches @@ -1240,10 +1243,10 @@ def regroup(values, wrap_class=PerReplica, always_wrap=False): # pylint: disable=protected-access assert not isinstance(v0, MirroredVariable), ( "ids = %s, values = %s" % ([id(v) for v in values], values)) - distributed_container = v0._distributed_container() + distributed_container = v0._distributed_container assert distributed_container is not None for v in values[1:]: - assert distributed_container is v._distributed_container() + assert distributed_container is v._distributed_container return distributed_container # pylint: enable=protected-access @@ -1331,7 +1334,7 @@ def value_container(val): # DistributedVariable has _distributed_container defined # but we don't want to return it. not isinstance(val, DistributedVariable)): - container = val._distributed_container() # pylint: disable=protected-access + container = val._distributed_container # pylint: disable=protected-access if container is not None: return container return val diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index f0901296d58..36fc4654a1a 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -365,7 +365,7 @@ def _make_mirrored(): return mirrored -class RegroupAndSelectDeviceTest(test.TestCase): +class RegroupAndSelectDeviceTest(test.TestCase, parameterized.TestCase): def _is_per_replica(self, result, expected, klass=values.PerReplica): self.assertIsInstance(result, klass) @@ -448,12 +448,20 @@ class RegroupAndSelectDeviceTest(test.TestCase): self._is_per_replica(result[0], ("1", "3"), values.PerReplica) self._is_per_replica(result[1], ("2", "4"), values.PerReplica) - def testMirroredContainer(self): - if context.num_gpus() < 1 and context.executing_eagerly(): - self.skipTest("A GPU is not available for this test in eager mode.") - mirrored = _make_mirrored() - result = values.regroup(mirrored.values) - self.assertIs(mirrored, result) + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.mirrored_strategy_with_one_cpu, + ], + mode=["graph", "eager"], + )) + def testMirroredContainer(self, distribution): + with distribution.scope(): + v = variable_scope.variable( + 1., aggregation=variable_scope.VariableAggregation.SUM) + self.assertTrue(values.is_distributed_variable(v)) + self.assertTrue(values.is_distributed_variable(values.regroup(v.values))) def testSameId(self): foo = object() @@ -479,18 +487,7 @@ class RegroupAndSelectDeviceTest(test.TestCase): result = values.regroup((_nested_value("1"),)) # On one device regroup() and select_replica() are basically identity. self.assertEqual(_nested_value("1"), result) - self.assertEqual(_nested_value("1"), - values.select_replica(0, result)) - - # The one exception has to do with MirroredVariables. - d = "/device:CPU:0" - with ops.device(d): - v = variable_scope.get_variable( - name="v", initializer=1., use_resource=True) - mirrored = values.MirroredVariable(None, (v,), - variable_scope.VariableAggregation.SUM) - result = values.regroup((v,)) - self.assertIs(mirrored, result) + self.assertEqual(_nested_value("1"), values.select_replica(0, result)) def testNamedTuple(self): diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py index 1c1db4c6bb2..4cf07033d92 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py @@ -1274,7 +1274,7 @@ def _var_key(var): # pylint: disable=protected-access # Get the distributed variable if it exists. if hasattr(var, "_distributed_container"): - var = var._distributed_container() + var = var._distributed_container if var._in_graph_mode: return var._shared_name return var._unique_id diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index f1a31d01dd4..9732ea04f26 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -759,7 +759,7 @@ class Optimizer( if hasattr(var, "_distributed_container"): # NOTE: If this isn't patched, then there is no `handle` in # `_resource_apply_dense`. - distributed_container = var._distributed_container() + distributed_container = var._distributed_container assert distributed_container is not None if ops.executing_eagerly_outside_functions(): key = distributed_container._unique_id