Adds _PartitionInfo to variable initializer signature
Change: 131229727
This commit is contained in:
parent
18e2e13ff2
commit
829a236522
@ -10,6 +10,9 @@
|
|||||||
* Int32 elements of list(type) arguments are no longer placed in host memory by
|
* Int32 elements of list(type) arguments are no longer placed in host memory by
|
||||||
default. If necessary, a list(type) argument to a kernel can be placed in host
|
default. If necessary, a list(type) argument to a kernel can be placed in host
|
||||||
memory using a HostMemory annotation.
|
memory using a HostMemory annotation.
|
||||||
|
* uniform_unit_scaling_initializer() no longer takes a full_shape arg, instead
|
||||||
|
relying on the partition info passed to the initializer function when it's
|
||||||
|
called.
|
||||||
|
|
||||||
# Release 0.10.0
|
# Release 0.10.0
|
||||||
|
|
||||||
|
@ -472,7 +472,8 @@ class ModelVariablesTest(tf.test.TestCase):
|
|||||||
|
|
||||||
def testInitializedVariableValue(self):
|
def testInitializedVariableValue(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
a = tf.contrib.framework.model_variable('a', [5], initializer=tf.ones)
|
a = tf.contrib.framework.model_variable(
|
||||||
|
'a', [5], initializer=tf.ones_initializer)
|
||||||
sess.run(tf.initialize_all_variables())
|
sess.run(tf.initialize_all_variables())
|
||||||
self.assertAllEqual(a.eval(), [1]*5)
|
self.assertAllEqual(a.eval(), [1]*5)
|
||||||
|
|
||||||
|
@ -105,7 +105,8 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
|
|||||||
raise TypeError('Cannot create initializer for non-floating point type.')
|
raise TypeError('Cannot create initializer for non-floating point type.')
|
||||||
if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']:
|
if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']:
|
||||||
raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)
|
raise TypeError('Unknow mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode)
|
||||||
def _initializer(shape, dtype=dtype):
|
|
||||||
|
def _initializer(shape, dtype=dtype, partition_info=None):
|
||||||
"""Initializer function."""
|
"""Initializer function."""
|
||||||
if not dtype.is_floating:
|
if not dtype.is_floating:
|
||||||
raise TypeError('Cannot create initializer for non-floating point type.')
|
raise TypeError('Cannot create initializer for non-floating point type.')
|
||||||
|
@ -219,7 +219,7 @@ class PartitionerCreatorsTest(tf.test.TestCase):
|
|||||||
expected_partitions=[4, 1, 1])
|
expected_partitions=[4, 1, 1])
|
||||||
|
|
||||||
|
|
||||||
def _IotaInitializer(shape, dtype=tf.float32):
|
def _IotaInitializer(shape, dtype=tf.float32, partition_info=None):
|
||||||
assert dtype == tf.float32
|
assert dtype == tf.float32
|
||||||
if len(shape) == 1:
|
if len(shape) == 1:
|
||||||
return range(shape[0])
|
return range(shape[0])
|
||||||
|
@ -751,5 +751,63 @@ class VariableScopeWithCustomGetterTest(tf.test.TestCase):
|
|||||||
np_vars, np_v = sess.run([true_vars, v])
|
np_vars, np_v = sess.run([true_vars, v])
|
||||||
self.assertAllClose(np_v, sum(np_vars))
|
self.assertAllClose(np_v, sum(np_vars))
|
||||||
|
|
||||||
|
|
||||||
|
class PartitionInfoTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testConstructorChecks(self):
|
||||||
|
# Invalid arg types.
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
variable_scope._PartitionInfo(full_shape=None, var_offset=[0, 1])
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=None)
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
variable_scope._PartitionInfo(full_shape="foo", var_offset=[0, 1])
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
variable_scope._PartitionInfo(full_shape=[0, 1], var_offset="foo")
|
||||||
|
|
||||||
|
# full_shape and var_offset must have same length.
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
variable_scope._PartitionInfo(full_shape=[0, 1], var_offset=[0])
|
||||||
|
# Offset must always be less than shape.
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
variable_scope._PartitionInfo(full_shape=[1, 1], var_offset=[0, 1])
|
||||||
|
|
||||||
|
def testSingleOffset(self):
|
||||||
|
partition_info = variable_scope._PartitionInfo(
|
||||||
|
full_shape=[9, 3], var_offset=[4, 0])
|
||||||
|
self.assertEqual(4, partition_info.single_offset([1, 3]))
|
||||||
|
|
||||||
|
# Tests when the variable isn't partitioned at all.
|
||||||
|
partition_info = variable_scope._PartitionInfo(
|
||||||
|
full_shape=[9, 3], var_offset=[0, 0])
|
||||||
|
self.assertEqual(0, partition_info.single_offset([9, 3]))
|
||||||
|
|
||||||
|
def testSingleSliceDim(self):
|
||||||
|
partition_info = variable_scope._PartitionInfo(
|
||||||
|
full_shape=[9, 3], var_offset=[4, 0])
|
||||||
|
# Invalid shape.
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
partition_info.single_slice_dim(None)
|
||||||
|
|
||||||
|
# Rank of shape differs from full_shape.
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
partition_info.single_slice_dim([1, 2, 3])
|
||||||
|
|
||||||
|
# Shape is too large given var_offset (4+6 > 9).
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
partition_info.single_slice_dim([6, 3])
|
||||||
|
|
||||||
|
# Multiple possible slice dim from shape.
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
partition_info.single_slice_dim([1, 1])
|
||||||
|
|
||||||
|
partition_info = variable_scope._PartitionInfo(
|
||||||
|
full_shape=[9, 3], var_offset=[0, 0])
|
||||||
|
self.assertEqual(1, partition_info.single_slice_dim([9, 2]))
|
||||||
|
partition_info = variable_scope._PartitionInfo(
|
||||||
|
full_shape=[9, 3], var_offset=[4, 0])
|
||||||
|
self.assertEqual(0, partition_info.single_slice_dim([2, 3]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -254,7 +254,7 @@ def rank_internal(input, name=None, optimize=True):
|
|||||||
|
|
||||||
# DEPRECATED use init_ops.zeros_initializer
|
# DEPRECATED use init_ops.zeros_initializer
|
||||||
# TODO(irving) Move it to init_ops.py
|
# TODO(irving) Move it to init_ops.py
|
||||||
def zeros_initializer(shape, dtype=dtypes.float32):
|
def zeros_initializer(shape, dtype=dtypes.float32, partition_info=None):
|
||||||
"""An adaptor for zeros() to match the Initializer spec."""
|
"""An adaptor for zeros() to match the Initializer spec."""
|
||||||
return zeros(shape, dtype)
|
return zeros(shape, dtype)
|
||||||
|
|
||||||
|
@ -13,7 +13,22 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Operations often used for initializing tensors."""
|
"""Operations often used for initializing tensors.
|
||||||
|
|
||||||
|
All variable initializers returned by functions in this file should have the
|
||||||
|
following signature:
|
||||||
|
|
||||||
|
def _initializer(shape, dtype=dtypes.float32, partition_info=None):
|
||||||
|
Args:
|
||||||
|
shape: List of `int` representing the shape of the output `Tensor`. Some
|
||||||
|
initializers may also be able to accept a `Tensor`.
|
||||||
|
dtype: (Optional) Type of the output `Tensor`.
|
||||||
|
partition_info: (Optional) variable_scope._PartitionInfo object holding
|
||||||
|
additional information about how the variable is partitioned. May be
|
||||||
|
`None` if the variable is not partitioned.
|
||||||
|
Returns:
|
||||||
|
A `Tensor` of type `dtype` and `shape`.
|
||||||
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -50,7 +65,7 @@ def _assert_float_dtype(dtype):
|
|||||||
zeros_initializer = array_ops.zeros_initializer
|
zeros_initializer = array_ops.zeros_initializer
|
||||||
|
|
||||||
|
|
||||||
def ones_initializer(shape, dtype=dtypes.float32):
|
def ones_initializer(shape, dtype=dtypes.float32, partition_info=None):
|
||||||
"""An adaptor for ones() to match the Initializer spec."""
|
"""An adaptor for ones() to match the Initializer spec."""
|
||||||
return array_ops.ones(shape, dtype)
|
return array_ops.ones(shape, dtype)
|
||||||
|
|
||||||
@ -125,7 +140,7 @@ def constant_initializer(value=0, dtype=dtypes.float32):
|
|||||||
ValueError: Too many elements provided. Needed at most 6, but received 8
|
ValueError: Too many elements provided. Needed at most 6, but received 8
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
def _initializer(shape, dtype=dtype):
|
def _initializer(shape, dtype=dtype, partition_info=None):
|
||||||
return constant_op.constant(value, dtype=dtype, shape=shape)
|
return constant_op.constant(value, dtype=dtype, shape=shape)
|
||||||
return _initializer
|
return _initializer
|
||||||
|
|
||||||
@ -147,7 +162,7 @@ def random_uniform_initializer(minval=0, maxval=None, seed=None,
|
|||||||
Returns:
|
Returns:
|
||||||
An initializer that generates tensors with a uniform distribution.
|
An initializer that generates tensors with a uniform distribution.
|
||||||
"""
|
"""
|
||||||
def _initializer(shape, dtype=dtype):
|
def _initializer(shape, dtype=dtype, partition_info=None):
|
||||||
return random_ops.random_uniform(shape, minval, maxval, dtype, seed=seed)
|
return random_ops.random_uniform(shape, minval, maxval, dtype, seed=seed)
|
||||||
return _initializer
|
return _initializer
|
||||||
|
|
||||||
@ -172,7 +187,8 @@ def random_normal_initializer(mean=0.0, stddev=1.0, seed=None,
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if `dtype` is not a floating point type.
|
ValueError: if `dtype` is not a floating point type.
|
||||||
"""
|
"""
|
||||||
def _initializer(shape, dtype=_assert_float_dtype(dtype)):
|
def _initializer(shape, dtype=_assert_float_dtype(dtype),
|
||||||
|
partition_info=None):
|
||||||
return random_ops.random_normal(shape, mean, stddev, dtype, seed=seed)
|
return random_ops.random_normal(shape, mean, stddev, dtype, seed=seed)
|
||||||
return _initializer
|
return _initializer
|
||||||
|
|
||||||
@ -203,13 +219,16 @@ def truncated_normal_initializer(mean=0.0, stddev=1.0, seed=None,
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if `dtype` is not a floating point type.
|
ValueError: if `dtype` is not a floating point type.
|
||||||
"""
|
"""
|
||||||
def _initializer(shape, dtype=_assert_float_dtype(dtype)):
|
def _initializer(shape, dtype=_assert_float_dtype(dtype),
|
||||||
|
partition_info=None):
|
||||||
return random_ops.truncated_normal(shape, mean, stddev, dtype, seed=seed)
|
return random_ops.truncated_normal(shape, mean, stddev, dtype, seed=seed)
|
||||||
|
|
||||||
return _initializer
|
return _initializer
|
||||||
|
|
||||||
|
|
||||||
def uniform_unit_scaling_initializer(factor=1.0, seed=None,
|
def uniform_unit_scaling_initializer(factor=1.0,
|
||||||
dtype=dtypes.float32, full_shape=None):
|
seed=None,
|
||||||
|
dtype=dtypes.float32):
|
||||||
"""Returns an initializer that generates tensors without scaling variance.
|
"""Returns an initializer that generates tensors without scaling variance.
|
||||||
|
|
||||||
When initializing a deep network, it is in principle advantageous to keep
|
When initializing a deep network, it is in principle advantageous to keep
|
||||||
@ -228,21 +247,12 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None,
|
|||||||
and the calculation of constants. In section 2.3 there, the constants were
|
and the calculation of constants. In section 2.3 there, the constants were
|
||||||
numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
|
numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
|
||||||
|
|
||||||
If the shape tuple `full_shape` is provided, the scale will be calculated from
|
|
||||||
this predefined shape. This is useful when a `Variable` is being partitioned
|
|
||||||
across several shards, and each shard has a smaller shape than the whole.
|
|
||||||
Since the shards are usually concatenated when used, the scale should be
|
|
||||||
based on the shape of the whole.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
factor: Float. A multiplicative factor by which the values will be scaled.
|
factor: Float. A multiplicative factor by which the values will be scaled.
|
||||||
seed: A Python integer. Used to create random seeds. See
|
seed: A Python integer. Used to create random seeds. See
|
||||||
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
|
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
|
||||||
for behavior.
|
for behavior.
|
||||||
dtype: The data type. Only floating point types are supported.
|
dtype: The data type. Only floating point types are supported.
|
||||||
full_shape: Tuple or list of integers. The shape used for calculating
|
|
||||||
scale normalization (instead of the shape passed at creation time).
|
|
||||||
Useful when creating sharded variables via partitioning.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An initializer that generates tensors with unit variance.
|
An initializer that generates tensors with unit variance.
|
||||||
@ -250,8 +260,12 @@ def uniform_unit_scaling_initializer(factor=1.0, seed=None,
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if `dtype` is not a floating point type.
|
ValueError: if `dtype` is not a floating point type.
|
||||||
"""
|
"""
|
||||||
def _initializer(shape, dtype=_assert_float_dtype(dtype)):
|
def _initializer(shape, dtype=_assert_float_dtype(dtype),
|
||||||
scale_shape = full_shape if full_shape is not None else shape
|
partition_info=None):
|
||||||
|
scale_shape = shape
|
||||||
|
if partition_info is not None:
|
||||||
|
scale_shape = partition_info.full_shape
|
||||||
|
|
||||||
input_size = 1.0
|
input_size = 1.0
|
||||||
# Estimating input size is not possible to do perfectly, but we try.
|
# Estimating input size is not possible to do perfectly, but we try.
|
||||||
# The estimate, obtained by multiplying all dimensions but the last one,
|
# The estimate, obtained by multiplying all dimensions but the last one,
|
||||||
@ -319,7 +333,7 @@ class _RandomWalkInitializer(object):
|
|||||||
self._nonlinearity = nonlinearity
|
self._nonlinearity = nonlinearity
|
||||||
self._seed = seed
|
self._seed = seed
|
||||||
|
|
||||||
def __call__(self, shape, dtype=dtypes.float32):
|
def __call__(self, shape, dtype=dtypes.float32, partition_info=None):
|
||||||
"""Generate a tensor used to initialize a variable."""
|
"""Generate a tensor used to initialize a variable."""
|
||||||
return random_ops._random_walk(shape, self._nonlinearity, dtype,
|
return random_ops._random_walk(shape, self._nonlinearity, dtype,
|
||||||
seed=self._seed)
|
seed=self._seed)
|
||||||
|
@ -39,6 +39,133 @@ __all__ = ["VariableScope", "get_variable_scope",
|
|||||||
"no_regularizer"]
|
"no_regularizer"]
|
||||||
|
|
||||||
|
|
||||||
|
class _PartitionInfo(object):
|
||||||
|
"""Holds partition info used by initializer functions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, full_shape, var_offset):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
full_shape: Tuple or list of `int` indicating the full combined shape
|
||||||
|
of the partitioned variables.
|
||||||
|
var_offset: Tuple or list of `int` specifying offset of this partition
|
||||||
|
with respect to the full variable for each dimension.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `full_shape` or `var_offset` is not a sequence.
|
||||||
|
ValueError: If `full_shape` or `var_offset` differ in length. If
|
||||||
|
`var_offset` exceeds `full_shape` in any dimension.
|
||||||
|
"""
|
||||||
|
if not isinstance(full_shape, collections_lib.Sequence) or isinstance(
|
||||||
|
full_shape, six.string_types):
|
||||||
|
raise TypeError(
|
||||||
|
"`full_shape` must be a sequence (like tuple or list) instead of " +
|
||||||
|
type(full_shape).__name__)
|
||||||
|
|
||||||
|
if not isinstance(var_offset, collections_lib.Sequence) or isinstance(
|
||||||
|
var_offset, six.string_types):
|
||||||
|
raise TypeError(
|
||||||
|
"`var_offset` must be a sequence (like tuple or list) instead of " +
|
||||||
|
type(var_offset).__name__)
|
||||||
|
|
||||||
|
if len(var_offset) != len(full_shape):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected equal length, but `var_offset` is of length {} while "
|
||||||
|
"full_shape is of length {}.".format(
|
||||||
|
len(var_offset), len(full_shape)))
|
||||||
|
|
||||||
|
for i in xrange(len(full_shape)):
|
||||||
|
offset = var_offset[i]
|
||||||
|
shape = full_shape[i]
|
||||||
|
if offset < 0 or offset >= shape:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected 0 <= offset < shape but found offset={}, shape={} for "
|
||||||
|
"var_offset={}, full_shape={}".format(offset, shape, var_offset,
|
||||||
|
full_shape))
|
||||||
|
|
||||||
|
self._full_shape = full_shape
|
||||||
|
self._var_offset = var_offset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_shape(self):
|
||||||
|
return self._full_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def var_offset(self):
|
||||||
|
return self._var_offset
|
||||||
|
|
||||||
|
def single_offset(self, shape):
|
||||||
|
"""Returns the offset when the variable is partitioned in at most one dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: Tuple or list of `int` indicating the shape of one specific
|
||||||
|
variable partition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int` representing the offset in the dimension along which the variable is
|
||||||
|
partitioned. Returns 0 if the variable is not being partitioned.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Depending on self.single_slice_dim().
|
||||||
|
"""
|
||||||
|
|
||||||
|
single_slice_dim = self.single_slice_dim(shape)
|
||||||
|
# If this variable is not being partitioned at all, single_slice_dim() could
|
||||||
|
# return None.
|
||||||
|
if single_slice_dim is None:
|
||||||
|
return 0
|
||||||
|
return self.var_offset[single_slice_dim]
|
||||||
|
|
||||||
|
def single_slice_dim(self, shape):
|
||||||
|
"""Returns the slice dim when the variable is partitioned only in one dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: Tuple or list of `int` indicating the shape of one specific
|
||||||
|
variable partition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int` representing the dimension that the variable is partitioned in, or
|
||||||
|
`None` if the variable doesn't seem to be partitioned at all.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If `shape` is not a sequence.
|
||||||
|
ValueError: If `shape` is not the same length as `self.full_shape`. If
|
||||||
|
the variable is partitioned in more than one dimension.
|
||||||
|
"""
|
||||||
|
if not isinstance(shape, collections_lib.Sequence) or isinstance(
|
||||||
|
shape, six.string_types):
|
||||||
|
raise TypeError(
|
||||||
|
"`shape` must be a sequence (like tuple or list) instead of " +
|
||||||
|
type(shape).__name__)
|
||||||
|
|
||||||
|
if len(shape) != len(self.full_shape):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected equal length, but received shape={} of length {} while "
|
||||||
|
"self.full_shape={} is of length {}.".format(shape, len(
|
||||||
|
shape), self.full_shape, len(self.full_shape)))
|
||||||
|
|
||||||
|
for i in xrange(len(shape)):
|
||||||
|
if self.var_offset[i] + shape[i] > self.full_shape[i]:
|
||||||
|
raise ValueError(
|
||||||
|
"With self.var_offset={}, a partition of shape={} would exceed "
|
||||||
|
"self.full_shape={} in dimension {}.".format(
|
||||||
|
self.var_offset, shape, self.full_shape, i))
|
||||||
|
|
||||||
|
slice_dim = None
|
||||||
|
for i in xrange(len(shape)):
|
||||||
|
if shape[i] == self.full_shape[i]:
|
||||||
|
continue
|
||||||
|
if slice_dim is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot use single_slice_dim() with shape={} and "
|
||||||
|
"self.full_shape={} since slice dim could be either dimension {} "
|
||||||
|
"or {}.".format(shape, self.full_shape, i, slice_dim))
|
||||||
|
slice_dim = i
|
||||||
|
|
||||||
|
return slice_dim
|
||||||
|
|
||||||
|
|
||||||
class _VariableStore(object):
|
class _VariableStore(object):
|
||||||
"""Variable store that carries a number of named Variables.
|
"""Variable store that carries a number of named Variables.
|
||||||
|
|
||||||
@ -390,6 +517,8 @@ class _VariableStore(object):
|
|||||||
for i in xrange(num_slices):
|
for i in xrange(num_slices):
|
||||||
var_shape = slice_shape[:]
|
var_shape = slice_shape[:]
|
||||||
var_offset = slice_offset[:]
|
var_offset = slice_offset[:]
|
||||||
|
partition_info = _PartitionInfo(
|
||||||
|
full_shape=shape.as_list(), var_offset=var_offset)
|
||||||
if i < num_slices_with_excess:
|
if i < num_slices_with_excess:
|
||||||
var_shape[slice_dim] += 1
|
var_shape[slice_dim] += 1
|
||||||
slice_offset[slice_dim] += var_shape[slice_dim]
|
slice_offset[slice_dim] += var_shape[slice_dim]
|
||||||
@ -397,8 +526,7 @@ class _VariableStore(object):
|
|||||||
var_full_name = "%s/part_%d" % (name, i)
|
var_full_name = "%s/part_%d" % (name, i)
|
||||||
with ops.name_scope(var_full_name + "/PartitionedInitializer"):
|
with ops.name_scope(var_full_name + "/PartitionedInitializer"):
|
||||||
if initializer is None:
|
if initializer is None:
|
||||||
init = init_ops.uniform_unit_scaling_initializer(
|
init = init_ops.uniform_unit_scaling_initializer()
|
||||||
full_shape=shape.as_list())
|
|
||||||
init_shape = var_shape
|
init_shape = var_shape
|
||||||
elif callable(initializer):
|
elif callable(initializer):
|
||||||
init = initializer
|
init = initializer
|
||||||
@ -419,6 +547,7 @@ class _VariableStore(object):
|
|||||||
shape=init_shape,
|
shape=init_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
initializer=init,
|
initializer=init,
|
||||||
|
partition_info=partition_info,
|
||||||
regularizer=regularizer,
|
regularizer=regularizer,
|
||||||
reuse=reuse,
|
reuse=reuse,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
@ -443,10 +572,18 @@ class _VariableStore(object):
|
|||||||
self._partitioned_vars[name] = partitioned_var
|
self._partitioned_vars[name] = partitioned_var
|
||||||
return partitioned_var
|
return partitioned_var
|
||||||
|
|
||||||
def _get_single_variable(self, name, shape=None, dtype=dtypes.float32,
|
def _get_single_variable(self,
|
||||||
initializer=None, regularizer=None, reuse=None,
|
name,
|
||||||
trainable=True, collections=None,
|
shape=None,
|
||||||
caching_device=None, validate_shape=True):
|
dtype=dtypes.float32,
|
||||||
|
initializer=None,
|
||||||
|
regularizer=None,
|
||||||
|
partition_info=None,
|
||||||
|
reuse=None,
|
||||||
|
trainable=True,
|
||||||
|
collections=None,
|
||||||
|
caching_device=None,
|
||||||
|
validate_shape=True):
|
||||||
"""Get or create a single Variable (e.g. a shard or entire variable).
|
"""Get or create a single Variable (e.g. a shard or entire variable).
|
||||||
|
|
||||||
See the documentation of get_variable above (ignore partitioning components)
|
See the documentation of get_variable above (ignore partitioning components)
|
||||||
@ -458,6 +595,7 @@ class _VariableStore(object):
|
|||||||
dtype: see get_variable.
|
dtype: see get_variable.
|
||||||
initializer: see get_variable.
|
initializer: see get_variable.
|
||||||
regularizer: see get_variable.
|
regularizer: see get_variable.
|
||||||
|
partition_info: _PartitionInfo object.
|
||||||
reuse: see get_variable.
|
reuse: see get_variable.
|
||||||
trainable: see get_variable.
|
trainable: see get_variable.
|
||||||
collections: see get_variable.
|
collections: see get_variable.
|
||||||
@ -523,7 +661,8 @@ class _VariableStore(object):
|
|||||||
init_val = initializer
|
init_val = initializer
|
||||||
variable_dtype = None
|
variable_dtype = None
|
||||||
else:
|
else:
|
||||||
init_val = lambda: initializer(shape.as_list(), dtype=dtype)
|
init_val = lambda: initializer(
|
||||||
|
shape.as_list(), dtype=dtype, partition_info=partition_info)
|
||||||
variable_dtype = dtype.base_dtype
|
variable_dtype = dtype.base_dtype
|
||||||
|
|
||||||
# Create the variable.
|
# Create the variable.
|
||||||
|
@ -34,7 +34,8 @@ class SaverLargePartitionedVariableTest(tf.test.TestCase):
|
|||||||
with tf.device("/cpu:0"):
|
with tf.device("/cpu:0"):
|
||||||
# Create a partitioned variable which is larger than int32 size but
|
# Create a partitioned variable which is larger than int32 size but
|
||||||
# split into smaller sized variables.
|
# split into smaller sized variables.
|
||||||
init = lambda shape, dtype: tf.constant(True, dtype, shape)
|
init = lambda shape, dtype, partition_info: tf.constant(
|
||||||
|
True, dtype, shape)
|
||||||
partitioned_var = tf.create_partitioned_variables(
|
partitioned_var = tf.create_partitioned_variables(
|
||||||
[1 << 31], [4], init, dtype=tf.bool, name=var_name)
|
[1 << 31], [4], init, dtype=tf.bool, name=var_name)
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
|
Loading…
Reference in New Issue
Block a user