Test both V1 and V2 variables.

PiperOrigin-RevId: 253063488
This commit is contained in:
Peter Buchlovsky 2019-06-13 10:55:31 -07:00 committed by TensorFlower Gardener
parent c22bc79896
commit 824fab41c0

View File

@ -427,18 +427,19 @@ class DistributionTestBase(test.TestCase):
run_and_concatenate(strategy, i) run_and_concatenate(strategy, i)
def _test_trainable_variable(self, strategy): def _test_trainable_variable(self, strategy):
with strategy.scope(): for cls in [variables.VariableV1, variables.Variable]:
v1 = variables.Variable(1.0) with strategy.scope():
self.assertEqual(True, v1.trainable) v1 = cls(1.0)
self.assertEqual(True, v1.trainable)
v2 = variables.Variable( v2 = cls(
1.0, synchronization=variables.VariableSynchronization.ON_READ) 1.0, synchronization=variables.VariableSynchronization.ON_READ)
self.assertEqual(False, v2.trainable) self.assertEqual(False, v2.trainable)
v3 = variables.Variable( v3 = cls(
1.0, synchronization=variables.VariableSynchronization.ON_READ, 1.0, synchronization=variables.VariableSynchronization.ON_READ,
trainable=True) trainable=True)
self.assertEqual(True, v3.trainable) self.assertEqual(True, v3.trainable)
class OneDeviceDistributionTestBase(test.TestCase): class OneDeviceDistributionTestBase(test.TestCase):