Assert global_step is scalar.

Change: 120276713
This commit is contained in:
A. Unique TensorFlower 2016-04-19 14:24:16 -08:00 committed by TensorFlower Gardener
parent 97880edb4e
commit 3198d7ef30
2 changed files with 22 additions and 0 deletions
tensorflow/contrib/framework/python/ops

View File

@ -34,6 +34,11 @@ __all__ = [
def assert_global_step(global_step_tensor):
"""Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
Args:
global_step_tensor: `Tensor` to test.
"""
if not (isinstance(global_step_tensor, variables.Variable) or
isinstance(global_step_tensor, ops.Tensor)):
raise TypeError('Existing "global_step" must be a Variable or Tensor.')
@ -43,6 +48,11 @@ def assert_global_step(global_step_tensor):
'Existing "global_step" does not have integer type: %s' %
global_step_tensor.dtype)
if global_step_tensor.get_shape().ndims != 0:
raise TypeError(
'Existing "global_step" is not scalar: %s' %
global_step_tensor.get_shape())
# TODO(ptucker): Change supervisor to use this when it's migrated to core.
def get_global_step(graph=None):

View File

@ -57,6 +57,18 @@ class GlobalStepTest(tf.test.TestCase):
TypeError, "does not have integer type",
tf.contrib.framework.get_global_step, g)
def test_invalid_shape(self):
with tf.Graph().as_default() as g:
self.assertEquals(None, tf.contrib.framework.get_global_step())
tf.Variable(
[0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP)
self.assertRaisesRegexp(
TypeError, "not scalar",
tf.contrib.framework.get_global_step)
self.assertRaisesRegexp(
TypeError, "not scalar",
tf.contrib.framework.get_global_step, g)
def test_create_global_step(self):
self.assertEquals(None, tf.contrib.framework.get_global_step())
with tf.Graph().as_default() as g: