Modify some v2 initializers to be able to return a value that corresponds to a partition of the entire value. This is useful for efficiently initializing sharded variables where only a shard of the initial value is necessary at a time.

PiperOrigin-RevId: 338371904
Change-Id: Ib4320d73cbaec30f5a61793debe7755026175781
This commit is contained in:
Chenkai Kuang 2020-10-21 17:16:27 -07:00 committed by TensorFlower Gardener
parent d2c7a16c2d
commit 239fe406d3
4 changed files with 249 additions and 73 deletions

View File

@ -34,7 +34,7 @@ class Initializer(object):
signature:
```python
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
# returns a tensor of shape `shape` and dtype `dtype`
# containing values drawn from a distribution of your choice.
```
@ -54,7 +54,7 @@ class Initializer(object):
self.mean = mean
self.stddev = stddev
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
return tf.random.normal(
shape, mean=self.mean, stddev=self.stddev, dtype=dtype)
@ -68,12 +68,13 @@ class Initializer(object):
works fine.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor.
**kwargs: Additional keyword arguments.
"""
raise NotImplementedError
@ -124,7 +125,7 @@ class Zeros(init_ops_v2.Zeros, Initializer):
>>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
@ -133,8 +134,9 @@ class Zeros(init_ops_v2.Zeros, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype))
return super(Zeros, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[])
@ -154,7 +156,7 @@ class Ones(init_ops_v2.Ones, Initializer):
>>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
@ -163,8 +165,9 @@ class Ones(init_ops_v2.Ones, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype))
return super(Ones, self).__call__(shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Constant',
@ -196,7 +199,7 @@ class Constant(Initializer):
def __init__(self, value=0):
self.value = value
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to `self.value`.
Args:
@ -205,7 +208,9 @@ class Constant(Initializer):
`tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
del kwargs
return constant_op.constant(
self.value, dtype=_get_dtype(dtype), shape=shape)
@ -241,7 +246,7 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
always produce the same random tensor for a given shape and dtype.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
@ -251,8 +256,10 @@ class RandomUniform(init_ops_v2.RandomUniform, Initializer):
`tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
return super(RandomUniform, self).__call__(shape, dtype=_get_dtype(dtype))
return super(RandomUniform, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.RandomNormal',
@ -283,17 +290,19 @@ class RandomNormal(init_ops_v2.RandomNormal, Initializer):
always produce the same random tensor for a given shape and dtype.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to random normal values.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
supported. If not specified, `tf.keras.backend.floatx()` is used, which
default to `float32` unless you configured it otherwise (via
`tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(RandomNormal, self).__call__(shape, dtype=_get_dtype(dtype))
return super(RandomNormal, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.TruncatedNormal',
@ -329,17 +338,19 @@ class TruncatedNormal(init_ops_v2.TruncatedNormal, Initializer):
always produce the same random tensor for a given shape and dtype.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to random normal values (truncated).
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
supported. If not specified, `tf.keras.backend.floatx()` is used, which
default to `float32` unless you configured it otherwise (via
`tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(TruncatedNormal, self).__call__(shape, dtype=_get_dtype(dtype))
return super(TruncatedNormal, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.VarianceScaling',
@ -384,17 +395,19 @@ class VarianceScaling(init_ops_v2.VarianceScaling, Initializer):
always produce the same random tensor for a given shape and dtype.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
supported. If not specified, `tf.keras.backend.floatx()` is used, which
default to `float32` unless you configured it otherwise (via
`tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(VarianceScaling, self).__call__(shape, dtype=_get_dtype(dtype))
return super(VarianceScaling, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Orthogonal',
@ -436,7 +449,7 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer):
([pdf](https://arxiv.org/pdf/1312.6120.pdf))
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to an orthogonal matrix.
Args:
@ -445,8 +458,10 @@ class Orthogonal(init_ops_v2.Orthogonal, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(Orthogonal, self).__call__(shape, dtype=_get_dtype(dtype))
return super(Orthogonal, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.Identity',
@ -473,7 +488,7 @@ class Identity(init_ops_v2.Identity, Initializer):
gain: Multiplicative factor to apply to the identity matrix.
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized to a 2D identity matrix.
Args:
@ -482,8 +497,10 @@ class Identity(init_ops_v2.Identity, Initializer):
supported. If not specified, `tf.keras.backend.floatx()` is used,
which default to `float32` unless you configured it otherwise
(via `tf.keras.backend.set_floatx(float_dtype)`)
**kwargs: Additional keyword arguments.
"""
return super(Identity, self).__call__(shape, dtype=_get_dtype(dtype))
return super(Identity, self).__call__(
shape, dtype=_get_dtype(dtype), **kwargs)
@keras_export('keras.initializers.GlorotUniform',

View File

@ -253,6 +253,34 @@ class KerasInitializersTest(test.TestCase):
initializer = initializers.deserialize(external_serialized_json)
self.assertEqual(initializer.distribution, 'truncated_normal')
def test_partition(self):
with self.cached_session():
partition_enabled_initializers = [
initializers.ZerosV2(),
initializers.OnesV2(),
initializers.RandomUniformV2(),
initializers.RandomNormalV2(),
initializers.TruncatedNormalV2(),
initializers.LecunUniformV2(),
initializers.GlorotUniformV2(),
initializers.HeUniformV2()
]
for initializer in partition_enabled_initializers:
got = initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
self.assertEqual(got.shape, (2, 2))
partition_forbidden_initializers = [
initializers.OrthogonalV2(),
initializers.IdentityV2()
]
for initializer in partition_forbidden_initializers:
with self.assertRaisesRegex(
ValueError,
"initializer doesn't support partition-related arguments"):
initializer(
shape=(4, 2), partition_shape=(2, 2), partition_offset=(0, 0))
if __name__ == '__main__':
test.main()

View File

@ -12,19 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Operations often used for initializing tensors.
All variable initializers returned by functions in this file should have the
following signature:
def _initializer(shape, dtype=dtypes.float32):
Args:
shape: List of `int` representing the shape of the output `Tensor`. Some
initializers may also be able to accept a `Tensor`.
dtype: (Optional) Type of the output `Tensor`.
Returns:
A `Tensor` of type `dtype` and `shape`.
"""
"""Initializers for TF 2."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@ -44,18 +32,40 @@ from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops.init_ops import _compute_fans
from tensorflow.python.util.tf_export import tf_export
_PARTITION_SHAPE = "partition_shape"
_PARTITION_OFFSET = "partition_offset"
class Initializer(object):
"""Initializer base class: all initializers inherit from this class.
Initializers should implement a `__call__` method with the following
signature:
```python
def __call__(self, shape, dtype=None, **kwargs):
# returns a tensor of shape `shape` and dtype `dtype`
# containing values drawn from a distribution of your choice.
```
"""
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. If not provided will return tensor
of `tf.float32`.
**kwargs: Additional keyword arguments. Accepted values:
`partition_shape` and `partition_offset`. Used when creating a single
partition in a partitioned variable. `partition_shape` is the shape
of the partition (i.e. the shape of the returned tensor) and
`partition_offset` is a tuple of `int` specifying the offset of this
partition w.r.t each axis. For example, a tensor of shape `(30, 100)`
can be partitioned into two partitions: `p0` of shape `(10, 100)` and
`p1` of shape `(20, 100)`; if the initializer is called with
`partition_shape=(20, 100)` and `partition_offset=(10, 0)`, it should
return the value for `p1`.
"""
raise NotImplementedError
@ -89,6 +99,14 @@ class Initializer(object):
config.pop("dtype", None)
return cls(**config)
def _validate_kwargs(self, kwargs, support_partition=True):
for kwarg in kwargs:
if kwarg not in [_PARTITION_SHAPE, _PARTITION_OFFSET]:
raise TypeError("Unknown keyword arguments: %s" % kwarg)
elif not support_partition:
raise ValueError("%s initializer doesn't support partition-related"
" arguments" % self.__class__.__name__)
@tf_export("zeros_initializer", v1=[])
class Zeros(Initializer):
@ -115,20 +133,24 @@ class Zeros(Initializer):
(<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
"""
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValuesError: If the dtype is not numeric or boolean.
"""
self._validate_kwargs(kwargs)
dtype = dtypes.as_dtype(dtype)
if not dtype.is_numpy_compatible or dtype == dtypes.string:
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return array_ops.zeros(shape, dtype)
@ -157,20 +179,24 @@ class Ones(Initializer):
(<tf.Variable...shape=(4,) dtype=float32...>, <tf.Variable...shape=(4, 4) ...
"""
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValuesError: If the dtype is not numeric or boolean.
"""
self._validate_kwargs(kwargs)
dtype = dtypes.as_dtype(dtype)
if not dtype.is_numpy_compatible or dtype == dtypes.string:
raise ValueError("Expected numeric or boolean dtype, got %s." % dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return array_ops.ones(shape, dtype)
@ -245,22 +271,23 @@ class Constant(Initializer):
"tuple of values, or numpy.ndarray)." % type(value))
self.value = value
def __call__(self, shape, dtype=None):
def __call__(self, shape, dtype=None, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. If not provided the dtype of the
tensor created will be the type of the inital value.
**kwargs: Additional keyword arguments.
Raises:
TypeError: If the initializer cannot create a tensor of the requested
dtype.
"""
self._validate_kwargs(kwargs, support_partition=False)
if dtype is not None:
dtype = dtypes.as_dtype(dtype)
return constant_op.constant(
self.value, dtype=dtype, shape=shape)
return constant_op.constant(self.value, dtype=dtype, shape=shape)
def get_config(self):
return {"value": self.value}
@ -305,20 +332,24 @@ class RandomUniform(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point and integer
types are supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not numeric.
"""
self._validate_kwargs(kwargs)
dtype = dtypes.as_dtype(dtype)
if not dtype.is_floating and not dtype.is_integer:
raise ValueError("Expected float or integer dtype, got %s." % dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return self._random_generator.random_uniform(shape, self.minval,
self.maxval, dtype)
@ -369,18 +400,22 @@ class RandomNormal(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point
"""
self._validate_kwargs(kwargs)
dtype = _assert_float_dtype(dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return self._random_generator.random_normal(shape, self.mean, self.stddev,
dtype)
@ -434,18 +469,22 @@ class TruncatedNormal(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point
"""
self._validate_kwargs(kwargs)
dtype = _assert_float_dtype(dtype)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
return self._random_generator.truncated_normal(shape, self.mean,
self.stddev, dtype)
@ -525,24 +564,24 @@ class VarianceScaling(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point
"""
partition_info = None # Keeps logic so can be readded later if necessary
self._validate_kwargs(kwargs)
dtype = _assert_float_dtype(dtype)
scale = self.scale
scale_shape = shape
if partition_info is not None:
scale_shape = partition_info.full_shape
fan_in, fan_out = _compute_fans(scale_shape)
fan_in, fan_out = _compute_fans(shape)
if _PARTITION_SHAPE in kwargs:
shape = kwargs[_PARTITION_SHAPE]
if self.mode == "fan_in":
scale /= max(1., fan_in)
elif self.mode == "fan_out":
@ -616,18 +655,20 @@ class Orthogonal(Initializer):
self.seed = seed
self._random_generator = _RandomGenerator(seed)
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point or the input shape is not
valid.
"""
self._validate_kwargs(kwargs, support_partition=False)
dtype = _assert_float_dtype(dtype)
# Check the shape
if len(shape) < 2:
@ -686,28 +727,25 @@ class Identity(Initializer):
def __init__(self, gain=1.0):
self.gain = gain
def __call__(self, shape, dtype=dtypes.float32):
def __call__(self, shape, dtype=dtypes.float32, **kwargs):
"""Returns a tensor object initialized as specified by the initializer.
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only floating point types are
supported.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the dtype is not floating point
ValueError: If the requested shape does not have exactly two axes.
"""
partition_info = None # Keeps logic so can be readded later if necessary
self._validate_kwargs(kwargs, support_partition=False)
dtype = _assert_float_dtype(dtype)
full_shape = shape if partition_info is None else partition_info.full_shape
if len(full_shape) != 2:
if len(shape) != 2:
raise ValueError(
"Identity matrix initializer can only be used for 2D matrices.")
initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
if partition_info is not None:
initializer = array_ops.slice(initializer, partition_info.var_offset,
shape)
initializer = linalg_ops_impl.eye(*shape, dtype=dtype)
return self.gain * initializer
def get_config(self):

View File

@ -78,6 +78,21 @@ class InitializersTest(test.TestCase):
if target_min is not None:
self.assertGreater(lim, abs(output.min() - target_min))
def _partition_test(self, init):
full_shape = (4, 2)
partition_shape = (2, 2)
partition_offset = (0, 0)
full_value = self.evaluate(init(full_shape, dtype=dtypes.float32))
got = self.evaluate(
init(
full_shape,
dtype=dtypes.float32,
partition_shape=partition_shape,
partition_offset=partition_offset))
self.assertEqual(got.shape, partition_shape)
self.assertAllClose(
got, array_ops.slice(full_value, partition_offset, partition_shape))
class ConstantInitializersTest(InitializersTest):
@ -86,11 +101,28 @@ class ConstantInitializersTest(InitializersTest):
self._range_test(init_ops_v2.Zeros(), shape=(4, 5),
target_mean=0., target_max=0.)
@test_util.run_in_graph_and_eager_modes
def testZerosPartition(self):
init = init_ops_v2.Zeros()
self._partition_test(init)
@test_util.run_in_graph_and_eager_modes
def testZerosInvalidKwargs(self):
init = init_ops_v2.Zeros()
with self.assertRaisesWithLiteralMatch(TypeError,
r"Unknown keyword arguments: dtpye"):
init((2, 2), dtpye=dtypes.float32)
@test_util.run_in_graph_and_eager_modes
def testOnes(self):
self._range_test(init_ops_v2.Ones(), shape=(4, 5),
target_mean=1., target_max=1.)
@test_util.run_in_graph_and_eager_modes
def testOnesPartition(self):
init = init_ops_v2.Ones()
self._partition_test(init)
@test_util.run_in_graph_and_eager_modes
def testConstantInt(self):
self._range_test(
@ -100,6 +132,14 @@ class ConstantInitializersTest(InitializersTest):
target_max=2,
target_min=2)
@test_util.run_in_graph_and_eager_modes
def testConstantPartition(self):
init = init_ops_v2.Constant([1, 2, 3, 4])
with self.assertRaisesWithLiteralMatch(
ValueError,
r"Constant initializer doesn't support partition-related arguments"):
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
@test_util.run_in_graph_and_eager_modes
def testConstantTuple(self):
init = init_ops_v2.constant_initializer((10, 20, 30))
@ -188,6 +228,11 @@ class RandomUniformInitializerTest(InitializersTest):
init = init_ops_v2.RandomUniform(0.0, 1.0)
self._duplicated_test(init)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
init = init_ops_v2.RandomUniform(0, 7, seed=1)
self._partition_test(init)
class RandomNormalInitializerTest(InitializersTest):
@ -217,6 +262,14 @@ class RandomNormalInitializerTest(InitializersTest):
init = init_ops_v2.RandomNormal(0.0, 1.0)
self._duplicated_test(init)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
if test_util.is_xla_enabled():
self.skipTest(
"XLA ignores seeds for RandomNormal, skip xla-enabled test.")
init = init_ops_v2.RandomNormal(0, 7, seed=1)
self._partition_test(init)
class TruncatedNormalInitializerTest(InitializersTest):
@ -247,6 +300,12 @@ class TruncatedNormalInitializerTest(InitializersTest):
init = init_ops_v2.TruncatedNormal(0.0, 1.0)
self._duplicated_test(init)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
init = init_ops_v2.TruncatedNormal(0.0, 1.0, seed=1)
self._partition_test(init)
@test_util.run_in_graph_and_eager_modes
def testInvalidDataType(self):
init = init_ops_v2.TruncatedNormal(0.0, 1.0)
with self.assertRaises(ValueError):
@ -317,6 +376,24 @@ class VarianceScalingInitializerTest(InitializersTest):
self.assertNear(np.mean(x), expect_mean, err=1e-2)
self.assertNear(np.var(x), expect_var, err=1e-2)
@test_util.run_in_graph_and_eager_modes
def testInitializePartition(self):
partition_shape = (100, 100)
shape = [1000, 100]
expect_mean = 0.
expect_var = 1. / shape[0]
init = init_ops_v2.VarianceScaling(distribution="untruncated_normal")
with test_util.use_gpu(), test.mock.patch.object(
random_ops, "random_normal",
wraps=random_ops.random_normal) as mock_random_normal:
x = self.evaluate(init(shape, partition_shape=partition_shape))
self.assertTrue(mock_random_normal.called)
self.assertEqual(x.shape, partition_shape)
self.assertNear(np.mean(x), expect_mean, err=1e-3)
self.assertNear(np.var(x), expect_var, err=1e-3)
class OrthogonalInitializerTest(InitializersTest):
@ -386,6 +463,14 @@ class OrthogonalInitializerTest(InitializersTest):
self.assertAllClose(
np.dot(t, t.T), np.eye(t.shape[0]), rtol=tol, atol=tol)
@test_util.run_in_graph_and_eager_modes
def testPartition(self):
init = init_ops_v2.Orthogonal(seed=1)
with self.assertRaisesWithLiteralMatch(
ValueError,
r"Orthogonal initializer doesn't support partition-related arguments"):
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
class IdentityInitializerTest(InitializersTest):
@ -439,6 +524,14 @@ class IdentityInitializerTest(InitializersTest):
self.assertAllClose(self.evaluate(init_custom(shape, dtype=dtype)),
np.eye(*shape) * 0.9)
@test_util.run_in_graph_and_eager_modes
def testPartition(self):
init = init_ops_v2.Identity()
with self.assertRaisesWithLiteralMatch(
ValueError,
r"Identity initializer doesn't support partition-related arguments"):
init((4, 2), dtype=dtypes.float32, partition_shape=(2, 2))
class GlorotInitializersTest(InitializersTest):