Assert global_step is scalar.
Change: 120276713
This commit is contained in:
parent
97880edb4e
commit
3198d7ef30
tensorflow/contrib/framework/python/ops
@ -34,6 +34,11 @@ __all__ = [
|
||||
|
||||
|
||||
def assert_global_step(global_step_tensor):
|
||||
"""Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
|
||||
|
||||
Args:
|
||||
global_step_tensor: `Tensor` to test.
|
||||
"""
|
||||
if not (isinstance(global_step_tensor, variables.Variable) or
|
||||
isinstance(global_step_tensor, ops.Tensor)):
|
||||
raise TypeError('Existing "global_step" must be a Variable or Tensor.')
|
||||
@ -43,6 +48,11 @@ def assert_global_step(global_step_tensor):
|
||||
'Existing "global_step" does not have integer type: %s' %
|
||||
global_step_tensor.dtype)
|
||||
|
||||
if global_step_tensor.get_shape().ndims != 0:
|
||||
raise TypeError(
|
||||
'Existing "global_step" is not scalar: %s' %
|
||||
global_step_tensor.get_shape())
|
||||
|
||||
|
||||
# TODO(ptucker): Change supervisor to use this when it's migrated to core.
|
||||
def get_global_step(graph=None):
|
||||
|
@ -57,6 +57,18 @@ class GlobalStepTest(tf.test.TestCase):
|
||||
TypeError, "does not have integer type",
|
||||
tf.contrib.framework.get_global_step, g)
|
||||
|
||||
def test_invalid_shape(self):
|
||||
with tf.Graph().as_default() as g:
|
||||
self.assertEquals(None, tf.contrib.framework.get_global_step())
|
||||
tf.Variable(
|
||||
[0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP)
|
||||
self.assertRaisesRegexp(
|
||||
TypeError, "not scalar",
|
||||
tf.contrib.framework.get_global_step)
|
||||
self.assertRaisesRegexp(
|
||||
TypeError, "not scalar",
|
||||
tf.contrib.framework.get_global_step, g)
|
||||
|
||||
def test_create_global_step(self):
|
||||
self.assertEquals(None, tf.contrib.framework.get_global_step())
|
||||
with tf.Graph().as_default() as g:
|
||||
|
Loading…
Reference in New Issue
Block a user