Test both V1 and V2 variables.
PiperOrigin-RevId: 253063488
This commit is contained in:
parent
c22bc79896
commit
824fab41c0
@ -427,18 +427,19 @@ class DistributionTestBase(test.TestCase):
|
||||
run_and_concatenate(strategy, i)
|
||||
|
||||
def _test_trainable_variable(self, strategy):
|
||||
with strategy.scope():
|
||||
v1 = variables.Variable(1.0)
|
||||
self.assertEqual(True, v1.trainable)
|
||||
for cls in [variables.VariableV1, variables.Variable]:
|
||||
with strategy.scope():
|
||||
v1 = cls(1.0)
|
||||
self.assertEqual(True, v1.trainable)
|
||||
|
||||
v2 = variables.Variable(
|
||||
1.0, synchronization=variables.VariableSynchronization.ON_READ)
|
||||
self.assertEqual(False, v2.trainable)
|
||||
v2 = cls(
|
||||
1.0, synchronization=variables.VariableSynchronization.ON_READ)
|
||||
self.assertEqual(False, v2.trainable)
|
||||
|
||||
v3 = variables.Variable(
|
||||
1.0, synchronization=variables.VariableSynchronization.ON_READ,
|
||||
trainable=True)
|
||||
self.assertEqual(True, v3.trainable)
|
||||
v3 = cls(
|
||||
1.0, synchronization=variables.VariableSynchronization.ON_READ,
|
||||
trainable=True)
|
||||
self.assertEqual(True, v3.trainable)
|
||||
|
||||
|
||||
class OneDeviceDistributionTestBase(test.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user