diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index d05b2e9eca6..e1a7ce46ce6 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -34,6 +34,11 @@ __all__ = [
 
 
 def assert_global_step(global_step_tensor):
+  """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
+
+  Args:
+    global_step_tensor: `Tensor` to test.
+  """
   if not (isinstance(global_step_tensor, variables.Variable) or
           isinstance(global_step_tensor, ops.Tensor)):
     raise TypeError('Existing "global_step" must be a Variable or Tensor.')
@@ -43,6 +48,11 @@ def assert_global_step(global_step_tensor):
         'Existing "global_step" does not have integer type: %s' %
         global_step_tensor.dtype)
 
+  if global_step_tensor.get_shape().ndims != 0:
+    raise TypeError(
+        'Existing "global_step" is not scalar: %s' %
+        global_step_tensor.get_shape())
+
 
 # TODO(ptucker): Change supervisor to use this when it's migrated to core.
 def get_global_step(graph=None):
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index 75e1cce1a49..af2e36c9ec4 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -57,6 +57,18 @@ class GlobalStepTest(tf.test.TestCase):
         TypeError, "does not have integer type",
         tf.contrib.framework.get_global_step, g)
 
+  def test_invalid_shape(self):
+    with tf.Graph().as_default() as g:
+      self.assertEquals(None, tf.contrib.framework.get_global_step())
+      tf.Variable(
+          [0], trainable=False, dtype=tf.int32, name=tf.GraphKeys.GLOBAL_STEP)
+      self.assertRaisesRegexp(
+          TypeError, "not scalar",
+          tf.contrib.framework.get_global_step)
+    self.assertRaisesRegexp(
+        TypeError, "not scalar",
+        tf.contrib.framework.get_global_step, g)
+
   def test_create_global_step(self):
     self.assertEquals(None, tf.contrib.framework.get_global_step())
     with tf.Graph().as_default() as g: