Fix GlobalStepTests to specify the collection
The name is meaningless in v2, even in V1 users should have been specifying the collections parameter for the variable. PiperOrigin-RevId: 321867276 Change-Id: I899ee8779c780be2bcc26d997ca5d3edc5eddbe6
This commit is contained in:
parent
a44a11793c
commit
a44821de91
@ -20,14 +20,12 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
class GlobalStepTest(test.TestCase):
|
||||
|
||||
def _assert_global_step(self, global_step, expected_dtype=dtypes.int64):
|
||||
@ -38,11 +36,12 @@ class GlobalStepTest(test.TestCase):
|
||||
def test_invalid_dtype(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
self.assertIsNone(training_util.get_global_step())
|
||||
variables.Variable(
|
||||
variables.VariableV1(
|
||||
0.0,
|
||||
trainable=False,
|
||||
dtype=dtypes.float32,
|
||||
name=ops.GraphKeys.GLOBAL_STEP)
|
||||
name=ops.GraphKeys.GLOBAL_STEP,
|
||||
collections=[ops.GraphKeys.GLOBAL_STEP])
|
||||
self.assertRaisesRegex(TypeError, 'does not have integer type',
|
||||
training_util.get_global_step)
|
||||
self.assertRaisesRegex(TypeError, 'does not have integer type',
|
||||
@ -55,7 +54,8 @@ class GlobalStepTest(test.TestCase):
|
||||
[0],
|
||||
trainable=False,
|
||||
dtype=dtypes.int32,
|
||||
name=ops.GraphKeys.GLOBAL_STEP)
|
||||
name=ops.GraphKeys.GLOBAL_STEP,
|
||||
collections=[ops.GraphKeys.GLOBAL_STEP])
|
||||
self.assertRaisesRegex(TypeError, 'not scalar',
|
||||
training_util.get_global_step)
|
||||
self.assertRaisesRegex(TypeError, 'not scalar',
|
||||
@ -79,7 +79,8 @@ class GlobalStepTest(test.TestCase):
|
||||
0,
|
||||
trainable=False,
|
||||
dtype=dtypes.int32,
|
||||
name=ops.GraphKeys.GLOBAL_STEP)
|
||||
name=ops.GraphKeys.GLOBAL_STEP,
|
||||
collections=[ops.GraphKeys.GLOBAL_STEP])
|
||||
self._assert_global_step(
|
||||
training_util.get_global_step(), expected_dtype=dtypes.int32)
|
||||
self._assert_global_step(
|
||||
@ -92,7 +93,6 @@ class GlobalStepTest(test.TestCase):
|
||||
self._assert_global_step(training_util.get_or_create_global_step(g))
|
||||
|
||||
|
||||
@test_util.run_v1_only('b/120545219')
|
||||
class GlobalStepReadTest(test.TestCase):
|
||||
|
||||
def test_global_step_read_is_none_if_there_is_no_global_step(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user