From 829a2365229580074cff806e725bca9f2a34788c Mon Sep 17 00:00:00 2001 From: Wei Ho Date: Wed, 24 Aug 2016 15:26:40 -0800 Subject: [PATCH] Adds _PartitionInfo to variable initializer signature Change: 131229727 --- RELEASE.md | 3 + .../framework/python/ops/variables_test.py | 3 +- .../layers/python/layers/initializers.py | 3 +- .../partitioned_variables_test.py | 2 +- .../kernel_tests/variable_scope_test.py | 58 +++++++ tensorflow/python/ops/array_ops.py | 2 +- tensorflow/python/ops/init_ops.py | 54 ++++--- tensorflow/python/ops/variable_scope.py | 153 +++++++++++++++++- .../saver_large_partitioned_variable_test.py | 3 +- 9 files changed, 249 insertions(+), 32 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index f1a8859c356..abb99a2ebb0 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -10,6 +10,9 @@ * Int32 elements of list(type) arguments are no longer placed in host memory by default. If necessary, a list(type) argument to a kernel can be placed in host memory using a HostMemory annotation. +* uniform_unit_scaling_initializer() no longer takes a full_shape arg, instead + relying on the partition info passed to the initializer function when it's + called. # Release 0.10.0 diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 11bd7231798..d6e1d03a560 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -472,7 +472,8 @@ class ModelVariablesTest(tf.test.TestCase): def testInitializedVariableValue(self): with self.test_session() as sess: - a = tf.contrib.framework.model_variable('a', [5], initializer=tf.ones) + a = tf.contrib.framework.model_variable( + 'a', [5], initializer=tf.ones_initializer) sess.run(tf.initialize_all_variables()) self.assertAllEqual(a.eval(), [1]*5) diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index 1786b71dcf7..fef925ca7e3 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -105,7 +105,8 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, raise TypeError('Cannot create initializer for non-floating point type.') if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']: raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode) - def _initializer(shape, dtype=dtype): + + def _initializer(shape, dtype=dtype, partition_info=None): """Initializer function.""" if not dtype.is_floating: raise TypeError('Cannot create initializer for non-floating point type.') diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py index 77feacad195..efa7322f653 100644 --- a/tensorflow/python/kernel_tests/partitioned_variables_test.py +++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py @@ -219,7 +219,7 @@ class PartitionerCreatorsTest(tf.test.TestCase): expected_partitions=[4, 1, 1]) -def _IotaInitializer(shape, dtype=tf.float32): +def _IotaInitializer(shape, dtype=tf.float32, partition_info=None): assert dtype == tf.float32 if len(shape) == 1: return range(shape[0]) diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index ff9716438ac..bce249d8c5e 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -751,5 +751,63 @@ class VariableScopeWithCustomGetterTest(tf.test.TestCase): np_vars, np_v = sess.run([true_vars, v]) self.assertAllClose(np_v, sum(np_vars)) + +class PartitionInfoTest(tf.test.TestCase): + + def testConstructorChecks(self): + # Invalid arg types. + with self.assertRaises(TypeError): + variable_scope._PartitionInfo(full_shape=None, var_offset=[0, 1]) + with self.assertRaises(TypeError): + variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=None) + with self.assertRaises(TypeError): + variable_scope._PartitionInfo(full_shape="foo", var_offset=[0, 1]) + with self.assertRaises(TypeError): + variable_scope._PartitionInfo(full_shape=[0, 1], var_offset="foo") + + # full_shape and var_offset must have same length. + with self.assertRaises(ValueError): + variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=[0]) + # Offset must always be less than shape. + with self.assertRaises(ValueError): + variable_scope._PartitionInfo(full_shape=[1, 1], var_offset=[0, 1]) + + def testSingleOffset(self): + partition_info = variable_scope._PartitionInfo( + full_shape=[9, 3], var_offset=[4, 0]) + self.assertEqual(4, partition_info.single_offset([1, 3])) + + # Tests when the variable isn't partitioned at all. + partition_info = variable_scope._PartitionInfo( + full_shape=[9, 3], var_offset=[0, 0]) + self.assertEqual(0, partition_info.single_offset([9, 3])) + + def testSingleSliceDim(self): + partition_info = variable_scope._PartitionInfo( + full_shape=[9, 3], var_offset=[4, 0]) + # Invalid shape. + with self.assertRaises(TypeError): + partition_info.single_slice_dim(None) + + # Rank of shape differs from full_shape. + with self.assertRaises(ValueError): + partition_info.single_slice_dim([1, 2, 3]) + + # Shape is too large given var_offset (4+6 > 9). + with self.assertRaises(ValueError): + partition_info.single_slice_dim([6, 3]) + + # Multiple possible slice dim from shape. + with self.assertRaises(ValueError): + partition_info.single_slice_dim([1, 1]) + + partition_info = variable_scope._PartitionInfo( + full_shape=[9, 3], var_offset=[0, 0]) + self.assertEqual(1, partition_info.single_slice_dim([9, 2])) + partition_info = variable_scope._PartitionInfo( + full_shape=[9, 3], var_offset=[4, 0]) + self.assertEqual(0, partition_info.single_slice_dim([2, 3])) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index c556e8e213c..16267fd4818 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -254,7 +254,7 @@ def rank_internal(input, name=None, optimize=True): # DEPRECATED use init_ops.zeros_initializer # TODO(irving) Move it to init_ops.py -def zeros_initializer(shape, dtype=dtypes.float32): +def zeros_initializer(shape, dtype=dtypes.float32, partition_info=None): """An adaptor for zeros() to match the Initializer spec.""" return zeros(shape, dtype) diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py index 77a8ced8371..24699b868bc 100644 --- a/tensorflow/python/ops/init_ops.py +++ b/tensorflow/python/ops/init_ops.py @@ -13,7 +13,22 @@ # limitations under the License. # ============================================================================== -"""Operations often used for initializing tensors.""" +"""Operations often used for initializing tensors. + +All variable initializers returned by functions in this file should have the +following signature: + +def _initializer(shape, dtype=dtypes.float32, partition_info=None): + Args: + shape: List of `int` representing the shape of the output `Tensor`. Some + initializers may also be able to accept a `Tensor`. + dtype: (Optional) Type of the output `Tensor`. + partition_info: (Optional) variable_scope._PartitionInfo object holding + additional information about how the variable is partitioned. May be + `None` if the variable is not partitioned. + Returns: + A `Tensor` of type `dtype` and `shape`. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -50,7 +65,7 @@ def _assert_float_dtype(dtype): zeros_initializer = array_ops.zeros_initializer -def ones_initializer(shape, dtype=dtypes.float32): +def ones_initializer(shape, dtype=dtypes.float32, partition_info=None): """An adaptor for ones() to match the Initializer spec.""" return array_ops.ones(shape, dtype) @@ -125,7 +140,7 @@ def constant_initializer(value=0, dtype=dtypes.float32): ValueError: Too many elements provided. Needed at most 6, but received 8 ``` """ - def _initializer(shape, dtype=dtype): + def _initializer(shape, dtype=dtype, partition_info=None): return constant_op.constant(value, dtype=dtype, shape=shape) return _initializer @@ -147,7 +162,7 @@ def random_uniform_initializer(minval=0, maxval=None, seed=None, Returns: An initializer that generates tensors with a uniform distribution. """ - def _initializer(shape, dtype=dtype): + def _initializer(shape, dtype=dtype, partition_info=None): return random_ops.random_uniform(shape, minval, maxval, dtype, seed=seed) return _initializer @@ -172,7 +187,8 @@ def random_normal_initializer(mean=0.0, stddev=1.0, seed=None, Raises: ValueError: if `dtype` is not a floating point type. """ - def _initializer(shape, dtype=_assert_float_dtype(dtype)): + def _initializer(shape, dtype=_assert_float_dtype(dtype), + partition_info=None): return random_ops.random_normal(shape, mean, stddev, dtype, seed=seed) return _initializer @@ -203,13 +219,16 @@ def truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None, Raises: ValueError: if `dtype` is not a floating point type. """ - def _initializer(shape, dtype=_assert_float_dtype(dtype)): + def _initializer(shape, dtype=_assert_float_dtype(dtype), + partition_info=None): return random_ops.truncated_normal(shape, mean, stddev, dtype, seed=seed) + return _initializer -def uniform_unit_scaling_initializer(factor=1.0, seed=None, - dtype=dtypes.float32, full_shape=None): +def uniform_unit_scaling_initializer(factor=1.0, + seed=None, + dtype=dtypes.float32): """Returns an initializer that generates tensors without scaling variance. When initializing a deep network, it is in principle advantageous to keep @@ -228,21 +247,12 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None, and the calculation of constants. In section 2.3 there, the constants were numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15. - If the shape tuple `full_shape` is provided, the scale will be calculated from - this predefined shape. This is useful when a `Variable` is being partitioned - across several shards, and each shard has a smaller shape than the whole. - Since the shards are usually concatenated when used, the scale should be - based on the shape of the whole. - Args: factor: Float. A multiplicative factor by which the values will be scaled. seed: A Python integer. Used to create random seeds. See [`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed) for behavior. dtype: The data type. Only floating point types are supported. - full_shape: Tuple or list of integers. The shape used for calculating - scale normalization (instead of the shape passed at creation time). - Useful when creating sharded variables via partitioning. Returns: An initializer that generates tensors with unit variance. @@ -250,8 +260,12 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None, Raises: ValueError: if `dtype` is not a floating point type. """ - def _initializer(shape, dtype=_assert_float_dtype(dtype)): - scale_shape = full_shape if full_shape is not None else shape + def _initializer(shape, dtype=_assert_float_dtype(dtype), + partition_info=None): + scale_shape = shape + if partition_info is not None: + scale_shape = partition_info.full_shape + input_size = 1.0 # Estimating input size is not possible to do perfectly, but we try. # The estimate, obtained by multiplying all dimensions but the last one, @@ -319,7 +333,7 @@ class _RandomWalkInitializer(object): self._nonlinearity = nonlinearity self._seed = seed - def __call__(self, shape, dtype=dtypes.float32): + def __call__(self, shape, dtype=dtypes.float32, partition_info=None): """Generate a tensor used to initialize a variable.""" return random_ops._random_walk(shape, self._nonlinearity, dtype, seed=self._seed) diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index c4c51aa443e..217556df3c7 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -39,6 +39,133 @@ __all__ = ["VariableScope", "get_variable_scope", "no_regularizer"] +class _PartitionInfo(object): + """Holds partition info used by initializer functions. + """ + + def __init__(self, full_shape, var_offset): + """Constructor. + + Args: + full_shape: Tuple or list of `int` indicating the full combined shape + of the partitioned variables. + var_offset: Tuple or list of `int` specifying offset of this partition + with respect to the full variable for each dimension. + + Raises: + TypeError: If `full_shape` or `var_offset` is not a sequence. + ValueError: If `full_shape` or `var_offset` differ in length. If + `var_offset` exceeds `full_shape` in any dimension. + """ + if not isinstance(full_shape, collections_lib.Sequence) or isinstance( + full_shape, six.string_types): + raise TypeError( + "`full_shape` must be a sequence (like tuple or list) instead of " + + type(full_shape).__name__) + + if not isinstance(var_offset, collections_lib.Sequence) or isinstance( + var_offset, six.string_types): + raise TypeError( + "`var_offset` must be a sequence (like tuple or list) instead of " + + type(var_offset).__name__) + + if len(var_offset) != len(full_shape): + raise ValueError( + "Expected equal length, but `var_offset` is of length {} while " + "full_shape is of length {}.".format( + len(var_offset), len(full_shape))) + + for i in xrange(len(full_shape)): + offset = var_offset[i] + shape = full_shape[i] + if offset < 0 or offset >= shape: + raise ValueError( + "Expected 0 <= offset < shape but found offset={}, shape={} for " + "var_offset={}, full_shape={}".format(offset, shape, var_offset, + full_shape)) + + self._full_shape = full_shape + self._var_offset = var_offset + + @property + def full_shape(self): + return self._full_shape + + @property + def var_offset(self): + return self._var_offset + + def single_offset(self, shape): + """Returns the offset when the variable is partitioned in at most one dim. + + Args: + shape: Tuple or list of `int` indicating the shape of one specific + variable partition. + + Returns: + `int` representing the offset in the dimension along which the variable is + partitioned. Returns 0 if the variable is not being partitioned. + + Raises: + ValueError: Depending on self.single_slice_dim(). + """ + + single_slice_dim = self.single_slice_dim(shape) + # If this variable is not being partitioned at all, single_slice_dim() could + # return None. + if single_slice_dim is None: + return 0 + return self.var_offset[single_slice_dim] + + def single_slice_dim(self, shape): + """Returns the slice dim when the variable is partitioned only in one dim. + + Args: + shape: Tuple or list of `int` indicating the shape of one specific + variable partition. + + Returns: + `int` representing the dimension that the variable is partitioned in, or + `None` if the variable doesn't seem to be partitioned at all. + + Raises: + TypeError: If `shape` is not a sequence. + ValueError: If `shape` is not the same length as `self.full_shape`. If + the variable is partitioned in more than one dimension. + """ + if not isinstance(shape, collections_lib.Sequence) or isinstance( + shape, six.string_types): + raise TypeError( + "`shape` must be a sequence (like tuple or list) instead of " + + type(shape).__name__) + + if len(shape) != len(self.full_shape): + raise ValueError( + "Expected equal length, but received shape={} of length {} while " + "self.full_shape={} is of length {}.".format(shape, len( + shape), self.full_shape, len(self.full_shape))) + + for i in xrange(len(shape)): + if self.var_offset[i] + shape[i] > self.full_shape[i]: + raise ValueError( + "With self.var_offset={}, a partition of shape={} would exceed " + "self.full_shape={} in dimension {}.".format( + self.var_offset, shape, self.full_shape, i)) + + slice_dim = None + for i in xrange(len(shape)): + if shape[i] == self.full_shape[i]: + continue + if slice_dim is not None: + raise ValueError( + "Cannot use single_slice_dim() with shape={} and " + "self.full_shape={} since slice dim could be either dimension {} " + "or {}.".format(shape, self.full_shape, i, slice_dim)) + slice_dim = i + + return slice_dim + + class _VariableStore(object): """Variable store that carries a number of named Variables. @@ -390,6 +517,8 @@ class _VariableStore(object): for i in xrange(num_slices): var_shape = slice_shape[:] var_offset = slice_offset[:] + partition_info = _PartitionInfo( + full_shape=shape.as_list(), var_offset=var_offset) if i < num_slices_with_excess: var_shape[slice_dim] += 1 slice_offset[slice_dim] += var_shape[slice_dim] @@ -397,8 +526,7 @@ class _VariableStore(object): var_full_name = "%s/part_%d" % (name, i) with ops.name_scope(var_full_name + "/PartitionedInitializer"): if initializer is None: - init = init_ops.uniform_unit_scaling_initializer( - full_shape=shape.as_list()) + init = init_ops.uniform_unit_scaling_initializer() init_shape = var_shape elif callable(initializer): init = initializer @@ -419,6 +547,7 @@ class _VariableStore(object): shape=init_shape, dtype=dtype, initializer=init, + partition_info=partition_info, regularizer=regularizer, reuse=reuse, trainable=trainable, @@ -443,10 +572,18 @@ class _VariableStore(object): self._partitioned_vars[name] = partitioned_var return partitioned_var - def _get_single_variable(self, name, shape=None, dtype=dtypes.float32, - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, - caching_device=None, validate_shape=True): + def _get_single_variable(self, + name, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + partition_info=None, + reuse=None, + trainable=True, + collections=None, + caching_device=None, + validate_shape=True): """Get or create a single Variable (e.g. a shard or entire variable). See the documentation of get_variable above (ignore partitioning components) @@ -458,6 +595,7 @@ class _VariableStore(object): dtype: see get_variable. initializer: see get_variable. regularizer: see get_variable. + partition_info: _PartitionInfo object. reuse: see get_variable. trainable: see get_variable. collections: see get_variable. @@ -523,7 +661,8 @@ class _VariableStore(object): init_val = initializer variable_dtype = None else: - init_val = lambda: initializer(shape.as_list(), dtype=dtype) + init_val = lambda: initializer( + shape.as_list(), dtype=dtype, partition_info=partition_info) variable_dtype = dtype.base_dtype # Create the variable. diff --git a/tensorflow/python/training/saver_large_partitioned_variable_test.py b/tensorflow/python/training/saver_large_partitioned_variable_test.py index bc071eb270e..4c0526cc42b 100644 --- a/tensorflow/python/training/saver_large_partitioned_variable_test.py +++ b/tensorflow/python/training/saver_large_partitioned_variable_test.py @@ -34,7 +34,8 @@ class SaverLargePartitionedVariableTest(tf.test.TestCase): with tf.device("/cpu:0"): # Create a partitioned variable which is larger than int32 size but # split into smaller sized variables. - init = lambda shape, dtype: tf.constant(True, dtype, shape) + init = lambda shape, dtype, partition_info: tf.constant( + True, dtype, shape) partitioned_var = tf.create_partitioned_variables( [1 << 31], [4], init, dtype=tf.bool, name=var_name) tf.initialize_all_variables().run()