Copy ResourceVariable aggregation and synchronization when deepcopying a

variable

PiperOrigin-RevId: 355411533
Change-Id: I65296ba0315cda40374778870b2611cc58fea76e
This commit is contained in:
Ran Chen 2021-02-03 09:19:24 -08:00 committed by TensorFlower Gardener
parent 8130665856
commit 0d7fe54134
2 changed files with 11 additions and 3 deletions

View File

@ -134,13 +134,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
def testEagerDeepCopy(self):
with context.eager_mode():
init_value = np.ones((4, 4, 4))
variable = resource_variable_ops.ResourceVariable(init_value,
name="init")
variable = resource_variable_ops.ResourceVariable(
init_value,
name="init",
synchronization=variables.VariableSynchronization.ON_READ,
aggregation=variables.VariableAggregation.SUM)
copied_variable = copy.deepcopy(variable)
self.assertEqual(variable.name, copied_variable.name)
self.assertEqual(variable.shape, copied_variable.shape)
self.assertEqual(variable.device, copied_variable.device)
self.assertEqual(variable.synchronization,
copied_variable.synchronization)
self.assertEqual(variable.aggregation, copied_variable.aggregation)
# The copied variable should have the same value as the original.
self.assertAllEqual(variable.numpy(), copied_variable.numpy())

View File

@ -502,7 +502,9 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
constraint=self._constraint,
dtype=self._dtype,
name=self._shared_name,
distribute_strategy=self._distribute_strategy)
distribute_strategy=self._distribute_strategy,
synchronization=self.synchronization,
aggregation=self.aggregation)
memo[self._unique_id] = copied_variable
return copied_variable