Assert global_step is scalar.
Change: 120276713
This commit is contained in:
parent
97880edb4e
commit
3198d7ef30
tensorflow/contrib/framework/python/ops
@ -34,6 +34,11 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def assert_global_step(global_step_tensor):
|
def assert_global_step(global_step_tensor):
|
||||||
|
"""Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
global_step_tensor: `Tensor` to test.
|
||||||
|
"""
|
||||||
if not (isinstance(global_step_tensor, variables.Variable) or
|
if not (isinstance(global_step_tensor, variables.Variable) or
|
||||||
isinstance(global_step_tensor, ops.Tensor)):
|
isinstance(global_step_tensor, ops.Tensor)):
|
||||||
raise TypeError('Existing "global_step" must be a Variable or Tensor.')
|
raise TypeError('Existing "global_step" must be a Variable or Tensor.')
|
||||||
@ -43,6 +48,11 @@ def assert_global_step(global_step_tensor):
|
|||||||
'Existing "global_step" does not have integer type: %s' %
|
'Existing "global_step" does not have integer type: %s' %
|
||||||
global_step_tensor.dtype)
|
global_step_tensor.dtype)
|
||||||
|
|
||||||
|
if global_step_tensor.get_shape().ndims != 0:
|
||||||
|
raise TypeError(
|
||||||
|
'Existing "global_step" is not scalar: %s' %
|
||||||
|
global_step_tensor.get_shape())
|
||||||
|
|
||||||
|
|
||||||
# TODO(ptucker): Change supervisor to use this when it's migrated to core.
|
# TODO(ptucker): Change supervisor to use this when it's migrated to core.
|
||||||
def get_global_step(graph=None):
|
def get_global_step(graph=None):
|
||||||
|
@ -57,6 +57,18 @@ class GlobalStepTest(tf.test.TestCase):
|
|||||||
TypeError, "does not have integer type",
|
TypeError, "does not have integer type",
|
||||||
tf.contrib.framework.get_global_step, g)
|
tf.contrib.framework.get_global_step, g)
|
||||||
|
|
||||||
|
def test_invalid_shape(self):
|
||||||
|
with tf.Graph().as_default() as g:
|
||||||
|
self.assertEquals(None, tf.contrib.framework.get_global_step())
|
||||||
|
tf.Variable(
|
||||||
|
[0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP)
|
||||||
|
self.assertRaisesRegexp(
|
||||||
|
TypeError, "not scalar",
|
||||||
|
tf.contrib.framework.get_global_step)
|
||||||
|
self.assertRaisesRegexp(
|
||||||
|
TypeError, "not scalar",
|
||||||
|
tf.contrib.framework.get_global_step, g)
|
||||||
|
|
||||||
def test_create_global_step(self):
|
def test_create_global_step(self):
|
||||||
self.assertEquals(None, tf.contrib.framework.get_global_step())
|
self.assertEquals(None, tf.contrib.framework.get_global_step())
|
||||||
with tf.Graph().as_default() as g:
|
with tf.Graph().as_default() as g:
|
||||||
|
Loading…
Reference in New Issue
Block a user