Merge pull request #30018 from candyzone:master

PiperOrigin-RevId: 267492477
This commit is contained in:
TensorFlower Gardener 2019-09-05 17:47:46 -07:00
commit 8beef3a4c1
2 changed files with 57 additions and 8 deletions

View File

@ -15,9 +15,9 @@
"""Standard functions for creating slots.
A slot is a `Variable` created with the same shape as a primary variable or
`Tensor`. A slot is always scoped in the namespace of the primary object and
typically has the same device and type.
A slot is a `Variable` created with the same first m-dimension as a primary
variable or `Tensor`. A slot is always scoped in the namespace of the primary
object and typically has the same device and type.
Slots are typically used as accumulators to track values associated with
the primary object:
@ -84,11 +84,19 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
# remove "'linear//weight' + '/'" and ':0'.
real_slot_name = slot.name[len(primary.op.name + "/"):-2]
slice_info = primary._save_slice_info
slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
slice_info.full_name + "/" + real_slot_name,
slice_info.full_shape[:],
slice_info.var_offset[:],
slice_info.var_shape[:]))
# support slot's shape not same as primary's shape
# example: primary's shape = [10, 20, 30], slot's shape =
# None, [], [10], [10, 20] or [10, 20, 30] is allowed
# slot's shape = None or [10, 20, 30], set slot's slice_info same as primary
# slot's shape = [], don't set slot's slice_info
# slot's shape = [10] or [10, 20], set slot's slice_info according to ndims
n = slot.shape.ndims
if n is None or n > 0:
slot._set_save_slice_info(
variables.Variable.SaveSliceInfo(
slice_info.full_name + "/" + real_slot_name,
slice_info.full_shape[:n], slice_info.var_offset[:n],
slice_info.var_shape[:n]))
# pylint: enable=protected-access
return slot

View File

@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@ -134,6 +135,46 @@ class SlotCreatorTest(test.TestCase):
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
self.assertEqual("scope/scope/var/slot", slot.op.name)
@test_util.run_deprecated_v1
def testCreateSlotFromFirstMDimensionVariable(self):
with self.test_session():
s = variables.Variable([1.0, 2.5], name="var")
p_v = variable_scope.get_variable(
"var",
shape=[2, 2],
partitioner=partitioned_variables.fixed_size_partitioner(2))
for i, v in enumerate(p_v):
slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")
si = slot._save_slice_info
variables.global_variables_initializer().run()
self.assertEqual("var/part_%d/slot" % i, slot.op.name)
self.assertEqual([2], slot.get_shape().as_list())
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
self.assertAllEqual([1.0, 2.5], slot.eval())
self.assertAllEqual([2], si.full_shape)
self.assertAllEqual([i], si.var_offset)
self.assertAllEqual([1], si.var_shape)
@test_util.run_deprecated_v1
def testCreateSlotFromScalarVariable(self):
with self.test_session():
s = variables.Variable(1.0, name="var")
p_v = variable_scope.get_variable(
"var",
shape=[2, 2],
partitioner=partitioned_variables.fixed_size_partitioner(2))
for i, v in enumerate(p_v):
slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")
variables.global_variables_initializer().run()
self.assertEqual("var/part_%d/slot" % i, slot.op.name)
self.assertEqual([], slot.get_shape().as_list())
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
self.assertAllEqual(1.0, slot.eval())
if __name__ == "__main__":
test.main()