Test both V1 and V2 variables.
PiperOrigin-RevId: 253063488
This commit is contained in:
parent
c22bc79896
commit
824fab41c0
@ -427,18 +427,19 @@ class DistributionTestBase(test.TestCase):
|
|||||||
run_and_concatenate(strategy, i)
|
run_and_concatenate(strategy, i)
|
||||||
|
|
||||||
def _test_trainable_variable(self, strategy):
|
def _test_trainable_variable(self, strategy):
|
||||||
with strategy.scope():
|
for cls in [variables.VariableV1, variables.Variable]:
|
||||||
v1 = variables.Variable(1.0)
|
with strategy.scope():
|
||||||
self.assertEqual(True, v1.trainable)
|
v1 = cls(1.0)
|
||||||
|
self.assertEqual(True, v1.trainable)
|
||||||
|
|
||||||
v2 = variables.Variable(
|
v2 = cls(
|
||||||
1.0, synchronization=variables.VariableSynchronization.ON_READ)
|
1.0, synchronization=variables.VariableSynchronization.ON_READ)
|
||||||
self.assertEqual(False, v2.trainable)
|
self.assertEqual(False, v2.trainable)
|
||||||
|
|
||||||
v3 = variables.Variable(
|
v3 = cls(
|
||||||
1.0, synchronization=variables.VariableSynchronization.ON_READ,
|
1.0, synchronization=variables.VariableSynchronization.ON_READ,
|
||||||
trainable=True)
|
trainable=True)
|
||||||
self.assertEqual(True, v3.trainable)
|
self.assertEqual(True, v3.trainable)
|
||||||
|
|
||||||
|
|
||||||
class OneDeviceDistributionTestBase(test.TestCase):
|
class OneDeviceDistributionTestBase(test.TestCase):
|
||||||
|
Loading…
Reference in New Issue
Block a user