Fix: slot and primary can be different shape

This commit is contained in:
candy.dc 2019-06-21 14:26:45 +08:00
parent a59ad83d06
commit 2a72e99930
2 changed files with 52 additions and 8 deletions

View File

@ -15,9 +15,9 @@
"""Standard functions for creating slots.
A slot is a `Variable` created with the same shape as a primary variable or
`Tensor`. A slot is always scoped in the namespace of the primary object and
typically has the same device and type.
A slot is a `Variable` created with the same first m-dimension as a primary
variable or `Tensor`. A slot is always scoped in the namespace of the primary
object and typically has the same device and type.
Slots are typically used as accumulators to track values associated with
the primary object:
@ -85,11 +85,19 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
# remove "'linear//weight' + '/'" and ':0'.
real_slot_name = slot.name[len(primary.op.name + "/"):-2]
slice_info = primary._save_slice_info
slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
slice_info.full_name + "/" + real_slot_name,
slice_info.full_shape[:],
slice_info.var_offset[:],
slice_info.var_shape[:]))
# support slot's shape not same as primary's shape
# example: primary's shape = [10, 20, 30], slot's shape =
# None, [], [10], [10, 20] or [10, 20, 30] is allowed
# slot's shape = None or [10, 20, 30], set slot's slice_info same as primary
# slot's shape = [], don't set slot's slice_info
# slot's shape = [10] or [10, 20], set slot's slice_info according to ndims
n = slot.shape.ndims
if n is None or n > 0:
slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
slice_info.full_name + "/" + real_slot_name,
slice_info.full_shape[:n],
slice_info.var_offset[:n],
slice_info.var_shape[:n]))
# pylint: enable=protected-access
return slot

View File

@ -23,6 +23,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@ -134,6 +136,40 @@ class SlotCreatorTest(test.TestCase):
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
self.assertEqual("scope/scope/var/slot", slot.op.name)
def testCreateSlotFromFirstMDimensionVariable(self):
with self.test_session():
s = variables.Variable([1.0, 2.5], name="var")
p_v = variable_scope.get_variable("var", shape=[2, 2],
partitioner=partitioned_variables.fixed_size_partitioner(2))
for i, v in enumerate(p_v):
slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")
si = slot._save_slice_info
variables.global_variables_initializer().run()
self.assertEqual("var/part_%d/slot"%i, slot.op.name)
self.assertEqual([2], slot.get_shape().as_list())
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
self.assertAllEqual([1.0, 2.5], slot.eval())
self.assertAllEqual([2], si.full_shape)
self.assertAllEqual([i], si.var_offset)
self.assertAllEqual([1], si.var_shape)
def testCreateSlotFromScalarVariable(self):
with self.test_session():
s = variables.Variable(1.0, name="var")
p_v = variable_scope.get_variable("var", shape=[2, 2],
partitioner=partitioned_variables.fixed_size_partitioner(2))
for i, v in enumerate(p_v):
slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")
variables.global_variables_initializer().run()
self.assertEqual("var/part_%d/slot"%i, slot.op.name)
self.assertEqual([], slot.get_shape().as_list())
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
self.assertAllEqual(1.0, slot.eval())
if __name__ == "__main__":
test.main()