diff --git a/tensorflow/python/training/training_util_test.py b/tensorflow/python/training/training_util_test.py index 5049d6e00a0..cf5942287a1 100644 --- a/tensorflow/python/training/training_util_test.py +++ b/tensorflow/python/training/training_util_test.py @@ -20,14 +20,12 @@ from __future__ import print_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import monitored_session from tensorflow.python.training import training_util -@test_util.run_v1_only('b/120545219') class GlobalStepTest(test.TestCase): def _assert_global_step(self, global_step, expected_dtype=dtypes.int64): @@ -38,11 +36,12 @@ class GlobalStepTest(test.TestCase): def test_invalid_dtype(self): with ops.Graph().as_default() as g: self.assertIsNone(training_util.get_global_step()) - variables.Variable( + variables.VariableV1( 0.0, trainable=False, dtype=dtypes.float32, - name=ops.GraphKeys.GLOBAL_STEP) + name=ops.GraphKeys.GLOBAL_STEP, + collections=[ops.GraphKeys.GLOBAL_STEP]) self.assertRaisesRegex(TypeError, 'does not have integer type', training_util.get_global_step) self.assertRaisesRegex(TypeError, 'does not have integer type', @@ -55,7 +54,8 @@ class GlobalStepTest(test.TestCase): [0], trainable=False, dtype=dtypes.int32, - name=ops.GraphKeys.GLOBAL_STEP) + name=ops.GraphKeys.GLOBAL_STEP, + collections=[ops.GraphKeys.GLOBAL_STEP]) self.assertRaisesRegex(TypeError, 'not scalar', training_util.get_global_step) self.assertRaisesRegex(TypeError, 'not scalar', @@ -79,7 +79,8 @@ class GlobalStepTest(test.TestCase): 0, trainable=False, dtype=dtypes.int32, - name=ops.GraphKeys.GLOBAL_STEP) + name=ops.GraphKeys.GLOBAL_STEP, + collections=[ops.GraphKeys.GLOBAL_STEP]) self._assert_global_step( training_util.get_global_step(), expected_dtype=dtypes.int32) self._assert_global_step( @@ -92,7 +93,6 @@ class GlobalStepTest(test.TestCase): self._assert_global_step(training_util.get_or_create_global_step(g)) -@test_util.run_v1_only('b/120545219') class GlobalStepReadTest(test.TestCase): def test_global_step_read_is_none_if_there_is_no_global_step(self):