Modify some v2 initializers to be able to return a value that corresponds to a partition of the entire value. This is useful for efficiently initializing sharded variables where only a shard of the initial value is necessary at a time.

PiperOrigin-RevId: 338371904
Change-Id: Ib4320d73cbaec30f5a61793debe7755026175781
This commit is contained in:
Chenkai Kuang 2020-10-21 17:16:27 -07:00 committed by TensorFlower Gardener
parent d2c7a16c2d
commit 239fe406d3
4 changed files with 249 additions and 73 deletions

View File

@ -34,7 +34,7 @@ class Initializer(object):
signature:
```python
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
# returns a tensor of shape `shape` and dtype `dtype`
# containing values drawn from a distribution of your choice.
```
@ -54,7 +54,7 @@ class Initializer(object):
self.mean = mean
self.stddev = stddev
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
return tf.random.normal(
shape, mean=self.mean, stddev=self.stddev, dtype=dtype)
@ -68,12 +68,13 @@ class Initializer(object):
works fine.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor.
**kwargs: Additional keyword arguments.
"""
raise NotImplementedError
@ -124,7 +125,7 @@ class Zeros(init_ops_v2.Zeros, Initializer):
>>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
@ -133,8 +134,9 @@ class Zeros(init_ops_v2.Zeros, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype))
return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[])
@ -154,7 +156,7 @@ class Ones(init_ops_v2.Ones, Initializer):
>>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
@ -163,8 +165,9 @@ class Ones(init_ops_v2.Ones, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype))
return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Constant',
@ -196,7 +199,7 @@ class Constant(Initializer):
def __init__(self, value=0):
self.value = value
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to `self.value`.
Args:
@ -205,7 +208,9 @@ class Constant(Initializer):
`tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
del kwargs
return constant_op.constant(
self.value, dtype=_get_dtype(dtype), shape=shape)
@ -241,7 +246,7 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
always produce the same random tensor for a given shape and dtype.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
@ -251,8 +256,10 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
`tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
return super(RandomUniform, self).__call__(shape, dtype=_get_dtype(dtype))
return super(RandomUniform, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.RandomNormal',
@ -283,17 +290,19 @@ class RandomNormal(init_ops_v2.RandomNormal, Initializer):
always produce the same random tensor for a given shape and dtype.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to random normal values.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
supported. If not specified, `tf.keras.backend.floatx()` is used, which
default to `float32` unless you configured it otherwise (via
`tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(RandomNormal, self).__call__(shape, dtype=_get_dtype(dtype))
return super(RandomNormal, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.TruncatedNormal',
@ -329,17 +338,19 @@ class TruncatedNormal(init_ops_v2.TruncatedNormal, Initializer):
always produce the same random tensor for a given shape and dtype.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to random normal values (truncated).
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
supported. If not specified, `tf.keras.backend.floatx()` is used, which
default to `float32` unless you configured it otherwise (via
`tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(TruncatedNormal, self).__call__(shape, dtype=_get_dtype(dtype))
return super(TruncatedNormal, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.VarianceScaling',
@ -384,17 +395,19 @@ class VarianceScaling(init_ops_v2.VarianceScaling, Initializer):
always produce the same random tensor for a given shape and dtype.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
supported. If not specified, `tf.keras.backend.floatx()` is used, which
default to `float32` unless you configured it otherwise (via
`tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(VarianceScaling, self).__call__(shape, dtype=_get_dtype(dtype))
return super(VarianceScaling, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Orthogonal',
@ -436,7 +449,7 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer):
([pdf](https://arxiv.org/pdf/1312.6120.pdf))
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to an orthogonal matrix.
Args:
@ -445,8 +458,10 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(Orthogonal, self).__call__(shape, dtype=_get_dtype(dtype))
return super(Orthogonal, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Identity',
@ -473,7 +488,7 @@ class Identity(init_ops_v2.Identity, Initializer):
gain: Multiplicative factor to apply to the identity matrix.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to a 2D identity matrix.
Args:
@ -482,8 +497,10 @@ class Identity(init_ops_v2.Identity, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(Identity, self).__call__(shape, dtype=_get_dtype(dtype))
return super(Identity, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.GlorotUniform',

View File

@ -253,6 +253,34 @@ class KerasInitializersTest(test.TestCase):
initializer = initializers.deserialize(external_serialized_json)
self.assertEqual(initializer.distribution, 'truncated_normal')
def test_partition(self):
with self.cached_session():
partition_enabled_initializers = [
initializers.ZerosV2(),
initializers.OnesV2(),
initializers.RandomUniformV2(),
initializers.RandomNormalV2(),
initializers.TruncatedNormalV2(),
initializers.LecunUniformV2(),
initializers.GlorotUniformV2(),
initializers.HeUniformV2()
]
for initializer in partition_enabled_initializers:
got = initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
self.assertEqual(got.shape, (2, 2))
partition_forbidden_initializers = [
initializers.OrthogonalV2(),
initializers.IdentityV2()
]
for initializer in partition_forbidden_initializers:
with self.assertRaisesRegex(
ValueError,
"initializer doesn't support partition-related arguments"):
initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
if __name__ == '__main__':
test.main()

View File

@ -12,19 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations often used for initializing tensors.
All variable initializers returned by functions in this file should have the
following signature:
def _initializer(shape, dtype=dtypes.float32):
Args:
shape: List of `int` representing the shape of the output `Tensor`. Some
initializers may also be able to accept a `Tensor`.
dtype: (Optional) Type of the output `Tensor`.
Returns:
A `Tensor` of type `dtype` and `shape`.
"""
"""Initializers for TF 2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -44,18 +32,40 @@ from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops.init_ops import _compute_fans
from tensorflow.python.util.tf_export import tf_export
_PARTITION_SHAPE = "partition_shape"
_PARTITION_OFFSET = "partition_offset"
class Initializer(object):
"""Initializer base class: all initializers inherit from this class.
Initializers should implement a `__call__` method with the following
signature:
```python
def __call__(self, shape, dtype=None, **kwargs):
# returns a tensor of shape `shape` and dtype `dtype`
# containing values drawn from a distribution of your choice.
```
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. If not provided will return tensor
of `tf.float32`.
of `tf.float32`.
**kwargs: Additional keyword arguments. Accepted values:
`partition_shape` and `partition_offset`. Used when creating a single
partition in a partitioned variable. `partition_shape` is the shape
of the partition (i.e. the shape of the returned tensor) and
`partition_offset` is a tuple of `int` specifying the offset of this
partition w.r.t each axis. For example, a tensor of shape `(30, 100)`
can be partitioned into two partitions: `p0` of shape `(10, 100)` and
`p1` of shape `(20, 100)`; if the initializer is called with
`partition_shape=(20, 100)` and `partition_offset=(10, 0)`, it should
return the value for `p1`.
"""
raise NotImplementedError
@ -89,6 +99,14 @@ class Initializer(object):
config.pop("dtype", None)
return cls(**config)
def _validate_kwargs(self, kwargs, support_partition=True):
for kwarg in kwargs:
if kwarg not in [_PARTITION_SHAPE, _PARTITION_OFFSET]:
raise TypeError("Unknown keyword arguments: %s" % kwarg)
elif not support_partition:
raise ValueError("%s initializer doesn't support partition-related"
" arguments" % self.__class__.__name__)
@tf_export("zeros_initializer", v1=[])
class Zeros(Initializer):
@ -115,20 +133,24 @@ class Zeros(Initializer):
(<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
"""
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValuesError: If the dtype is not numeric or boolean.
"""
self._validate_kwargs(kwargs)
dtype = dtypes.as_dtype(dtype)
if not dtype.is_numpy_compatible or dtype == dtypes.string:
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return array_ops.zeros(shape, dtype)
@ -157,20 +179,24 @@ class Ones(Initializer):
(<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
"""
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
supported.
supported.
**kwargs: Additional keyword arguments.
Raises:
ValuesError: If the dtype is not numeric or boolean.
"""
self._validate_kwargs(kwargs)
dtype = dtypes.as_dtype(dtype)
if not dtype.is_numpy_compatible or dtype == dtypes.string:
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return array_ops.ones(shape, dtype)
@ -245,22 +271,23 @@ class Constant(Initializer):
"tuple of values, or numpy.ndarray)." % type(value))
self.value = value
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. If not provided the dtype of the
tensor created will be the type of the inital value.
tensor created will be the type of the inital value.
**kwargs: Additional keyword arguments.
Raises:
TypeError: If the initializer cannot create a tensor of the requested
dtype.
"""
self._validate_kwargs(kwargs, support_partition=False)
if dtype is not None:
dtype = dtypes.as_dtype(dtype)
return constant_op.constant(
self.value, dtype=dtype, shape=shape)
return constant_op.constant(self.value, dtype=dtype, shape=shape)
def get_config(self):
return {"value": self.value}
@ -305,20 +332,24 @@ class RandomUniform(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point and integer
types are supported.
types are supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not numeric.
"""
self._validate_kwargs(kwargs)
dtype = dtypes.as_dtype(dtype)
if not dtype.is_floating and not dtype.is_integer:
raise ValueError("Expected float or integer dtype, got %s." % dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return self._random_generator.random_uniform(shape, self.minval,
self.maxval, dtype)
@ -369,18 +400,22 @@ class RandomNormal(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point
"""
self._validate_kwargs(kwargs)
dtype = _assert_float_dtype(dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return self._random_generator.random_normal(shape, self.mean, self.stddev,
dtype)
@ -434,18 +469,22 @@ class TruncatedNormal(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point
"""
self._validate_kwargs(kwargs)
dtype = _assert_float_dtype(dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return self._random_generator.truncated_normal(shape, self.mean,
self.stddev, dtype)
@ -525,24 +564,24 @@ class VarianceScaling(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point
"""
partition_info = None # Keeps logic so can be readded later if necessary
self._validate_kwargs(kwargs)
dtype = _assert_float_dtype(dtype)
scale = self.scale
scale_shape = shape
if partition_info is not None:
scale_shape = partition_info.full_shape
fan_in, fan_out = _compute_fans(scale_shape)
fan_in, fan_out = _compute_fans(shape)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
if self.mode == "fan_in":
scale /= max(1., fan_in)
elif self.mode == "fan_out":
@ -616,18 +655,20 @@ class Orthogonal(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point or the input shape is not
valid.
"""
self._validate_kwargs(kwargs, support_partition=False)
dtype = _assert_float_dtype(dtype)
# Check the shape
if len(shape) < 2:
@ -686,28 +727,25 @@ class Identity(Initializer):
def __init__(self, gain=1.0):
self.gain = gain
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point
ValueError: If the requested shape does not have exactly two axes.
"""
partition_info = None # Keeps logic so can be readded later if necessary
self._validate_kwargs(kwargs, support_partition=False)
dtype = _assert_float_dtype(dtype)
full_shape = shape if partition_info is None else partition_info.full_shape
if len(full_shape) != 2:
if len(shape) != 2:
raise ValueError(
"Identity matrix initializer can only be used for 2D matrices.")
initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
if partition_info is not None:
initializer = array_ops.slice(initializer, partition_info.var_offset,
shape)
initializer = linalg_ops_impl.eye(*shape, dtype=dtype)
return self.gain * initializer
def get_config(self):

View File

@ -78,6 +78,21 @@ class InitializersTest(test.TestCase):
if target_min is not None:
self.assertGreater(lim, abs(output.min() - target_min))
def _partition_test(self, init):
full_shape = (4, 2)
partition_shape = (2, 2)
partition_offset = (0, 0)
full_value = self.evaluate(init(full_shape, dtype=dtypes.float32))
got = self.evaluate(
init(
full_shape,
dtype=dtypes.float32,
partition_shape=partition_shape,
partition_offset=partition_offset))
self.assertEqual(got.shape, partition_shape)
self.assertAllClose(
got, array_ops.slice(full_value, partition_offset, partition_shape))
class ConstantInitializersTest(InitializersTest):
@ -86,11 +101,28 @@ class ConstantInitializersTest(InitializersTest):
self._range_test(init_ops_v2.Zeros(), shape=(4, 5),
target_mean=0., target_max=0.)
@test_util.run_in_graph_and_eager_modes
def testZerosPartition(self):
init = init_ops_v2.Zeros()
self._partition_test(init)
@test_util.run_in_graph_and_eager_modes
def testZerosInvalidKwargs(self):
init = init_ops_v2.Zeros()
with self.assertRaisesWithLiteralMatch(TypeError,
r"Unknown keyword arguments: dtpye"):
init((2, 2), dtpye=dtypes.float32)
@test_util.run_in_graph_and_eager_modes
def testOnes(self):
self._range_test(init_ops_v2.Ones(), shape=(4, 5),
target_mean=1., target_max=1.)
@test_util.run_in_graph_and_eager_modes
def testOnesPartition(self):
init = init_ops_v2.Ones()
self._partition_test(init)
@test_util.run_in_graph_and_eager_modes
def testConstantInt(self):
self._range_test(
@ -100,6 +132,14 @@ class ConstantInitializersTest(InitializersTest):
target_max=2,
target_min=2)
@test_util.run_in_graph_and_eager_modes
def testConstantPartition(self):
init = init_ops_v2.Constant([1, 2, 3, 4])
with self.assertRaisesWithLiteralMatch(
ValueError,
r"Constant initializer doesn't support partition-related arguments"):
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
@test_util.run_in_graph_and_eager_modes
def testConstantTuple(self):
init = init_ops_v2.constant_initializer((10, 20, 30))
@ -188,6 +228,11 @@ class RandomUniformInitializerTest(InitializersTest):
init = init_ops_v2.RandomUniform(0.0, 1.0)
self._duplicated_test(init)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
init = init_ops_v2.RandomUniform(0, 7, seed=1)
self._partition_test(init)
class RandomNormalInitializerTest(InitializersTest):
@ -217,6 +262,14 @@ class RandomNormalInitializerTest(InitializersTest):
init = init_ops_v2.RandomNormal(0.0, 1.0)
self._duplicated_test(init)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
if test_util.is_xla_enabled():
self.skipTest(
"XLA ignores seeds for RandomNormal, skip xla-enabled test.")
init = init_ops_v2.RandomNormal(0, 7, seed=1)
self._partition_test(init)
class TruncatedNormalInitializerTest(InitializersTest):
@ -247,6 +300,12 @@ class TruncatedNormalInitializerTest(InitializersTest):
init = init_ops_v2.TruncatedNormal(0.0, 1.0)
self._duplicated_test(init)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
init = init_ops_v2.TruncatedNormal(0.0, 1.0, seed=1)
self._partition_test(init)
@test_util.run_in_graph_and_eager_modes
def testInvalidDataType(self):
init = init_ops_v2.TruncatedNormal(0.0, 1.0)
with self.assertRaises(ValueError):
@ -317,6 +376,24 @@ class VarianceScalingInitializerTest(InitializersTest):
self.assertNear(np.mean(x), expect_mean, err=1e-2)
self.assertNear(np.var(x), expect_var, err=1e-2)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
partition_shape = (100, 100)
shape = [1000, 100]
expect_mean = 0.
expect_var = 1. / shape[0]
init = init_ops_v2.VarianceScaling(distribution="untruncated_normal")
with test_util.use_gpu(), test.mock.patch.object(
random_ops, "random_normal",
wraps=random_ops.random_normal) as mock_random_normal:
x = self.evaluate(init(shape, partition_shape=partition_shape))
self.assertTrue(mock_random_normal.called)
self.assertEqual(x.shape, partition_shape)
self.assertNear(np.mean(x), expect_mean, err=1e-3)
self.assertNear(np.var(x), expect_var, err=1e-3)
class OrthogonalInitializerTest(InitializersTest):
@ -386,6 +463,14 @@ class OrthogonalInitializerTest(InitializersTest):
self.assertAllClose(
np.dot(t, t.T), np.eye(t.shape[0]), rtol=tol, atol=tol)
@test_util.run_in_graph_and_eager_modes
def testPartition(self):
init = init_ops_v2.Orthogonal(seed=1)
with self.assertRaisesWithLiteralMatch(
ValueError,
r"Orthogonal initializer doesn't support partition-related arguments"):
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
class IdentityInitializerTest(InitializersTest):
@ -439,6 +524,14 @@ class IdentityInitializerTest(InitializersTest):
self.assertAllClose(self.evaluate(init_custom(shape, dtype=dtype)),
np.eye(*shape) * 0.9)
@test_util.run_in_graph_and_eager_modes
def testPartition(self):
init = init_ops_v2.Identity()
with self.assertRaisesWithLiteralMatch(
ValueError,
r"Identity initializer doesn't support partition-related arguments"):
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
class GlorotInitializersTest(InitializersTest):