From 824fab41c003276a15c30b73d6deb71f65a6d842 Mon Sep 17 00:00:00 2001
From: Peter Buchlovsky <petebu@google.com>
Date: Thu, 13 Jun 2019 10:55:31 -0700
Subject: [PATCH] Test both V1 and V2 variables.

PiperOrigin-RevId: 253063488
---
 .../python/distribute/strategy_test_lib.py    | 21 ++++++++++---------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py
index b8b6bd8dd3d..78e80704e15 100644
--- a/tensorflow/python/distribute/strategy_test_lib.py
+++ b/tensorflow/python/distribute/strategy_test_lib.py
@@ -427,18 +427,19 @@ class DistributionTestBase(test.TestCase):
         run_and_concatenate(strategy, i)
 
   def _test_trainable_variable(self, strategy):
-    with strategy.scope():
-      v1 = variables.Variable(1.0)
-      self.assertEqual(True, v1.trainable)
+    for cls in [variables.VariableV1, variables.Variable]:
+      with strategy.scope():
+        v1 = cls(1.0)
+        self.assertEqual(True, v1.trainable)
 
-      v2 = variables.Variable(
-          1.0, synchronization=variables.VariableSynchronization.ON_READ)
-      self.assertEqual(False, v2.trainable)
+        v2 = cls(
+            1.0, synchronization=variables.VariableSynchronization.ON_READ)
+        self.assertEqual(False, v2.trainable)
 
-      v3 = variables.Variable(
-          1.0, synchronization=variables.VariableSynchronization.ON_READ,
-          trainable=True)
-      self.assertEqual(True, v3.trainable)
+        v3 = cls(
+            1.0, synchronization=variables.VariableSynchronization.ON_READ,
+            trainable=True)
+        self.assertEqual(True, v3.trainable)
 
 
 class OneDeviceDistributionTestBase(test.TestCase):