Merge pull request #30018 from candyzone:master
PiperOrigin-RevId: 267492477
This commit is contained in:
commit
8beef3a4c1
@ -15,9 +15,9 @@
|
||||
|
||||
"""Standard functions for creating slots.
|
||||
|
||||
A slot is a `Variable` created with the same shape as a primary variable or
|
||||
`Tensor`. A slot is always scoped in the namespace of the primary object and
|
||||
typically has the same device and type.
|
||||
A slot is a `Variable` created with the same first m-dimension as a primary
|
||||
variable or `Tensor`. A slot is always scoped in the namespace of the primary
|
||||
object and typically has the same device and type.
|
||||
|
||||
Slots are typically used as accumulators to track values associated with
|
||||
the primary object:
|
||||
@ -84,11 +84,19 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
|
||||
# remove "'linear//weight' + '/'" and ':0'.
|
||||
real_slot_name = slot.name[len(primary.op.name + "/"):-2]
|
||||
slice_info = primary._save_slice_info
|
||||
slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
|
||||
slice_info.full_name + "/" + real_slot_name,
|
||||
slice_info.full_shape[:],
|
||||
slice_info.var_offset[:],
|
||||
slice_info.var_shape[:]))
|
||||
# support slot's shape not same as primary's shape
|
||||
# example: primary's shape = [10, 20, 30], slot's shape =
|
||||
# None, [], [10], [10, 20] or [10, 20, 30] is allowed
|
||||
# slot's shape = None or [10, 20, 30], set slot's slice_info same as primary
|
||||
# slot's shape = [], don't set slot's slice_info
|
||||
# slot's shape = [10] or [10, 20], set slot's slice_info according to ndims
|
||||
n = slot.shape.ndims
|
||||
if n is None or n > 0:
|
||||
slot._set_save_slice_info(
|
||||
variables.Variable.SaveSliceInfo(
|
||||
slice_info.full_name + "/" + real_slot_name,
|
||||
slice_info.full_shape[:n], slice_info.var_offset[:n],
|
||||
slice_info.var_shape[:n]))
|
||||
# pylint: enable=protected-access
|
||||
return slot
|
||||
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
@ -134,6 +135,46 @@ class SlotCreatorTest(test.TestCase):
|
||||
slot = slot_creator.create_slot(v, v.initialized_value(), name="slot")
|
||||
self.assertEqual("scope/scope/var/slot", slot.op.name)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCreateSlotFromFirstMDimensionVariable(self):
|
||||
with self.test_session():
|
||||
s = variables.Variable([1.0, 2.5], name="var")
|
||||
p_v = variable_scope.get_variable(
|
||||
"var",
|
||||
shape=[2, 2],
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
for i, v in enumerate(p_v):
|
||||
slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")
|
||||
si = slot._save_slice_info
|
||||
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertEqual("var/part_%d/slot" % i, slot.op.name)
|
||||
self.assertEqual([2], slot.get_shape().as_list())
|
||||
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
|
||||
self.assertAllEqual([1.0, 2.5], slot.eval())
|
||||
self.assertAllEqual([2], si.full_shape)
|
||||
self.assertAllEqual([i], si.var_offset)
|
||||
self.assertAllEqual([1], si.var_shape)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testCreateSlotFromScalarVariable(self):
|
||||
with self.test_session():
|
||||
s = variables.Variable(1.0, name="var")
|
||||
p_v = variable_scope.get_variable(
|
||||
"var",
|
||||
shape=[2, 2],
|
||||
partitioner=partitioned_variables.fixed_size_partitioner(2))
|
||||
for i, v in enumerate(p_v):
|
||||
slot = slot_creator.create_slot(v, s.initialized_value(), name="slot")
|
||||
|
||||
variables.global_variables_initializer().run()
|
||||
|
||||
self.assertEqual("var/part_%d/slot" % i, slot.op.name)
|
||||
self.assertEqual([], slot.get_shape().as_list())
|
||||
self.assertEqual(dtypes.float32, slot.dtype.base_dtype)
|
||||
self.assertAllEqual(1.0, slot.eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user