diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index 488bd2ebcdc..9f310bb2a65 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -15,9 +15,9 @@ """Standard functions for creating slots. -A slot is a `Variable` created with the same shape as a primary variable or -`Tensor`. A slot is always scoped in the namespace of the primary object and -typically has the same device and type. +A slot is a `Variable` created with the same first m-dimension as a primary +variable or `Tensor`. A slot is always scoped in the namespace of the primary +object and typically has the same device and type. Slots are typically used as accumulators to track values associated with the primary object: @@ -84,11 +84,19 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): # remove "'linear//weight' + '/'" and ':0'. real_slot_name = slot.name[len(primary.op.name + "/"):-2] slice_info = primary._save_slice_info - slot._set_save_slice_info(variables.Variable.SaveSliceInfo( - slice_info.full_name + "/" + real_slot_name, - slice_info.full_shape[:], - slice_info.var_offset[:], - slice_info.var_shape[:])) + # support slot's shape not same as primary's shape + # example: primary's shape = [10, 20, 30], slot's shape = + # None, [], [10], [10, 20] or [10, 20, 30] is allowed + # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary + # slot's shape = [], don't set slot's slice_info + # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims + n = slot.shape.ndims + if n is None or n > 0: + slot._set_save_slice_info( + variables.Variable.SaveSliceInfo( + slice_info.full_name + "/" + real_slot_name, + slice_info.full_shape[:n], slice_info.var_offset[:n], + slice_info.var_shape[:n])) # pylint: enable=protected-access return slot diff --git a/tensorflow/python/training/slot_creator_test.py b/tensorflow/python/training/slot_creator_test.py index ec2eec39324..2465afdae9c 100644 --- a/tensorflow/python/training/slot_creator_test.py +++ b/tensorflow/python/training/slot_creator_test.py @@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -134,6 +135,46 @@ class SlotCreatorTest(test.TestCase): slot = slot_creator.create_slot(v, v.initialized_value(), name="slot") self.assertEqual("scope/scope/var/slot", slot.op.name) + @test_util.run_deprecated_v1 + def testCreateSlotFromFirstMDimensionVariable(self): + with self.test_session(): + s = variables.Variable([1.0, 2.5], name="var") + p_v = variable_scope.get_variable( + "var", + shape=[2, 2], + partitioner=partitioned_variables.fixed_size_partitioner(2)) + for i, v in enumerate(p_v): + slot = slot_creator.create_slot(v, s.initialized_value(), name="slot") + si = slot._save_slice_info + + variables.global_variables_initializer().run() + + self.assertEqual("var/part_%d/slot" % i, slot.op.name) + self.assertEqual([2], slot.get_shape().as_list()) + self.assertEqual(dtypes.float32, slot.dtype.base_dtype) + self.assertAllEqual([1.0, 2.5], slot.eval()) + self.assertAllEqual([2], si.full_shape) + self.assertAllEqual([i], si.var_offset) + self.assertAllEqual([1], si.var_shape) + + @test_util.run_deprecated_v1 + def testCreateSlotFromScalarVariable(self): + with self.test_session(): + s = variables.Variable(1.0, name="var") + p_v = variable_scope.get_variable( + "var", + shape=[2, 2], + partitioner=partitioned_variables.fixed_size_partitioner(2)) + for i, v in enumerate(p_v): + slot = slot_creator.create_slot(v, s.initialized_value(), name="slot") + + variables.global_variables_initializer().run() + + self.assertEqual("var/part_%d/slot" % i, slot.op.name) + self.assertEqual([], slot.get_shape().as_list()) + self.assertEqual(dtypes.float32, slot.dtype.base_dtype) + self.assertAllEqual(1.0, slot.eval()) + if __name__ == "__main__": test.main()