Test both V1 and V2 variables.

PiperOrigin-RevId: 253063488
This commit is contained in:
Peter Buchlovsky 2019-06-13 10:55:31 -07:00 committed by TensorFlower Gardener
parent c22bc79896
commit 824fab41c0

View File

@ -427,15 +427,16 @@ class DistributionTestBase(test.TestCase):
run_and_concatenate(strategy, i) run_and_concatenate(strategy, i)
def _test_trainable_variable(self, strategy): def _test_trainable_variable(self, strategy):
for cls in [variables.VariableV1, variables.Variable]:
with strategy.scope(): with strategy.scope():
v1 = variables.Variable(1.0) v1 = cls(1.0)
self.assertEqual(True, v1.trainable) self.assertEqual(True, v1.trainable)
v2 = variables.Variable( v2 = cls(
1.0, synchronization=variables.VariableSynchronization.ON_READ) 1.0, synchronization=variables.VariableSynchronization.ON_READ)
self.assertEqual(False, v2.trainable) self.assertEqual(False, v2.trainable)
v3 = variables.Variable( v3 = cls(
1.0, synchronization=variables.VariableSynchronization.ON_READ, 1.0, synchronization=variables.VariableSynchronization.ON_READ,
trainable=True) trainable=True)
self.assertEqual(True, v3.trainable) self.assertEqual(True, v3.trainable)