Modify some v2 initializers to be able to return a value that corresponds to a partition of the entire value. This is useful for efficiently initializing sharded variables where only a shard of the initial value is necessary at a time.

PiperOrigin-RevId: 338371904
Change-Id: Ib4320d73cbaec30f5a61793debe7755026175781
This commit is contained in:
Chenkai Kuang 2020-10-21 17:16:27 -07:00 committed by TensorFlower Gardener
parent d2c7a16c2d
commit 239fe406d3
4 changed files with 249 additions and 73 deletions

View File

@ -34,7 +34,7 @@ class Initializer(object):
signature: signature:
```python ```python
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
# returns a tensor of shape `shape` and dtype `dtype` # returns a tensor of shape `shape` and dtype `dtype`
# containing values drawn from a distribution of your choice. # containing values drawn from a distribution of your choice.
``` ```
@ -54,7 +54,7 @@ class Initializer(object):
self.mean = mean self.mean = mean
self.stddev = stddev self.stddev = stddev
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
return tf.random.normal( return tf.random.normal(
shape, mean=self.mean, stddev=self.stddev, dtype=dtype) shape, mean=self.mean, stddev=self.stddev, dtype=dtype)
@ -68,12 +68,13 @@ class Initializer(object):
works fine. works fine.
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. dtype: Optional dtype of the tensor.
**kwargs: Additional keyword arguments.
""" """
raise NotImplementedError raise NotImplementedError
@ -124,7 +125,7 @@ class Zeros(init_ops_v2.Zeros, Initializer):
>>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
@ -133,8 +134,9 @@ class Zeros(init_ops_v2.Zeros, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used, supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`). (via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
""" """
return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype)) return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[]) @keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[])
@ -154,7 +156,7 @@ class Ones(init_ops_v2.Ones, Initializer):
>>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer) >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
@ -163,8 +165,9 @@ class Ones(init_ops_v2.Ones, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used, supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`). (via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
""" """
return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype)) return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Constant', @keras_export('keras.initializers.Constant',
@ -196,7 +199,7 @@ class Constant(Initializer):
def __init__(self, value=0): def __init__(self, value=0):
self.value = value self.value = value
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to `self.value`. """Returns a tensor object initialized to `self.value`.
Args: Args:
@ -205,7 +208,9 @@ class Constant(Initializer):
`tf.keras.backend.floatx()` is used, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`). (via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
""" """
del kwargs
return constant_op.constant( return constant_op.constant(
self.value, dtype=_get_dtype(dtype), shape=shape) self.value, dtype=_get_dtype(dtype), shape=shape)
@ -241,7 +246,7 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
always produce the same random tensor for a given shape and dtype. always produce the same random tensor for a given shape and dtype.
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
@ -251,8 +256,10 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
`tf.keras.backend.floatx()` is used, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`). (via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
""" """
return super(RandomUniform, self).__call__(shape, dtype=_get_dtype(dtype)) return super(RandomUniform, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.RandomNormal', @keras_export('keras.initializers.RandomNormal',
@ -283,17 +290,19 @@ class RandomNormal(init_ops_v2.RandomNormal, Initializer):
always produce the same random tensor for a given shape and dtype. always produce the same random tensor for a given shape and dtype.
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to random normal values. """Returns a tensor object initialized to random normal values.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used, supported. If not specified, `tf.keras.backend.floatx()` is used, which
which default to `float32` unless you configured it otherwise default to `float32` unless you configured it otherwise (via
(via `tf.keras.backend.set_floatx(float_dtype)`) `tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
""" """
return super(RandomNormal, self).__call__(shape, dtype=_get_dtype(dtype)) return super(RandomNormal, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.TruncatedNormal', @keras_export('keras.initializers.TruncatedNormal',
@ -329,17 +338,19 @@ class TruncatedNormal(init_ops_v2.TruncatedNormal, Initializer):
always produce the same random tensor for a given shape and dtype. always produce the same random tensor for a given shape and dtype.
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to random normal values (truncated). """Returns a tensor object initialized to random normal values (truncated).
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used, supported. If not specified, `tf.keras.backend.floatx()` is used, which
which default to `float32` unless you configured it otherwise default to `float32` unless you configured it otherwise (via
(via `tf.keras.backend.set_floatx(float_dtype)`) `tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
""" """
return super(TruncatedNormal, self).__call__(shape, dtype=_get_dtype(dtype)) return super(TruncatedNormal, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.VarianceScaling', @keras_export('keras.initializers.VarianceScaling',
@ -384,17 +395,19 @@ class VarianceScaling(init_ops_v2.VarianceScaling, Initializer):
always produce the same random tensor for a given shape and dtype. always produce the same random tensor for a given shape and dtype.
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used, supported. If not specified, `tf.keras.backend.floatx()` is used, which
which default to `float32` unless you configured it otherwise default to `float32` unless you configured it otherwise (via
(via `tf.keras.backend.set_floatx(float_dtype)`) `tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
""" """
return super(VarianceScaling, self).__call__(shape, dtype=_get_dtype(dtype)) return super(VarianceScaling, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Orthogonal', @keras_export('keras.initializers.Orthogonal',
@ -436,7 +449,7 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer):
([pdf](https://arxiv.org/pdf/1312.6120.pdf)) ([pdf](https://arxiv.org/pdf/1312.6120.pdf))
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to an orthogonal matrix. """Returns a tensor object initialized to an orthogonal matrix.
Args: Args:
@ -445,8 +458,10 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used, supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`) (via `tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
""" """
return super(Orthogonal, self).__call__(shape, dtype=_get_dtype(dtype)) return super(Orthogonal, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Identity', @keras_export('keras.initializers.Identity',
@ -473,7 +488,7 @@ class Identity(init_ops_v2.Identity, Initializer):
gain: Multiplicative factor to apply to the identity matrix. gain: Multiplicative factor to apply to the identity matrix.
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to a 2D identity matrix. """Returns a tensor object initialized to a 2D identity matrix.
Args: Args:
@ -482,8 +497,10 @@ class Identity(init_ops_v2.Identity, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used, supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`) (via `tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
""" """
return super(Identity, self).__call__(shape, dtype=_get_dtype(dtype)) return super(Identity, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.GlorotUniform', @keras_export('keras.initializers.GlorotUniform',

View File

@ -253,6 +253,34 @@ class KerasInitializersTest(test.TestCase):
initializer = initializers.deserialize(external_serialized_json) initializer = initializers.deserialize(external_serialized_json)
self.assertEqual(initializer.distribution, 'truncated_normal') self.assertEqual(initializer.distribution, 'truncated_normal')
def test_partition(self):
with self.cached_session():
partition_enabled_initializers = [
initializers.ZerosV2(),
initializers.OnesV2(),
initializers.RandomUniformV2(),
initializers.RandomNormalV2(),
initializers.TruncatedNormalV2(),
initializers.LecunUniformV2(),
initializers.GlorotUniformV2(),
initializers.HeUniformV2()
]
for initializer in partition_enabled_initializers:
got = initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
self.assertEqual(got.shape, (2, 2))
partition_forbidden_initializers = [
initializers.OrthogonalV2(),
initializers.IdentityV2()
]
for initializer in partition_forbidden_initializers:
with self.assertRaisesRegex(
ValueError,
"initializer doesn't support partition-related arguments"):
initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
if __name__ == '__main__': if __name__ == '__main__':
test.main() test.main()

View File

@ -12,19 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Operations often used for initializing tensors. """Initializers for TF 2."""
All variable initializers returned by functions in this file should have the
following signature:
def _initializer(shape, dtype=dtypes.float32):
Args:
shape: List of `int` representing the shape of the output `Tensor`. Some
initializers may also be able to accept a `Tensor`.
dtype: (Optional) Type of the output `Tensor`.
Returns:
A `Tensor` of type `dtype` and `shape`.
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -44,18 +32,40 @@ from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops.init_ops import _compute_fans from tensorflow.python.ops.init_ops import _compute_fans
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
_PARTITION_SHAPE = "partition_shape"
_PARTITION_OFFSET = "partition_offset"
class Initializer(object): class Initializer(object):
"""Initializer base class: all initializers inherit from this class. """Initializer base class: all initializers inherit from this class.
Initializers should implement a `__call__` method with the following
signature:
```python
def __call__(self, shape, dtype=None, **kwargs):
# returns a tensor of shape `shape` and dtype `dtype`
# containing values drawn from a distribution of your choice.
```
""" """
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. If not provided will return tensor dtype: Optional dtype of the tensor. If not provided will return tensor
of `tf.float32`. of `tf.float32`.
**kwargs: Additional keyword arguments. Accepted values:
`partition_shape` and `partition_offset`. Used when creating a single
partition in a partitioned variable. `partition_shape` is the shape
of the partition (i.e. the shape of the returned tensor) and
`partition_offset` is a tuple of `int` specifying the offset of this
partition w.r.t each axis. For example, a tensor of shape `(30, 100)`
can be partitioned into two partitions: `p0` of shape `(10, 100)` and
`p1` of shape `(20, 100)`; if the initializer is called with
`partition_shape=(20, 100)` and `partition_offset=(10, 0)`, it should
return the value for `p1`.
""" """
raise NotImplementedError raise NotImplementedError
@ -89,6 +99,14 @@ class Initializer(object):
config.pop("dtype", None) config.pop("dtype", None)
return cls(**config) return cls(**config)
def _validate_kwargs(self, kwargs, support_partition=True):
for kwarg in kwargs:
if kwarg not in [_PARTITION_SHAPE, _PARTITION_OFFSET]:
raise TypeError("Unknown keyword arguments: %s" % kwarg)
elif not support_partition:
raise ValueError("%s initializer doesn't support partition-related"
" arguments" % self.__class__.__name__)
@tf_export("zeros_initializer", v1=[]) @tf_export("zeros_initializer", v1=[])
class Zeros(Initializer): class Zeros(Initializer):
@ -115,20 +133,24 @@ class Zeros(Initializer):
(<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ... (<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
""" """
def __call__(self, shape, dtype=dtypes.float32): def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
supported. supported.
**kwargs: Additional keyword arguments.
Raises: Raises:
ValuesError: If the dtype is not numeric or boolean. ValuesError: If the dtype is not numeric or boolean.
""" """
self._validate_kwargs(kwargs)
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
if not dtype.is_numpy_compatible or dtype == dtypes.string: if not dtype.is_numpy_compatible or dtype == dtypes.string:
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype) raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return array_ops.zeros(shape, dtype) return array_ops.zeros(shape, dtype)
@ -157,20 +179,24 @@ class Ones(Initializer):
(<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ... (<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
""" """
def __call__(self, shape, dtype=dtypes.float32): def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
supported. supported.
**kwargs: Additional keyword arguments.
Raises: Raises:
ValuesError: If the dtype is not numeric or boolean. ValuesError: If the dtype is not numeric or boolean.
""" """
self._validate_kwargs(kwargs)
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
if not dtype.is_numpy_compatible or dtype == dtypes.string: if not dtype.is_numpy_compatible or dtype == dtypes.string:
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype) raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return array_ops.ones(shape, dtype) return array_ops.ones(shape, dtype)
@ -245,22 +271,23 @@ class Constant(Initializer):
"tuple of values, or numpy.ndarray)." % type(value)) "tuple of values, or numpy.ndarray)." % type(value))
self.value = value self.value = value
def __call__(self, shape, dtype=None): def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. If not provided the dtype of the dtype: Optional dtype of the tensor. If not provided the dtype of the
tensor created will be the type of the inital value. tensor created will be the type of the inital value.
**kwargs: Additional keyword arguments.
Raises: Raises:
TypeError: If the initializer cannot create a tensor of the requested TypeError: If the initializer cannot create a tensor of the requested
dtype. dtype.
""" """
self._validate_kwargs(kwargs, support_partition=False)
if dtype is not None: if dtype is not None:
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
return constant_op.constant( return constant_op.constant(self.value, dtype=dtype, shape=shape)
self.value, dtype=dtype, shape=shape)
def get_config(self): def get_config(self):
return {"value": self.value} return {"value": self.value}
@ -305,20 +332,24 @@ class RandomUniform(Initializer):
self.seed = seed self.seed = seed
self._random_generator = _RandomGenerator(seed) self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32): def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point and integer dtype: Optional dtype of the tensor. Only floating point and integer
types are supported. types are supported.
**kwargs: Additional keyword arguments.
Raises: Raises:
ValueError: If the dtype is not numeric. ValueError: If the dtype is not numeric.
""" """
self._validate_kwargs(kwargs)
dtype = dtypes.as_dtype(dtype) dtype = dtypes.as_dtype(dtype)
if not dtype.is_floating and not dtype.is_integer: if not dtype.is_floating and not dtype.is_integer:
raise ValueError("Expected float or integer dtype, got %s." % dtype) raise ValueError("Expected float or integer dtype, got %s." % dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return self._random_generator.random_uniform(shape, self.minval, return self._random_generator.random_uniform(shape, self.minval,
self.maxval, dtype) self.maxval, dtype)
@ -369,18 +400,22 @@ class RandomNormal(Initializer):
self.seed = seed self.seed = seed
self._random_generator = _RandomGenerator(seed) self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32): def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are dtype: Optional dtype of the tensor. Only floating point types are
supported. supported.
**kwargs: Additional keyword arguments.
Raises: Raises:
ValueError: If the dtype is not floating point ValueError: If the dtype is not floating point
""" """
self._validate_kwargs(kwargs)
dtype = _assert_float_dtype(dtype) dtype = _assert_float_dtype(dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return self._random_generator.random_normal(shape, self.mean, self.stddev, return self._random_generator.random_normal(shape, self.mean, self.stddev,
dtype) dtype)
@ -434,18 +469,22 @@ class TruncatedNormal(Initializer):
self.seed = seed self.seed = seed
self._random_generator = _RandomGenerator(seed) self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32): def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are dtype: Optional dtype of the tensor. Only floating point types are
supported. supported.
**kwargs: Additional keyword arguments.
Raises: Raises:
ValueError: If the dtype is not floating point ValueError: If the dtype is not floating point
""" """
self._validate_kwargs(kwargs)
dtype = _assert_float_dtype(dtype) dtype = _assert_float_dtype(dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return self._random_generator.truncated_normal(shape, self.mean, return self._random_generator.truncated_normal(shape, self.mean,
self.stddev, dtype) self.stddev, dtype)
@ -525,24 +564,24 @@ class VarianceScaling(Initializer):
self.seed = seed self.seed = seed
self._random_generator = _RandomGenerator(seed) self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32): def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are dtype: Optional dtype of the tensor. Only floating point types are
supported. supported.
**kwargs: Additional keyword arguments.
Raises: Raises:
ValueError: If the dtype is not floating point ValueError: If the dtype is not floating point
""" """
partition_info = None # Keeps logic so can be readded later if necessary self._validate_kwargs(kwargs)
dtype = _assert_float_dtype(dtype) dtype = _assert_float_dtype(dtype)
scale = self.scale scale = self.scale
scale_shape = shape fan_in, fan_out = _compute_fans(shape)
if partition_info is not None: if _PARTITION_SHAPE in kwargs:
scale_shape = partition_info.full_shape shape = kwargs[_PARTITION_SHAPE]
fan_in, fan_out = _compute_fans(scale_shape)
if self.mode == "fan_in": if self.mode == "fan_in":
scale /= max(1., fan_in) scale /= max(1., fan_in)
elif self.mode == "fan_out": elif self.mode == "fan_out":
@ -616,18 +655,20 @@ class Orthogonal(Initializer):
self.seed = seed self.seed = seed
self._random_generator = _RandomGenerator(seed) self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32): def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are dtype: Optional dtype of the tensor. Only floating point types are
supported. supported.
**kwargs: Additional keyword arguments.
Raises: Raises:
ValueError: If the dtype is not floating point or the input shape is not ValueError: If the dtype is not floating point or the input shape is not
valid. valid.
""" """
self._validate_kwargs(kwargs, support_partition=False)
dtype = _assert_float_dtype(dtype) dtype = _assert_float_dtype(dtype)
# Check the shape # Check the shape
if len(shape) < 2: if len(shape) < 2:
@ -686,28 +727,25 @@ class Identity(Initializer):
def __init__(self, gain=1.0): def __init__(self, gain=1.0):
self.gain = gain self.gain = gain
def __call__(self, shape, dtype=dtypes.float32): def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer. """Returns a tensor object initialized as specified by the initializer.
Args: Args:
shape: Shape of the tensor. shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are dtype: Optional dtype of the tensor. Only floating point types are
supported. supported.
**kwargs: Additional keyword arguments.
Raises: Raises:
ValueError: If the dtype is not floating point ValueError: If the dtype is not floating point
ValueError: If the requested shape does not have exactly two axes. ValueError: If the requested shape does not have exactly two axes.
""" """
partition_info = None # Keeps logic so can be readded later if necessary self._validate_kwargs(kwargs, support_partition=False)
dtype = _assert_float_dtype(dtype) dtype = _assert_float_dtype(dtype)
full_shape = shape if partition_info is None else partition_info.full_shape if len(shape) != 2:
if len(full_shape) != 2:
raise ValueError( raise ValueError(
"Identity matrix initializer can only be used for 2D matrices.") "Identity matrix initializer can only be used for 2D matrices.")
initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype) initializer = linalg_ops_impl.eye(*shape, dtype=dtype)
if partition_info is not None:
initializer = array_ops.slice(initializer, partition_info.var_offset,
shape)
return self.gain * initializer return self.gain * initializer
def get_config(self): def get_config(self):

View File

@ -78,6 +78,21 @@ class InitializersTest(test.TestCase):
if target_min is not None: if target_min is not None:
self.assertGreater(lim, abs(output.min() - target_min)) self.assertGreater(lim, abs(output.min() - target_min))
def _partition_test(self, init):
full_shape = (4, 2)
partition_shape = (2, 2)
partition_offset = (0, 0)
full_value = self.evaluate(init(full_shape, dtype=dtypes.float32))
got = self.evaluate(
init(
full_shape,
dtype=dtypes.float32,
partition_shape=partition_shape,
partition_offset=partition_offset))
self.assertEqual(got.shape, partition_shape)
self.assertAllClose(
got, array_ops.slice(full_value, partition_offset, partition_shape))
class ConstantInitializersTest(InitializersTest): class ConstantInitializersTest(InitializersTest):
@ -86,11 +101,28 @@ class ConstantInitializersTest(InitializersTest):
self._range_test(init_ops_v2.Zeros(), shape=(4, 5), self._range_test(init_ops_v2.Zeros(), shape=(4, 5),
target_mean=0., target_max=0.) target_mean=0., target_max=0.)
@test_util.run_in_graph_and_eager_modes
def testZerosPartition(self):
init = init_ops_v2.Zeros()
self._partition_test(init)
@test_util.run_in_graph_and_eager_modes
def testZerosInvalidKwargs(self):
init = init_ops_v2.Zeros()
with self.assertRaisesWithLiteralMatch(TypeError,
r"Unknown keyword arguments: dtpye"):
init((2, 2), dtpye=dtypes.float32)
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testOnes(self): def testOnes(self):
self._range_test(init_ops_v2.Ones(), shape=(4, 5), self._range_test(init_ops_v2.Ones(), shape=(4, 5),
target_mean=1., target_max=1.) target_mean=1., target_max=1.)
@test_util.run_in_graph_and_eager_modes
def testOnesPartition(self):
init = init_ops_v2.Ones()
self._partition_test(init)
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testConstantInt(self): def testConstantInt(self):
self._range_test( self._range_test(
@ -100,6 +132,14 @@ class ConstantInitializersTest(InitializersTest):
target_max=2, target_max=2,
target_min=2) target_min=2)
@test_util.run_in_graph_and_eager_modes
def testConstantPartition(self):
init = init_ops_v2.Constant([1, 2, 3, 4])
with self.assertRaisesWithLiteralMatch(
ValueError,
r"Constant initializer doesn't support partition-related arguments"):
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
@test_util.run_in_graph_and_eager_modes @test_util.run_in_graph_and_eager_modes
def testConstantTuple(self): def testConstantTuple(self):
init = init_ops_v2.constant_initializer((10, 20, 30)) init = init_ops_v2.constant_initializer((10, 20, 30))
@ -188,6 +228,11 @@ class RandomUniformInitializerTest(InitializersTest):
init = init_ops_v2.RandomUniform(0.0, 1.0) init = init_ops_v2.RandomUniform(0.0, 1.0)
self._duplicated_test(init) self._duplicated_test(init)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
init = init_ops_v2.RandomUniform(0, 7, seed=1)
self._partition_test(init)
class RandomNormalInitializerTest(InitializersTest): class RandomNormalInitializerTest(InitializersTest):
@ -217,6 +262,14 @@ class RandomNormalInitializerTest(InitializersTest):
init = init_ops_v2.RandomNormal(0.0, 1.0) init = init_ops_v2.RandomNormal(0.0, 1.0)
self._duplicated_test(init) self._duplicated_test(init)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
if test_util.is_xla_enabled():
self.skipTest(
"XLA ignores seeds for RandomNormal, skip xla-enabled test.")
init = init_ops_v2.RandomNormal(0, 7, seed=1)
self._partition_test(init)
class TruncatedNormalInitializerTest(InitializersTest): class TruncatedNormalInitializerTest(InitializersTest):
@ -247,6 +300,12 @@ class TruncatedNormalInitializerTest(InitializersTest):
init = init_ops_v2.TruncatedNormal(0.0, 1.0) init = init_ops_v2.TruncatedNormal(0.0, 1.0)
self._duplicated_test(init) self._duplicated_test(init)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
init = init_ops_v2.TruncatedNormal(0.0, 1.0, seed=1)
self._partition_test(init)
@test_util.run_in_graph_and_eager_modes
def testInvalidDataType(self): def testInvalidDataType(self):
init = init_ops_v2.TruncatedNormal(0.0, 1.0) init = init_ops_v2.TruncatedNormal(0.0, 1.0)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -317,6 +376,24 @@ class VarianceScalingInitializerTest(InitializersTest):
self.assertNear(np.mean(x), expect_mean, err=1e-2) self.assertNear(np.mean(x), expect_mean, err=1e-2)
self.assertNear(np.var(x), expect_var, err=1e-2) self.assertNear(np.var(x), expect_var, err=1e-2)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
partition_shape = (100, 100)
shape = [1000, 100]
expect_mean = 0.
expect_var = 1. / shape[0]
init = init_ops_v2.VarianceScaling(distribution="untruncated_normal")
with test_util.use_gpu(), test.mock.patch.object(
random_ops, "random_normal",
wraps=random_ops.random_normal) as mock_random_normal:
x = self.evaluate(init(shape, partition_shape=partition_shape))
self.assertTrue(mock_random_normal.called)
self.assertEqual(x.shape, partition_shape)
self.assertNear(np.mean(x), expect_mean, err=1e-3)
self.assertNear(np.var(x), expect_var, err=1e-3)
class OrthogonalInitializerTest(InitializersTest): class OrthogonalInitializerTest(InitializersTest):
@ -386,6 +463,14 @@ class OrthogonalInitializerTest(InitializersTest):
self.assertAllClose( self.assertAllClose(
np.dot(t, t.T), np.eye(t.shape[0]), rtol=tol, atol=tol) np.dot(t, t.T), np.eye(t.shape[0]), rtol=tol, atol=tol)
@test_util.run_in_graph_and_eager_modes
def testPartition(self):
init = init_ops_v2.Orthogonal(seed=1)
with self.assertRaisesWithLiteralMatch(
ValueError,
r"Orthogonal initializer doesn't support partition-related arguments"):
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
class IdentityInitializerTest(InitializersTest): class IdentityInitializerTest(InitializersTest):
@ -439,6 +524,14 @@ class IdentityInitializerTest(InitializersTest):
self.assertAllClose(self.evaluate(init_custom(shape, dtype=dtype)), self.assertAllClose(self.evaluate(init_custom(shape, dtype=dtype)),
np.eye(*shape) * 0.9) np.eye(*shape) * 0.9)
@test_util.run_in_graph_and_eager_modes
def testPartition(self):
init = init_ops_v2.Identity()
with self.assertRaisesWithLiteralMatch(
ValueError,
r"Identity initializer doesn't support partition-related arguments"):
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
class GlorotInitializersTest(InitializersTest): class GlorotInitializersTest(InitializersTest):