[TF DistStrat] Add support for deepcopy on AggregatingVariable (PS)
Tests passing on a multi-GPU system: [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignReturnValueIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testAssignSignature_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testCheckpointing_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testDeepCopy_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testExtendsVariable_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testIsTensorLike_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_eager_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testSelectReplica_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationMEAN_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationONLYFIRSTREPLICA_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE [ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONREAD [ OK ] DistributedVariableTest.testTraceback_test_aggregation_VariableAggregationSUM_distribution_CentralStorage2GPUs_mode_graph_synchronization_VariableSynchronizationONWRITE PiperOrigin-RevId: 327440841 Change-Id: I86b33681b5ad187f5d3f5e8a0d6d374edfafc8a6
This commit is contained in:
parent
e5f12a0ff5
commit
654b45cd56
@ -1136,6 +1136,7 @@ distribute_py_test(
|
||||
":distribute_utils",
|
||||
":packed_distributed_variable",
|
||||
":parameter_server_strategy",
|
||||
":ps_values",
|
||||
":strategy_combinations",
|
||||
":test_util",
|
||||
":tpu_strategy",
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
@ -43,6 +44,36 @@ class AggregatingVariable(variables_lib.Variable, core.Tensor):
|
||||
v._aggregating_container = weakref.ref(self) # pylint: disable=protected-access
|
||||
self._aggregation = aggregation
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
"""Perform a deepcopy of the `AggregatingVariable`.
|
||||
|
||||
Unlike the deepcopy of a regular tf.Variable, this keeps the original
|
||||
strategy and devices of the `AggregatingVariable`. To avoid confusion
|
||||
with the behavior of deepcopy on a regular `Variable` (which does
|
||||
copy into new devices), we only allow a deepcopy of a `AggregatingVariable`
|
||||
within its originating strategy scope.
|
||||
|
||||
Args:
|
||||
memo: The memoization object for `deepcopy`.
|
||||
|
||||
Returns:
|
||||
A deep copy of the current `AggregatingVariable`.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If trying to deepcopy into a different strategy.
|
||||
"""
|
||||
with ds_context.enter_or_assert_strategy(self._distribute_strategy):
|
||||
v = copy.deepcopy(self._v, memo)
|
||||
|
||||
copied_variable = type(self)(
|
||||
strategy=self._distribute_strategy,
|
||||
v=v,
|
||||
aggregation=self._aggregation)
|
||||
|
||||
memo[id(self)] = copied_variable
|
||||
|
||||
return copied_variable
|
||||
|
||||
def get(self):
|
||||
return self._v
|
||||
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribute_utils
|
||||
from tensorflow.python.distribute import packed_distributed_variable as packed
|
||||
from tensorflow.python.distribute import parameter_server_strategy
|
||||
from tensorflow.python.distribute import ps_values
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util as ds_test_util
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
@ -549,12 +550,17 @@ class DistributedVariableTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsInstance(v2, type(v1))
|
||||
self.assertEqual(v1.aggregation, v2.aggregation)
|
||||
self.assertEqual(v1.distribute_strategy, v2.distribute_strategy)
|
||||
self.assertEqual(v1._policy, v2._policy) # pylint: disable=protected-access
|
||||
self.assertEqual(len(v1.values), len(v2.values))
|
||||
for (v1v, v2v) in zip(v1.values, v2.values):
|
||||
self.assertEqual(v1v.device, v2v.device)
|
||||
self.assertNotEqual(id(v1v), id(v2v))
|
||||
self.assertAllEqual(self.evaluate(v1.values), self.evaluate(v2.values))
|
||||
if isinstance(v1, ps_values.AggregatingVariable):
|
||||
self.assertIsInstance(v2.get(), type(v1.get()))
|
||||
self.assertNotEqual(id(v1.get()), id(v2.get()))
|
||||
else:
|
||||
self.assertEqual(v1._policy, v2._policy) # pylint: disable=protected-access
|
||||
self.assertEqual(len(v1.values), len(v2.values))
|
||||
for (v1v, v2v) in zip(v1.values, v2.values):
|
||||
self.assertEqual(v1v.device, v2v.device)
|
||||
self.assertNotEqual(id(v1v), id(v2v))
|
||||
self.assertAllEqual(self.evaluate(v1.values),
|
||||
self.evaluate(v2.values))
|
||||
|
||||
self.evaluate(variables_lib.global_variables_initializer())
|
||||
if not isinstance(distribution.extended, tpu_strategy.TPUExtended):
|
||||
|
Loading…
Reference in New Issue
Block a user