From 3198d7ef3063ac6fdbbf9ca4fdd3e059f6100ed0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 Apr 2016 14:24:16 -0800 Subject: [PATCH] Assert global_step is scalar. Change: 120276713 --- tensorflow/contrib/framework/python/ops/variables.py | 10 ++++++++++ .../contrib/framework/python/ops/variables_test.py | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index d05b2e9eca6..e1a7ce46ce6 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -34,6 +34,11 @@ __all__ = [ def assert_global_step(global_step_tensor): + """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`. + + Args: + global_step_tensor: `Tensor` to test. + """ if not (isinstance(global_step_tensor, variables.Variable) or isinstance(global_step_tensor, ops.Tensor)): raise TypeError('Existing "global_step" must be a Variable or Tensor.') @@ -43,6 +48,11 @@ def assert_global_step(global_step_tensor): 'Existing "global_step" does not have integer type: %s' % global_step_tensor.dtype) + if global_step_tensor.get_shape().ndims != 0: + raise TypeError( + 'Existing "global_step" is not scalar: %s' % + global_step_tensor.get_shape()) + # TODO(ptucker): Change supervisor to use this when it's migrated to core. def get_global_step(graph=None): diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 75e1cce1a49..af2e36c9ec4 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -57,6 +57,18 @@ class GlobalStepTest(tf.test.TestCase): TypeError, "does not have integer type", tf.contrib.framework.get_global_step, g) + def test_invalid_shape(self): + with tf.Graph().as_default() as g: + self.assertEquals(None, tf.contrib.framework.get_global_step()) + tf.Variable( + [0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP) + self.assertRaisesRegexp( + TypeError, "not scalar", + tf.contrib.framework.get_global_step) + self.assertRaisesRegexp( + TypeError, "not scalar", + tf.contrib.framework.get_global_step, g) + def test_create_global_step(self): self.assertEquals(None, tf.contrib.framework.get_global_step()) with tf.Graph().as_default() as g: