From 824fab41c003276a15c30b73d6deb71f65a6d842 Mon Sep 17 00:00:00 2001 From: Peter Buchlovsky <petebu@google.com> Date: Thu, 13 Jun 2019 10:55:31 -0700 Subject: [PATCH] Test both V1 and V2 variables. PiperOrigin-RevId: 253063488 --- .../python/distribute/strategy_test_lib.py | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py index b8b6bd8dd3d..78e80704e15 100644 --- a/tensorflow/python/distribute/strategy_test_lib.py +++ b/tensorflow/python/distribute/strategy_test_lib.py @@ -427,18 +427,19 @@ class DistributionTestBase(test.TestCase): run_and_concatenate(strategy, i) def _test_trainable_variable(self, strategy): - with strategy.scope(): - v1 = variables.Variable(1.0) - self.assertEqual(True, v1.trainable) + for cls in [variables.VariableV1, variables.Variable]: + with strategy.scope(): + v1 = cls(1.0) + self.assertEqual(True, v1.trainable) - v2 = variables.Variable( - 1.0, synchronization=variables.VariableSynchronization.ON_READ) - self.assertEqual(False, v2.trainable) + v2 = cls( + 1.0, synchronization=variables.VariableSynchronization.ON_READ) + self.assertEqual(False, v2.trainable) - v3 = variables.Variable( - 1.0, synchronization=variables.VariableSynchronization.ON_READ, - trainable=True) - self.assertEqual(True, v3.trainable) + v3 = cls( + 1.0, synchronization=variables.VariableSynchronization.ON_READ, + trainable=True) + self.assertEqual(True, v3.trainable) class OneDeviceDistributionTestBase(test.TestCase):