Modify some v2 initializers to be able to return a value that corresponds to a partition of the entire value. This is useful for efficiently initializing sharded variables where only a shard of the initial value is necessary at a time.
PiperOrigin-RevId: 338371904 Change-Id: Ib4320d73cbaec30f5a61793debe7755026175781
This commit is contained in:
parent
d2c7a16c2d
commit
239fe406d3
tensorflow/python
@ -34,7 +34,7 @@ class Initializer(object):
|
||||
signature:
|
||||
|
||||
```python
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
# returns a tensor of shape `shape` and dtype `dtype`
|
||||
# containing values drawn from a distribution of your choice.
|
||||
```
|
||||
@ -54,7 +54,7 @@ class Initializer(object):
|
||||
self.mean = mean
|
||||
self.stddev = stddev
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
return tf.random.normal(
|
||||
shape, mean=self.mean, stddev=self.stddev, dtype=dtype)
|
||||
|
||||
@ -68,12 +68,13 @@ class Initializer(object):
|
||||
works fine.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -124,7 +125,7 @@ class Zeros(init_ops_v2.Zeros, Initializer):
|
||||
>>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
@ -133,8 +134,9 @@ class Zeros(init_ops_v2.Zeros, Initializer):
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype))
|
||||
return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[])
|
||||
@ -154,7 +156,7 @@ class Ones(init_ops_v2.Ones, Initializer):
|
||||
>>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
@ -163,8 +165,9 @@ class Ones(init_ops_v2.Ones, Initializer):
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype))
|
||||
return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.Constant',
|
||||
@ -196,7 +199,7 @@ class Constant(Initializer):
|
||||
def __init__(self, value=0):
|
||||
self.value = value
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized to `self.value`.
|
||||
|
||||
Args:
|
||||
@ -205,7 +208,9 @@ class Constant(Initializer):
|
||||
`tf.keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
del kwargs
|
||||
return constant_op.constant(
|
||||
self.value, dtype=_get_dtype(dtype), shape=shape)
|
||||
|
||||
@ -241,7 +246,7 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
|
||||
always produce the same random tensor for a given shape and dtype.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
@ -251,8 +256,10 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
|
||||
`tf.keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(RandomUniform, self).__call__(shape, dtype=_get_dtype(dtype))
|
||||
return super(RandomUniform, self).__call__(
|
||||
shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.RandomNormal',
|
||||
@ -283,17 +290,19 @@ class RandomNormal(init_ops_v2.RandomNormal, Initializer):
|
||||
always produce the same random tensor for a given shape and dtype.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized to random normal values.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only floating point types are
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`)
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used, which
|
||||
default to `float32` unless you configured it otherwise (via
|
||||
`tf.keras.backend.set_floatx(float_dtype)`)
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(RandomNormal, self).__call__(shape, dtype=_get_dtype(dtype))
|
||||
return super(RandomNormal, self).__call__(
|
||||
shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.TruncatedNormal',
|
||||
@ -329,17 +338,19 @@ class TruncatedNormal(init_ops_v2.TruncatedNormal, Initializer):
|
||||
always produce the same random tensor for a given shape and dtype.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized to random normal values (truncated).
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only floating point types are
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`)
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used, which
|
||||
default to `float32` unless you configured it otherwise (via
|
||||
`tf.keras.backend.set_floatx(float_dtype)`)
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(TruncatedNormal, self).__call__(shape, dtype=_get_dtype(dtype))
|
||||
return super(TruncatedNormal, self).__call__(
|
||||
shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.VarianceScaling',
|
||||
@ -384,17 +395,19 @@ class VarianceScaling(init_ops_v2.VarianceScaling, Initializer):
|
||||
always produce the same random tensor for a given shape and dtype.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only floating point types are
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`)
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used, which
|
||||
default to `float32` unless you configured it otherwise (via
|
||||
`tf.keras.backend.set_floatx(float_dtype)`)
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(VarianceScaling, self).__call__(shape, dtype=_get_dtype(dtype))
|
||||
return super(VarianceScaling, self).__call__(
|
||||
shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.Orthogonal',
|
||||
@ -436,7 +449,7 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer):
|
||||
([pdf](https://arxiv.org/pdf/1312.6120.pdf))
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized to an orthogonal matrix.
|
||||
|
||||
Args:
|
||||
@ -445,8 +458,10 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer):
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`)
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(Orthogonal, self).__call__(shape, dtype=_get_dtype(dtype))
|
||||
return super(Orthogonal, self).__call__(
|
||||
shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.Identity',
|
||||
@ -473,7 +488,7 @@ class Identity(init_ops_v2.Identity, Initializer):
|
||||
gain: Multiplicative factor to apply to the identity matrix.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized to a 2D identity matrix.
|
||||
|
||||
Args:
|
||||
@ -482,8 +497,10 @@ class Identity(init_ops_v2.Identity, Initializer):
|
||||
supported. If not specified, `tf.keras.backend.floatx()` is used,
|
||||
which default to `float32` unless you configured it otherwise
|
||||
(via `tf.keras.backend.set_floatx(float_dtype)`)
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
return super(Identity, self).__call__(shape, dtype=_get_dtype(dtype))
|
||||
return super(Identity, self).__call__(
|
||||
shape, dtype=_get_dtype(dtype), **kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.initializers.GlorotUniform',
|
||||
|
@ -253,6 +253,34 @@ class KerasInitializersTest(test.TestCase):
|
||||
initializer = initializers.deserialize(external_serialized_json)
|
||||
self.assertEqual(initializer.distribution, 'truncated_normal')
|
||||
|
||||
def test_partition(self):
|
||||
with self.cached_session():
|
||||
partition_enabled_initializers = [
|
||||
initializers.ZerosV2(),
|
||||
initializers.OnesV2(),
|
||||
initializers.RandomUniformV2(),
|
||||
initializers.RandomNormalV2(),
|
||||
initializers.TruncatedNormalV2(),
|
||||
initializers.LecunUniformV2(),
|
||||
initializers.GlorotUniformV2(),
|
||||
initializers.HeUniformV2()
|
||||
]
|
||||
for initializer in partition_enabled_initializers:
|
||||
got = initializer(
|
||||
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
|
||||
self.assertEqual(got.shape, (2, 2))
|
||||
|
||||
partition_forbidden_initializers = [
|
||||
initializers.OrthogonalV2(),
|
||||
initializers.IdentityV2()
|
||||
]
|
||||
for initializer in partition_forbidden_initializers:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"initializer doesn't support partition-related arguments"):
|
||||
initializer(
|
||||
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -12,19 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Operations often used for initializing tensors.
|
||||
|
||||
All variable initializers returned by functions in this file should have the
|
||||
following signature:
|
||||
|
||||
def _initializer(shape, dtype=dtypes.float32):
|
||||
Args:
|
||||
shape: List of `int` representing the shape of the output `Tensor`. Some
|
||||
initializers may also be able to accept a `Tensor`.
|
||||
dtype: (Optional) Type of the output `Tensor`.
|
||||
Returns:
|
||||
A `Tensor` of type `dtype` and `shape`.
|
||||
"""
|
||||
"""Initializers for TF 2."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
@ -44,18 +32,40 @@ from tensorflow.python.ops import stateless_random_ops
|
||||
from tensorflow.python.ops.init_ops import _compute_fans
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
_PARTITION_SHAPE = "partition_shape"
|
||||
_PARTITION_OFFSET = "partition_offset"
|
||||
|
||||
|
||||
class Initializer(object):
|
||||
"""Initializer base class: all initializers inherit from this class.
|
||||
|
||||
Initializers should implement a `__call__` method with the following
|
||||
signature:
|
||||
|
||||
```python
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
# returns a tensor of shape `shape` and dtype `dtype`
|
||||
# containing values drawn from a distribution of your choice.
|
||||
```
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. If not provided will return tensor
|
||||
of `tf.float32`.
|
||||
of `tf.float32`.
|
||||
**kwargs: Additional keyword arguments. Accepted values:
|
||||
`partition_shape` and `partition_offset`. Used when creating a single
|
||||
partition in a partitioned variable. `partition_shape` is the shape
|
||||
of the partition (i.e. the shape of the returned tensor) and
|
||||
`partition_offset` is a tuple of `int` specifying the offset of this
|
||||
partition w.r.t each axis. For example, a tensor of shape `(30, 100)`
|
||||
can be partitioned into two partitions: `p0` of shape `(10, 100)` and
|
||||
`p1` of shape `(20, 100)`; if the initializer is called with
|
||||
`partition_shape=(20, 100)` and `partition_offset=(10, 0)`, it should
|
||||
return the value for `p1`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -89,6 +99,14 @@ class Initializer(object):
|
||||
config.pop("dtype", None)
|
||||
return cls(**config)
|
||||
|
||||
def _validate_kwargs(self, kwargs, support_partition=True):
|
||||
for kwarg in kwargs:
|
||||
if kwarg not in [_PARTITION_SHAPE, _PARTITION_OFFSET]:
|
||||
raise TypeError("Unknown keyword arguments: %s" % kwarg)
|
||||
elif not support_partition:
|
||||
raise ValueError("%s initializer doesn't support partition-related"
|
||||
" arguments" % self.__class__.__name__)
|
||||
|
||||
|
||||
@tf_export("zeros_initializer", v1=[])
|
||||
class Zeros(Initializer):
|
||||
@ -115,20 +133,24 @@ class Zeros(Initializer):
|
||||
(<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=dtypes.float32):
|
||||
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
|
||||
supported.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValuesError: If the dtype is not numeric or boolean.
|
||||
"""
|
||||
self._validate_kwargs(kwargs)
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if not dtype.is_numpy_compatible or dtype == dtypes.string:
|
||||
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
return array_ops.zeros(shape, dtype)
|
||||
|
||||
|
||||
@ -157,20 +179,24 @@ class Ones(Initializer):
|
||||
(<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=dtypes.float32):
|
||||
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
|
||||
supported.
|
||||
supported.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValuesError: If the dtype is not numeric or boolean.
|
||||
"""
|
||||
self._validate_kwargs(kwargs)
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if not dtype.is_numpy_compatible or dtype == dtypes.string:
|
||||
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
return array_ops.ones(shape, dtype)
|
||||
|
||||
|
||||
@ -245,22 +271,23 @@ class Constant(Initializer):
|
||||
"tuple of values, or numpy.ndarray)." % type(value))
|
||||
self.value = value
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. If not provided the dtype of the
|
||||
tensor created will be the type of the inital value.
|
||||
tensor created will be the type of the inital value.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
TypeError: If the initializer cannot create a tensor of the requested
|
||||
dtype.
|
||||
"""
|
||||
self._validate_kwargs(kwargs, support_partition=False)
|
||||
if dtype is not None:
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
return constant_op.constant(
|
||||
self.value, dtype=dtype, shape=shape)
|
||||
return constant_op.constant(self.value, dtype=dtype, shape=shape)
|
||||
|
||||
def get_config(self):
|
||||
return {"value": self.value}
|
||||
@ -305,20 +332,24 @@ class RandomUniform(Initializer):
|
||||
self.seed = seed
|
||||
self._random_generator = _RandomGenerator(seed)
|
||||
|
||||
def __call__(self, shape, dtype=dtypes.float32):
|
||||
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only floating point and integer
|
||||
types are supported.
|
||||
types are supported.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dtype is not numeric.
|
||||
"""
|
||||
self._validate_kwargs(kwargs)
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if not dtype.is_floating and not dtype.is_integer:
|
||||
raise ValueError("Expected float or integer dtype, got %s." % dtype)
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
return self._random_generator.random_uniform(shape, self.minval,
|
||||
self.maxval, dtype)
|
||||
|
||||
@ -369,18 +400,22 @@ class RandomNormal(Initializer):
|
||||
self.seed = seed
|
||||
self._random_generator = _RandomGenerator(seed)
|
||||
|
||||
def __call__(self, shape, dtype=dtypes.float32):
|
||||
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only floating point types are
|
||||
supported.
|
||||
supported.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dtype is not floating point
|
||||
"""
|
||||
self._validate_kwargs(kwargs)
|
||||
dtype = _assert_float_dtype(dtype)
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
return self._random_generator.random_normal(shape, self.mean, self.stddev,
|
||||
dtype)
|
||||
|
||||
@ -434,18 +469,22 @@ class TruncatedNormal(Initializer):
|
||||
self.seed = seed
|
||||
self._random_generator = _RandomGenerator(seed)
|
||||
|
||||
def __call__(self, shape, dtype=dtypes.float32):
|
||||
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only floating point types are
|
||||
supported.
|
||||
supported.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dtype is not floating point
|
||||
"""
|
||||
self._validate_kwargs(kwargs)
|
||||
dtype = _assert_float_dtype(dtype)
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
return self._random_generator.truncated_normal(shape, self.mean,
|
||||
self.stddev, dtype)
|
||||
|
||||
@ -525,24 +564,24 @@ class VarianceScaling(Initializer):
|
||||
self.seed = seed
|
||||
self._random_generator = _RandomGenerator(seed)
|
||||
|
||||
def __call__(self, shape, dtype=dtypes.float32):
|
||||
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only floating point types are
|
||||
supported.
|
||||
supported.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dtype is not floating point
|
||||
"""
|
||||
partition_info = None # Keeps logic so can be readded later if necessary
|
||||
self._validate_kwargs(kwargs)
|
||||
dtype = _assert_float_dtype(dtype)
|
||||
scale = self.scale
|
||||
scale_shape = shape
|
||||
if partition_info is not None:
|
||||
scale_shape = partition_info.full_shape
|
||||
fan_in, fan_out = _compute_fans(scale_shape)
|
||||
fan_in, fan_out = _compute_fans(shape)
|
||||
if _PARTITION_SHAPE in kwargs:
|
||||
shape = kwargs[_PARTITION_SHAPE]
|
||||
if self.mode == "fan_in":
|
||||
scale /= max(1., fan_in)
|
||||
elif self.mode == "fan_out":
|
||||
@ -616,18 +655,20 @@ class Orthogonal(Initializer):
|
||||
self.seed = seed
|
||||
self._random_generator = _RandomGenerator(seed)
|
||||
|
||||
def __call__(self, shape, dtype=dtypes.float32):
|
||||
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only floating point types are
|
||||
supported.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dtype is not floating point or the input shape is not
|
||||
valid.
|
||||
"""
|
||||
self._validate_kwargs(kwargs, support_partition=False)
|
||||
dtype = _assert_float_dtype(dtype)
|
||||
# Check the shape
|
||||
if len(shape) < 2:
|
||||
@ -686,28 +727,25 @@ class Identity(Initializer):
|
||||
def __init__(self, gain=1.0):
|
||||
self.gain = gain
|
||||
|
||||
def __call__(self, shape, dtype=dtypes.float32):
|
||||
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
|
||||
"""Returns a tensor object initialized as specified by the initializer.
|
||||
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only floating point types are
|
||||
supported.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the dtype is not floating point
|
||||
ValueError: If the requested shape does not have exactly two axes.
|
||||
"""
|
||||
partition_info = None # Keeps logic so can be readded later if necessary
|
||||
self._validate_kwargs(kwargs, support_partition=False)
|
||||
dtype = _assert_float_dtype(dtype)
|
||||
full_shape = shape if partition_info is None else partition_info.full_shape
|
||||
if len(full_shape) != 2:
|
||||
if len(shape) != 2:
|
||||
raise ValueError(
|
||||
"Identity matrix initializer can only be used for 2D matrices.")
|
||||
initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
|
||||
if partition_info is not None:
|
||||
initializer = array_ops.slice(initializer, partition_info.var_offset,
|
||||
shape)
|
||||
initializer = linalg_ops_impl.eye(*shape, dtype=dtype)
|
||||
return self.gain * initializer
|
||||
|
||||
def get_config(self):
|
||||
|
@ -78,6 +78,21 @@ class InitializersTest(test.TestCase):
|
||||
if target_min is not None:
|
||||
self.assertGreater(lim, abs(output.min() - target_min))
|
||||
|
||||
def _partition_test(self, init):
|
||||
full_shape = (4, 2)
|
||||
partition_shape = (2, 2)
|
||||
partition_offset = (0, 0)
|
||||
full_value = self.evaluate(init(full_shape, dtype=dtypes.float32))
|
||||
got = self.evaluate(
|
||||
init(
|
||||
full_shape,
|
||||
dtype=dtypes.float32,
|
||||
partition_shape=partition_shape,
|
||||
partition_offset=partition_offset))
|
||||
self.assertEqual(got.shape, partition_shape)
|
||||
self.assertAllClose(
|
||||
got, array_ops.slice(full_value, partition_offset, partition_shape))
|
||||
|
||||
|
||||
class ConstantInitializersTest(InitializersTest):
|
||||
|
||||
@ -86,11 +101,28 @@ class ConstantInitializersTest(InitializersTest):
|
||||
self._range_test(init_ops_v2.Zeros(), shape=(4, 5),
|
||||
target_mean=0., target_max=0.)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testZerosPartition(self):
|
||||
init = init_ops_v2.Zeros()
|
||||
self._partition_test(init)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testZerosInvalidKwargs(self):
|
||||
init = init_ops_v2.Zeros()
|
||||
with self.assertRaisesWithLiteralMatch(TypeError,
|
||||
r"Unknown keyword arguments: dtpye"):
|
||||
init((2, 2), dtpye=dtypes.float32)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testOnes(self):
|
||||
self._range_test(init_ops_v2.Ones(), shape=(4, 5),
|
||||
target_mean=1., target_max=1.)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testOnesPartition(self):
|
||||
init = init_ops_v2.Ones()
|
||||
self._partition_test(init)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConstantInt(self):
|
||||
self._range_test(
|
||||
@ -100,6 +132,14 @@ class ConstantInitializersTest(InitializersTest):
|
||||
target_max=2,
|
||||
target_min=2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConstantPartition(self):
|
||||
init = init_ops_v2.Constant([1, 2, 3, 4])
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
r"Constant initializer doesn't support partition-related arguments"):
|
||||
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testConstantTuple(self):
|
||||
init = init_ops_v2.constant_initializer((10, 20, 30))
|
||||
@ -188,6 +228,11 @@ class RandomUniformInitializerTest(InitializersTest):
|
||||
init = init_ops_v2.RandomUniform(0.0, 1.0)
|
||||
self._duplicated_test(init)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testInitializePartition(self):
|
||||
init = init_ops_v2.RandomUniform(0, 7, seed=1)
|
||||
self._partition_test(init)
|
||||
|
||||
|
||||
class RandomNormalInitializerTest(InitializersTest):
|
||||
|
||||
@ -217,6 +262,14 @@ class RandomNormalInitializerTest(InitializersTest):
|
||||
init = init_ops_v2.RandomNormal(0.0, 1.0)
|
||||
self._duplicated_test(init)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testInitializePartition(self):
|
||||
if test_util.is_xla_enabled():
|
||||
self.skipTest(
|
||||
"XLA ignores seeds for RandomNormal, skip xla-enabled test.")
|
||||
init = init_ops_v2.RandomNormal(0, 7, seed=1)
|
||||
self._partition_test(init)
|
||||
|
||||
|
||||
class TruncatedNormalInitializerTest(InitializersTest):
|
||||
|
||||
@ -247,6 +300,12 @@ class TruncatedNormalInitializerTest(InitializersTest):
|
||||
init = init_ops_v2.TruncatedNormal(0.0, 1.0)
|
||||
self._duplicated_test(init)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testInitializePartition(self):
|
||||
init = init_ops_v2.TruncatedNormal(0.0, 1.0, seed=1)
|
||||
self._partition_test(init)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testInvalidDataType(self):
|
||||
init = init_ops_v2.TruncatedNormal(0.0, 1.0)
|
||||
with self.assertRaises(ValueError):
|
||||
@ -317,6 +376,24 @@ class VarianceScalingInitializerTest(InitializersTest):
|
||||
self.assertNear(np.mean(x), expect_mean, err=1e-2)
|
||||
self.assertNear(np.var(x), expect_var, err=1e-2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testInitializePartition(self):
|
||||
partition_shape = (100, 100)
|
||||
shape = [1000, 100]
|
||||
expect_mean = 0.
|
||||
expect_var = 1. / shape[0]
|
||||
init = init_ops_v2.VarianceScaling(distribution="untruncated_normal")
|
||||
|
||||
with test_util.use_gpu(), test.mock.patch.object(
|
||||
random_ops, "random_normal",
|
||||
wraps=random_ops.random_normal) as mock_random_normal:
|
||||
x = self.evaluate(init(shape, partition_shape=partition_shape))
|
||||
self.assertTrue(mock_random_normal.called)
|
||||
|
||||
self.assertEqual(x.shape, partition_shape)
|
||||
self.assertNear(np.mean(x), expect_mean, err=1e-3)
|
||||
self.assertNear(np.var(x), expect_var, err=1e-3)
|
||||
|
||||
|
||||
class OrthogonalInitializerTest(InitializersTest):
|
||||
|
||||
@ -386,6 +463,14 @@ class OrthogonalInitializerTest(InitializersTest):
|
||||
self.assertAllClose(
|
||||
np.dot(t, t.T), np.eye(t.shape[0]), rtol=tol, atol=tol)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testPartition(self):
|
||||
init = init_ops_v2.Orthogonal(seed=1)
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
r"Orthogonal initializer doesn't support partition-related arguments"):
|
||||
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
|
||||
|
||||
|
||||
class IdentityInitializerTest(InitializersTest):
|
||||
|
||||
@ -439,6 +524,14 @@ class IdentityInitializerTest(InitializersTest):
|
||||
self.assertAllClose(self.evaluate(init_custom(shape, dtype=dtype)),
|
||||
np.eye(*shape) * 0.9)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testPartition(self):
|
||||
init = init_ops_v2.Identity()
|
||||
with self.assertRaisesWithLiteralMatch(
|
||||
ValueError,
|
||||
r"Identity initializer doesn't support partition-related arguments"):
|
||||
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
|
||||
|
||||
|
||||
class GlorotInitializersTest(InitializersTest):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user