Fix GlobalStepTests to specify the collection

The name is meaningless in v2, even in V1 users should have been
specifying the collections parameter for the variable.

PiperOrigin-RevId: 321867276
Change-Id: I899ee8779c780be2bcc26d997ca5d3edc5eddbe6
This commit is contained in:
Gaurav Jain 2020-07-17 15:27:48 -07:00 committed by TensorFlower Gardener
parent a44a11793c
commit a44821de91

View File

@ -20,14 +20,12 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session from tensorflow.python.training import monitored_session
from tensorflow.python.training import training_util from tensorflow.python.training import training_util
@test_util.run_v1_only('b/120545219')
class GlobalStepTest(test.TestCase): class GlobalStepTest(test.TestCase):
def _assert_global_step(self, global_step, expected_dtype=dtypes.int64): def _assert_global_step(self, global_step, expected_dtype=dtypes.int64):
@ -38,11 +36,12 @@ class GlobalStepTest(test.TestCase):
def test_invalid_dtype(self): def test_invalid_dtype(self):
with ops.Graph().as_default() as g: with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step()) self.assertIsNone(training_util.get_global_step())
variables.Variable( variables.VariableV1(
0.0, 0.0,
trainable=False, trainable=False,
dtype=dtypes.float32, dtype=dtypes.float32,
name=ops.GraphKeys.GLOBAL_STEP) name=ops.GraphKeys.GLOBAL_STEP,
collections=[ops.GraphKeys.GLOBAL_STEP])
self.assertRaisesRegex(TypeError, 'does not have integer type', self.assertRaisesRegex(TypeError, 'does not have integer type',
training_util.get_global_step) training_util.get_global_step)
self.assertRaisesRegex(TypeError, 'does not have integer type', self.assertRaisesRegex(TypeError, 'does not have integer type',
@ -55,7 +54,8 @@ class GlobalStepTest(test.TestCase):
[0], [0],
trainable=False, trainable=False,
dtype=dtypes.int32, dtype=dtypes.int32,
name=ops.GraphKeys.GLOBAL_STEP) name=ops.GraphKeys.GLOBAL_STEP,
collections=[ops.GraphKeys.GLOBAL_STEP])
self.assertRaisesRegex(TypeError, 'not scalar', self.assertRaisesRegex(TypeError, 'not scalar',
training_util.get_global_step) training_util.get_global_step)
self.assertRaisesRegex(TypeError, 'not scalar', self.assertRaisesRegex(TypeError, 'not scalar',
@ -79,7 +79,8 @@ class GlobalStepTest(test.TestCase):
0, 0,
trainable=False, trainable=False,
dtype=dtypes.int32, dtype=dtypes.int32,
name=ops.GraphKeys.GLOBAL_STEP) name=ops.GraphKeys.GLOBAL_STEP,
collections=[ops.GraphKeys.GLOBAL_STEP])
self._assert_global_step( self._assert_global_step(
training_util.get_global_step(), expected_dtype=dtypes.int32) training_util.get_global_step(), expected_dtype=dtypes.int32)
self._assert_global_step( self._assert_global_step(
@ -92,7 +93,6 @@ class GlobalStepTest(test.TestCase):
self._assert_global_step(training_util.get_or_create_global_step(g)) self._assert_global_step(training_util.get_or_create_global_step(g))
@test_util.run_v1_only('b/120545219')
class GlobalStepReadTest(test.TestCase): class GlobalStepReadTest(test.TestCase):
def test_global_step_read_is_none_if_there_is_no_global_step(self): def test_global_step_read_is_none_if_there_is_no_global_step(self):