Fix GlobalStepTests to specify the collection

The name is meaningless in v2, even in V1 users should have been
specifying the collections parameter for the variable.

PiperOrigin-RevId: 321867276
Change-Id: I899ee8779c780be2bcc26d997ca5d3edc5eddbe6
This commit is contained in:
Gaurav Jain 2020-07-17 15:27:48 -07:00 committed by TensorFlower Gardener
parent a44a11793c
commit a44821de91

View File

@ -20,14 +20,12 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
from tensorflow.python.training import training_util
@test_util.run_v1_only('b/120545219')
class GlobalStepTest(test.TestCase):
def _assert_global_step(self, global_step, expected_dtype=dtypes.int64):
@ -38,11 +36,12 @@ class GlobalStepTest(test.TestCase):
def test_invalid_dtype(self):
with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step())
variables.Variable(
variables.VariableV1(
0.0,
trainable=False,
dtype=dtypes.float32,
name=ops.GraphKeys.GLOBAL_STEP)
name=ops.GraphKeys.GLOBAL_STEP,
collections=[ops.GraphKeys.GLOBAL_STEP])
self.assertRaisesRegex(TypeError, 'does not have integer type',
training_util.get_global_step)
self.assertRaisesRegex(TypeError, 'does not have integer type',
@ -55,7 +54,8 @@ class GlobalStepTest(test.TestCase):
[0],
trainable=False,
dtype=dtypes.int32,
name=ops.GraphKeys.GLOBAL_STEP)
name=ops.GraphKeys.GLOBAL_STEP,
collections=[ops.GraphKeys.GLOBAL_STEP])
self.assertRaisesRegex(TypeError, 'not scalar',
training_util.get_global_step)
self.assertRaisesRegex(TypeError, 'not scalar',
@ -79,7 +79,8 @@ class GlobalStepTest(test.TestCase):
0,
trainable=False,
dtype=dtypes.int32,
name=ops.GraphKeys.GLOBAL_STEP)
name=ops.GraphKeys.GLOBAL_STEP,
collections=[ops.GraphKeys.GLOBAL_STEP])
self._assert_global_step(
training_util.get_global_step(), expected_dtype=dtypes.int32)
self._assert_global_step(
@ -92,7 +93,6 @@ class GlobalStepTest(test.TestCase):
self._assert_global_step(training_util.get_or_create_global_step(g))
@test_util.run_v1_only('b/120545219')
class GlobalStepReadTest(test.TestCase):
def test_global_step_read_is_none_if_there_is_no_global_step(self):