Copy ResourceVariable aggregation and synchronization when deepcopying a
variable PiperOrigin-RevId: 355411533 Change-Id: I65296ba0315cda40374778870b2611cc58fea76e
This commit is contained in:
parent
8130665856
commit
0d7fe54134
@ -134,13 +134,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
|
||||
def testEagerDeepCopy(self):
|
||||
with context.eager_mode():
|
||||
init_value = np.ones((4, 4, 4))
|
||||
variable = resource_variable_ops.ResourceVariable(init_value,
|
||||
name="init")
|
||||
variable = resource_variable_ops.ResourceVariable(
|
||||
init_value,
|
||||
name="init",
|
||||
synchronization=variables.VariableSynchronization.ON_READ,
|
||||
aggregation=variables.VariableAggregation.SUM)
|
||||
|
||||
copied_variable = copy.deepcopy(variable)
|
||||
self.assertEqual(variable.name, copied_variable.name)
|
||||
self.assertEqual(variable.shape, copied_variable.shape)
|
||||
self.assertEqual(variable.device, copied_variable.device)
|
||||
self.assertEqual(variable.synchronization,
|
||||
copied_variable.synchronization)
|
||||
self.assertEqual(variable.aggregation, copied_variable.aggregation)
|
||||
|
||||
# The copied variable should have the same value as the original.
|
||||
self.assertAllEqual(variable.numpy(), copied_variable.numpy())
|
||||
|
@ -502,7 +502,9 @@ class BaseResourceVariable(variables.VariableV1, core.Tensor):
|
||||
constraint=self._constraint,
|
||||
dtype=self._dtype,
|
||||
name=self._shared_name,
|
||||
distribute_strategy=self._distribute_strategy)
|
||||
distribute_strategy=self._distribute_strategy,
|
||||
synchronization=self.synchronization,
|
||||
aggregation=self.aggregation)
|
||||
memo[self._unique_id] = copied_variable
|
||||
return copied_variable
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user