From cff5492c76c2548339096eab1c36aee9fddc8d28 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Thu, 9 Jul 2020 16:10:48 -0700 Subject: [PATCH] Update v1 only slot_creator_test with graph scope. PiperOrigin-RevId: 320494529 Change-Id: If3b188c80a4e9f874d14b8e20eaf793955ef0fa0 --- .../python/training/slot_creator_test.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/training/slot_creator_test.py b/tensorflow/python/training/slot_creator_test.py index 88192811c8b..a35d1af3084 100644 --- a/tensorflow/python/training/slot_creator_test.py +++ b/tensorflow/python/training/slot_creator_test.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops @@ -33,9 +32,9 @@ from tensorflow.python.training import slot_creator class SlotCreatorTest(test.TestCase): - @test_util.run_v1_only("b/120545219") def testCreateSlotFromVariable(self): - with self.cached_session(): + # slot_creator is used only in optimizer V1. + with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") @@ -46,9 +45,9 @@ class SlotCreatorTest(test.TestCase): self.assertEqual(dtypes.float32, slot.dtype.base_dtype) self.assertAllEqual([1.0, 2.5], self.evaluate(slot)) - @test_util.run_deprecated_v1 def testCreateSlotFromTensor(self): - with self.cached_session(): + # slot_creator is used only in optimizer V1. + with ops.Graph().as_default(), self.cached_session(): v = constant_op.constant([1.0, 2.5], name="const") slot = slot_creator.create_slot(v, v * 2, name="slot") @@ -59,9 +58,9 @@ class SlotCreatorTest(test.TestCase): self.assertEqual(dtypes.float32, slot.dtype.base_dtype) self.assertAllEqual([2.0, 5.0], self.evaluate(slot)) - @test_util.run_deprecated_v1 def testCreateZerosSlotFromVariable(self): - with self.cached_session(): + # slot_creator is used only in optimizer V1. + with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot( @@ -74,9 +73,9 @@ class SlotCreatorTest(test.TestCase): self.assertEqual(dtypes.float64, slot.dtype.base_dtype) self.assertAllEqual([0.0, 0.0], self.evaluate(slot)) - @test_util.run_v1_only("b/120545219") def testCreateZerosSlotFromDynamicShapedVariable(self): - with self.cached_session(): + # slot_creator is used only in optimizer V1. + with ops.Graph().as_default(), self.cached_session(): dyn_shape = constant_op.constant([2], dtype=dtypes.int32) dyn_shape = array_ops.placeholder_with_default(dyn_shape, shape=[None]) @@ -96,9 +95,9 @@ class SlotCreatorTest(test.TestCase): self.assertEqual(dtypes.float64, slot.dtype.base_dtype) self.assertAllEqual([0.0, 0.0], self.evaluate(slot)) - @test_util.run_deprecated_v1 def testCreateZerosSlotFromTensor(self): - with self.cached_session(): + # slot_creator is used only in optimizer V1. + with ops.Graph().as_default(), self.cached_session(): v = constant_op.constant([1.0, 2.5], name="const") with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot") @@ -110,9 +109,9 @@ class SlotCreatorTest(test.TestCase): self.assertEqual(dtypes.float32, slot.dtype.base_dtype) self.assertAllEqual([0.0, 0.0], self.evaluate(slot)) - @test_util.run_deprecated_v1 def testCreateZerosSlotFromDynamicShapedTensor(self): - with self.cached_session(): + # slot_creator is used only in optimizer V1. + with ops.Graph().as_default(), self.cached_session(): v = random_ops.random_uniform([2], dtype=dtypes.float64) v = array_ops.placeholder_with_default(v, shape=[None], name="const") with ops.control_dependencies(None): @@ -126,18 +125,18 @@ class SlotCreatorTest(test.TestCase): self.assertEqual(dtypes.float64, slot.dtype.base_dtype) self.assertAllEqual([0.0, 0.0], self.evaluate(slot)) - @test_util.run_v1_only("b/120545219") def testCreateSlotFromVariableRespectsScope(self): # See discussion on #2740. - with self.cached_session(): + # slot_creator is used only in optimizer V1. + with ops.Graph().as_default(), self.cached_session(): with variable_scope.variable_scope("scope"): v = variables.Variable([1.0, 2.5], name="var") slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") self.assertEqual("scope/scope/var/slot", slot.op.name) - @test_util.run_deprecated_v1 def testCreateSlotFromFirstMDimensionVariable(self): - with self.test_session(): + # slot_creator is used only in optimizer V1. + with ops.Graph().as_default(), self.test_session(): s = variables.Variable([1.0, 2.5], name="var") p_v = variable_scope.get_variable( "var", @@ -157,9 +156,9 @@ class SlotCreatorTest(test.TestCase): self.assertAllEqual([i], si.var_offset) self.assertAllEqual([1], si.var_shape) - @test_util.run_deprecated_v1 def testCreateSlotFromScalarVariable(self): - with self.test_session(): + # slot_creator is used only in optimizer V1. + with ops.Graph().as_default(), self.test_session(): s = variables.Variable(1.0, name="var") p_v = variable_scope.get_variable( "var",