Introduce Glorot initializers in core.
Change: 142728126
This commit is contained in:
parent
c1e7b1580a
commit
be60473c88
@ -34,6 +34,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
@ -290,6 +291,152 @@ def uniform_unit_scaling_initializer(factor=1.0,
|
|||||||
return _initializer
|
return _initializer
|
||||||
|
|
||||||
|
|
||||||
|
def variance_scaling_initializer(scale=1.0,
|
||||||
|
mode="fan_in",
|
||||||
|
distribution="normal",
|
||||||
|
seed=None,
|
||||||
|
dtype=dtypes.float32):
|
||||||
|
"""Initializer capable of adapting its scale to the shape of weights tensors.
|
||||||
|
|
||||||
|
With `distribution="normal"`, samples are drawn from a truncated normal
|
||||||
|
distribution centered on zero, with `stddev = sqrt(scale / n)`
|
||||||
|
where n is:
|
||||||
|
- number of input units in the weight tensor, if mode = "fan_in"
|
||||||
|
- number of output units, if mode = "fan_out"
|
||||||
|
- average of the numbers of input and output units, if mode = "fan_avg"
|
||||||
|
|
||||||
|
With `distribution="uniform"`, samples are drawn from a uniform distribution
|
||||||
|
within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
scale: Scaling factor (positive float).
|
||||||
|
mode: One of "fan_in", "fan_out", "fan_avg".
|
||||||
|
distribution: Random distribution to use. One of "normal", "uniform".
|
||||||
|
seed: A Python integer. Used to create random seeds. See
|
||||||
|
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
|
||||||
|
for behavior.
|
||||||
|
dtype: The data type. Only floating point types are supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An initializer.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: In case of an invalid value for the "scale", mode" or
|
||||||
|
"distribution" arguments.
|
||||||
|
"""
|
||||||
|
# In order to be able to use the `scale` variable inside the inner
|
||||||
|
# `_initializer` function, we must temporarily rename it: for an unknown
|
||||||
|
# reason, the name `scale` prevents the variable from appearing in the inner
|
||||||
|
# scope. We note that this sort of access to the outer scope by the inner
|
||||||
|
# function is unsafe (as illustrated by this strange issue) and should
|
||||||
|
# be removed in the future, by refactoring initializers as classes.
|
||||||
|
scale_ = scale
|
||||||
|
if scale <= 0.:
|
||||||
|
raise ValueError("`scale` must be positive float.")
|
||||||
|
if mode not in {"fan_in", "fan_out", "fan_avg"}:
|
||||||
|
raise ValueError("Invalid `mode` argument:", mode)
|
||||||
|
distribution = distribution.lower()
|
||||||
|
if distribution not in {"normal", "uniform"}:
|
||||||
|
raise ValueError("Invalid `distribution` argument:", distribution)
|
||||||
|
|
||||||
|
def _initializer(shape, dtype=_assert_float_dtype(dtype),
|
||||||
|
partition_info=None):
|
||||||
|
scale = scale_
|
||||||
|
scale_shape = shape
|
||||||
|
if partition_info is not None:
|
||||||
|
scale_shape = partition_info.full_shape
|
||||||
|
fan_in, fan_out = _compute_fans(scale_shape)
|
||||||
|
if mode == "fan_in":
|
||||||
|
scale /= max(1., fan_in)
|
||||||
|
elif mode == "fan_out":
|
||||||
|
scale /= max(1., fan_out)
|
||||||
|
else:
|
||||||
|
scale /= max(1., (fan_in + fan_out) / 2.)
|
||||||
|
if distribution == "normal":
|
||||||
|
stddev = math.sqrt(scale)
|
||||||
|
return random_ops.truncated_normal(shape, 0.0, stddev, dtype, seed=seed)
|
||||||
|
else:
|
||||||
|
limit = math.sqrt(3.0 * scale)
|
||||||
|
return random_ops.random_uniform(shape, -limit, limit,
|
||||||
|
dtype, seed=seed)
|
||||||
|
return _initializer
|
||||||
|
|
||||||
|
|
||||||
|
def glorot_uniform_initializer(seed=None, dtype=dtypes.float32):
|
||||||
|
"""The Glorot uniform initializer, also called Xavier uniform initializer.
|
||||||
|
|
||||||
|
It draws samples from a uniform distribution within [-limit, limit]
|
||||||
|
where `limit` is `sqrt(6 / (fan_in + fan_out))`
|
||||||
|
where `fan_in` is the number of input units in the weight tensor
|
||||||
|
and `fan_out` is the number of output units in the weight tensor.
|
||||||
|
|
||||||
|
Reference: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
seed: A Python integer. Used to create random seeds. See
|
||||||
|
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
|
||||||
|
for behavior.
|
||||||
|
dtype: The data type. Only floating point types are supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An initializer.
|
||||||
|
"""
|
||||||
|
return variance_scaling_initializer(scale=1.0,
|
||||||
|
mode="fan_avg",
|
||||||
|
distribution="uniform",
|
||||||
|
seed=seed,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def glorot_normal_initializer(seed=None, dtype=dtypes.float32):
|
||||||
|
"""The Glorot normal initializer, also called Xavier normal initializer.
|
||||||
|
|
||||||
|
It draws samples from a truncated normal distribution centered on 0
|
||||||
|
with `stddev = sqrt(2 / (fan_in + fan_out))`
|
||||||
|
where `fan_in` is the number of input units in the weight tensor
|
||||||
|
and `fan_out` is the number of output units in the weight tensor.
|
||||||
|
|
||||||
|
Reference: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
seed: A Python integer. Used to create random seeds. See
|
||||||
|
[`set_random_seed`](../../api_docs/python/constant_op.md#set_random_seed)
|
||||||
|
for behavior.
|
||||||
|
dtype: The data type. Only floating point types are supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An initializer.
|
||||||
|
"""
|
||||||
|
return variance_scaling_initializer(scale=1.0,
|
||||||
|
mode="fan_avg",
|
||||||
|
distribution="normal",
|
||||||
|
seed=seed,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_fans(shape):
|
||||||
|
"""Computes the number of input and output units for a weight shape.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
shape: Integer shape tuple or TF tensor shape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of scalars (fan_in, fan_out).
|
||||||
|
"""
|
||||||
|
if len(shape) == 2:
|
||||||
|
fan_in = shape[0]
|
||||||
|
fan_out = shape[1]
|
||||||
|
else:
|
||||||
|
# Assuming convolution kernels (2D, 3D, or more).
|
||||||
|
# kernel shape: (..., input_depth, depth)
|
||||||
|
receptive_field_size = 1.
|
||||||
|
for dim in shape[:-2]:
|
||||||
|
receptive_field_size *= dim
|
||||||
|
fan_in = shape[-2] * receptive_field_size
|
||||||
|
fan_out = shape[-1] * receptive_field_size
|
||||||
|
return fan_in, fan_out
|
||||||
|
|
||||||
|
|
||||||
# TODO(vrv): Unhide when we are ready to expose this publicly.
|
# TODO(vrv): Unhide when we are ready to expose this publicly.
|
||||||
def _random_walk(shape, nonlinearity, dtype=dtypes.float32, seed=None,
|
def _random_walk(shape, nonlinearity, dtype=dtypes.float32, seed=None,
|
||||||
name="random_walk"):
|
name="random_walk"):
|
||||||
|
Loading…
Reference in New Issue
Block a user