diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py
index 488bd2ebcdc..9f310bb2a65 100644
--- a/tensorflow/python/training/slot_creator.py
+++ b/tensorflow/python/training/slot_creator.py
@@ -15,9 +15,9 @@
 
 """Standard functions for creating slots.
 
-A slot is a `Variable` created with the same shape as a primary variable or
-`Tensor`. A slot is always scoped in the namespace of the primary object and
-typically has the same device and type.
+A slot is a `Variable` created with the same first m-dimension as a primary
+variable or `Tensor`. A slot is always scoped in the namespace of the primary
+object and typically has the same device and type.
 
 Slots are typically used as accumulators to track values associated with
 the primary object:
@@ -84,11 +84,19 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
     # remove "'linear//weight' + '/'" and ':0'.
     real_slot_name = slot.name[len(primary.op.name + "/"):-2]
     slice_info = primary._save_slice_info
-    slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
-        slice_info.full_name + "/" + real_slot_name,
-        slice_info.full_shape[:],
-        slice_info.var_offset[:],
-        slice_info.var_shape[:]))
+    # support slot's shape not same as primary's shape
+    # example: primary's shape = [10, 20, 30], slot's shape =
+    # None, [], [10], [10, 20] or [10, 20, 30] is allowed
+    # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary
+    # slot's shape = [], don't set slot's slice_info
+    # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims
+    n = slot.shape.ndims
+    if n is None or n > 0:
+      slot._set_save_slice_info(
+          variables.Variable.SaveSliceInfo(
+              slice_info.full_name + "/" + real_slot_name,
+              slice_info.full_shape[:n], slice_info.var_offset[:n],
+              slice_info.var_shape[:n]))
   # pylint: enable=protected-access
   return slot
 
diff --git a/tensorflow/python/training/slot_creator_test.py b/tensorflow/python/training/slot_creator_test.py
index ec2eec39324..2465afdae9c 100644
--- a/tensorflow/python/training/slot_creator_test.py
+++ b/tensorflow/python/training/slot_creator_test.py
@@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import partitioned_variables
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
@@ -134,6 +135,46 @@ class SlotCreatorTest(test.TestCase):
         slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
         self.assertEqual("scope/scope/var/slot", slot.op.name)
 
+  @test_util.run_deprecated_v1
+  def testCreateSlotFromFirstMDimensionVariable(self):
+    with self.test_session():
+      s = variables.Variable([1.0, 2.5], name="var")
+      p_v = variable_scope.get_variable(
+          "var",
+          shape=[2, 2],
+          partitioner=partitioned_variables.fixed_size_partitioner(2))
+      for i, v in enumerate(p_v):
+        slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")
+        si = slot._save_slice_info
+
+        variables.global_variables_initializer().run()
+
+        self.assertEqual("var/part_%d/slot" % i, slot.op.name)
+        self.assertEqual([2], slot.get_shape().as_list())
+        self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
+        self.assertAllEqual([1.0, 2.5], slot.eval())
+        self.assertAllEqual([2], si.full_shape)
+        self.assertAllEqual([i], si.var_offset)
+        self.assertAllEqual([1], si.var_shape)
+
+  @test_util.run_deprecated_v1
+  def testCreateSlotFromScalarVariable(self):
+    with self.test_session():
+      s = variables.Variable(1.0, name="var")
+      p_v = variable_scope.get_variable(
+          "var",
+          shape=[2, 2],
+          partitioner=partitioned_variables.fixed_size_partitioner(2))
+      for i, v in enumerate(p_v):
+        slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")
+
+        variables.global_variables_initializer().run()
+
+        self.assertEqual("var/part_%d/slot" % i, slot.op.name)
+        self.assertEqual([], slot.get_shape().as_list())
+        self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
+        self.assertAllEqual(1.0, slot.eval())
+
 
 if __name__ == "__main__":
   test.main()