Fix formatting of file
PiperOrigin-RevId: 312408716 Change-Id: I63f427c3453745008b159afc7a459df63b0ec8d0
This commit is contained in:
parent
361470d24a
commit
f8a797e13e
@ -12,8 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Normalization layers.
|
"""Normalization layers."""
|
||||||
"""
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -43,7 +42,7 @@ from tensorflow.python.util.tf_export import keras_export
|
|||||||
|
|
||||||
|
|
||||||
class BatchNormalizationBase(Layer):
|
class BatchNormalizationBase(Layer):
|
||||||
r"""Normalize and scale inputs or activations. (Ioffe and Szegedy, 2014).
|
r"""Normalize and scale inputs or activations.
|
||||||
|
|
||||||
Normalize the activations of the previous layer at each batch,
|
Normalize the activations of the previous layer at each batch,
|
||||||
i.e. applies a transformation that maintains the mean activation
|
i.e. applies a transformation that maintains the mean activation
|
||||||
@ -65,20 +64,16 @@ class BatchNormalizationBase(Layer):
|
|||||||
`training=False` when calling the model, or using `model.predict`.
|
`training=False` when calling the model, or using `model.predict`.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
axis: Integer, the axis that should be normalized
|
axis: Integer, the axis that should be normalized (typically the features
|
||||||
(typically the features axis).
|
axis). For instance, after a `Conv2D` layer with
|
||||||
For instance, after a `Conv2D` layer with
|
`data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
|
||||||
`data_format="channels_first"`,
|
|
||||||
set `axis=1` in `BatchNormalization`.
|
|
||||||
momentum: Momentum for the moving average.
|
momentum: Momentum for the moving average.
|
||||||
epsilon: Small float added to variance to avoid dividing by zero.
|
epsilon: Small float added to variance to avoid dividing by zero.
|
||||||
center: If True, add offset of `beta` to normalized tensor.
|
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
|
||||||
If False, `beta` is ignored.
|
is ignored.
|
||||||
scale: If True, multiply by `gamma`.
|
scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the
|
||||||
If False, `gamma` is not used.
|
next layer is linear (also e.g. `nn.relu`), this can be disabled since the
|
||||||
When the next layer is linear (also e.g. `nn.relu`),
|
scaling will be done by the next layer.
|
||||||
this can be disabled since the scaling
|
|
||||||
will be done by the next layer.
|
|
||||||
beta_initializer: Initializer for the beta weight.
|
beta_initializer: Initializer for the beta weight.
|
||||||
gamma_initializer: Initializer for the gamma weight.
|
gamma_initializer: Initializer for the gamma weight.
|
||||||
moving_mean_initializer: Initializer for the moving mean.
|
moving_mean_initializer: Initializer for the moving mean.
|
||||||
@ -89,17 +84,17 @@ class BatchNormalizationBase(Layer):
|
|||||||
gamma_constraint: Optional constraint for the gamma weight.
|
gamma_constraint: Optional constraint for the gamma weight.
|
||||||
renorm: Whether to use [Batch Renormalization](
|
renorm: Whether to use [Batch Renormalization](
|
||||||
https://arxiv.org/abs/1702.03275). This adds extra variables during
|
https://arxiv.org/abs/1702.03275). This adds extra variables during
|
||||||
training. The inference is the same for either value of this parameter.
|
training. The inference is the same for either value of this parameter.
|
||||||
renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
|
renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
|
||||||
scalar `Tensors` used to clip the renorm correction. The correction
|
scalar `Tensors` used to clip the renorm correction. The correction `(r,
|
||||||
`(r, d)` is used as `corrected_value = normalized_value * r + d`, with
|
d)` is used as `corrected_value = normalized_value * r + d`, with `r`
|
||||||
`r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
|
clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
|
||||||
dmax are set to inf, 0, inf, respectively.
|
dmax are set to inf, 0, inf, respectively.
|
||||||
renorm_momentum: Momentum used to update the moving means and standard
|
renorm_momentum: Momentum used to update the moving means and standard
|
||||||
deviations with renorm. Unlike `momentum`, this affects training
|
deviations with renorm. Unlike `momentum`, this affects training and
|
||||||
and should be neither too small (which would add noise) nor too large
|
should be neither too small (which would add noise) nor too large (which
|
||||||
(which would give stale estimates). Note that `momentum` is still applied
|
would give stale estimates). Note that `momentum` is still applied to get
|
||||||
to get the means and variances for inference.
|
the means and variances for inference.
|
||||||
fused: if `True`, use a faster, fused implementation, or raise a ValueError
|
fused: if `True`, use a faster, fused implementation, or raise a ValueError
|
||||||
if the fused implementation cannot be used. If `None`, use the faster
|
if the fused implementation cannot be used. If `None`, use the faster
|
||||||
implementation if possible. If False, do not used the fused
|
implementation if possible. If False, do not used the fused
|
||||||
@ -117,54 +112,36 @@ class BatchNormalizationBase(Layer):
|
|||||||
example, if axis==-1,
|
example, if axis==-1,
|
||||||
`adjustment = lambda shape: (
|
`adjustment = lambda shape: (
|
||||||
tf.random.uniform(shape[-1:], 0.93, 1.07),
|
tf.random.uniform(shape[-1:], 0.93, 1.07),
|
||||||
tf.random.uniform(shape[-1:], -0.1, 0.1))`
|
tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized
|
||||||
will scale the normalized value by up to 7% up or down, then shift the
|
value by up to 7% up or down, then shift the result by up to 0.1
|
||||||
result by up to 0.1 (with independent scaling and bias for each feature
|
(with independent scaling and bias for each feature but shared
|
||||||
but shared across all examples), and finally apply gamma and/or beta. If
|
across all examples), and finally apply gamma and/or beta. If
|
||||||
`None`, no adjustment is applied. Cannot be specified if
|
`None`, no adjustment is applied. Cannot be specified if
|
||||||
virtual_batch_size is specified.
|
virtual_batch_size is specified.
|
||||||
|
|
||||||
Call arguments:
|
Call arguments:
|
||||||
inputs: Input tensor (of any rank).
|
inputs: Input tensor (of any rank).
|
||||||
training: Python boolean indicating whether the layer should behave in
|
training: Python boolean indicating whether the layer should behave in
|
||||||
training mode or in inference mode.
|
training mode or in inference mode.
|
||||||
- `training=True`: The layer will normalize its inputs using the
|
- `training=True`: The layer will normalize its inputs using the mean and
|
||||||
mean and variance of the current batch of inputs.
|
variance of the current batch of inputs.
|
||||||
- `training=False`: The layer will normalize its inputs using the
|
- `training=False`: The layer will normalize its inputs using the mean and
|
||||||
mean and variance of its moving statistics, learned during training.
|
variance of its moving statistics, learned during training.
|
||||||
|
Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of
|
||||||
Input shape:
|
integers, does not include the samples axis) when using this layer as the
|
||||||
Arbitrary. Use the keyword argument `input_shape`
|
first layer in a model.
|
||||||
(tuple of integers, does not include the samples axis)
|
Output shape: Same shape as input. {{TRAINABLE_ATTRIBUTE_NOTE}}
|
||||||
when using this layer as the first layer in a model.
|
Normalization equations: Consider the intermediate activations \(x\) of a
|
||||||
|
mini-batch of size
|
||||||
Output shape:
|
\\(m\\): We can compute the mean and variance of the batch \\({\mu_B} =
|
||||||
Same shape as input.
|
\frac{1}{m} \sum_{i=1}^{m} {x_i}\\) \\({\sigma_B^2} = \frac{1}{m}
|
||||||
|
\sum_{i=1}^{m} ({x_i} - {\mu_B})^2\\) and then compute a normalized
|
||||||
{{TRAINABLE_ATTRIBUTE_NOTE}}
|
\\(x\\), including a small factor \\({\epsilon}\\) for numerical
|
||||||
|
stability. \\(\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 +
|
||||||
Normalization equations:
|
\epsilon}}\\) And finally \\(\hat{x}\) is linearly transformed by
|
||||||
Consider the intermediate activations \(x\) of a mini-batch of size
|
\({\gamma}\\)
|
||||||
\\(m\\):
|
and \\({\beta}\\), which are learned parameters: \\({y_i} = {\gamma *
|
||||||
|
\hat{x_i} + \beta}\\)
|
||||||
We can compute the mean and variance of the batch
|
|
||||||
|
|
||||||
\\({\mu_B} = \frac{1}{m} \sum_{i=1}^{m} {x_i}\\)
|
|
||||||
|
|
||||||
\\({\sigma_B^2} = \frac{1}{m} \sum_{i=1}^{m} ({x_i} - {\mu_B})^2\\)
|
|
||||||
|
|
||||||
and then compute a normalized \\(x\\), including a small factor
|
|
||||||
\\({\epsilon}\\) for numerical stability.
|
|
||||||
|
|
||||||
\\(\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}\\)
|
|
||||||
|
|
||||||
And finally \\(\hat{x}\) is linearly transformed by \({\gamma}\\)
|
|
||||||
and \\({\beta}\\), which are learned parameters:
|
|
||||||
|
|
||||||
\\({y_i} = {\gamma * \hat{x_i} + \beta}\\)
|
|
||||||
|
|
||||||
Reference:
|
Reference:
|
||||||
|
|
||||||
- [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
|
- [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -195,8 +172,7 @@ class BatchNormalizationBase(Layer):
|
|||||||
adjustment=None,
|
adjustment=None,
|
||||||
name=None,
|
name=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(BatchNormalizationBase, self).__init__(
|
super(BatchNormalizationBase, self).__init__(name=name, **kwargs)
|
||||||
name=name, **kwargs)
|
|
||||||
if isinstance(axis, (list, tuple)):
|
if isinstance(axis, (list, tuple)):
|
||||||
self.axis = axis[:]
|
self.axis = axis[:]
|
||||||
elif isinstance(axis, int):
|
elif isinstance(axis, int):
|
||||||
@ -275,8 +251,8 @@ class BatchNormalizationBase(Layer):
|
|||||||
# TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check.
|
# TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check.
|
||||||
if self._compute_dtype not in ('float16', 'bfloat16', 'float32', None):
|
if self._compute_dtype not in ('float16', 'bfloat16', 'float32', None):
|
||||||
raise ValueError('Passing fused=True is only supported when the compute '
|
raise ValueError('Passing fused=True is only supported when the compute '
|
||||||
'dtype is float16, bfloat16, or float32. Got dtype: %s'
|
'dtype is float16, bfloat16, or float32. Got dtype: %s' %
|
||||||
% (self._compute_dtype,))
|
(self._compute_dtype,))
|
||||||
|
|
||||||
def _fused_can_be_used(self):
|
def _fused_can_be_used(self):
|
||||||
try:
|
try:
|
||||||
@ -380,13 +356,14 @@ class BatchNormalizationBase(Layer):
|
|||||||
param_shape = (list(axis_to_dim.values())[0],)
|
param_shape = (list(axis_to_dim.values())[0],)
|
||||||
else:
|
else:
|
||||||
# Parameter shape is the original shape but with 1 in all non-axis dims
|
# Parameter shape is the original shape but with 1 in all non-axis dims
|
||||||
param_shape = [axis_to_dim[i] if i in axis_to_dim
|
param_shape = [
|
||||||
else 1 for i in range(ndims)]
|
axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
|
||||||
|
]
|
||||||
if self.virtual_batch_size is not None:
|
if self.virtual_batch_size is not None:
|
||||||
# When using virtual batches, add an extra dim at index 1
|
# When using virtual batches, add an extra dim at index 1
|
||||||
param_shape.insert(1, 1)
|
param_shape.insert(1, 1)
|
||||||
for idx, x in enumerate(self.axis):
|
for idx, x in enumerate(self.axis):
|
||||||
self.axis[idx] = x + 1 # Account for added dimension
|
self.axis[idx] = x + 1 # Account for added dimension
|
||||||
|
|
||||||
if self.scale:
|
if self.scale:
|
||||||
self.gamma = self.add_weight(
|
self.gamma = self.add_weight(
|
||||||
@ -507,8 +484,7 @@ class BatchNormalizationBase(Layer):
|
|||||||
decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay')
|
decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay')
|
||||||
if decay.dtype != variable.dtype.base_dtype:
|
if decay.dtype != variable.dtype.base_dtype:
|
||||||
decay = math_ops.cast(decay, variable.dtype.base_dtype)
|
decay = math_ops.cast(decay, variable.dtype.base_dtype)
|
||||||
update_delta = (
|
update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
|
||||||
variable - math_ops.cast(value, variable.dtype)) * decay
|
|
||||||
if inputs_size is not None:
|
if inputs_size is not None:
|
||||||
update_delta = array_ops.where(inputs_size > 0, update_delta,
|
update_delta = array_ops.where(inputs_size > 0, update_delta,
|
||||||
K.zeros_like(update_delta))
|
K.zeros_like(update_delta))
|
||||||
@ -650,8 +626,9 @@ class BatchNormalizationBase(Layer):
|
|||||||
with ops.control_dependencies([r, d]):
|
with ops.control_dependencies([r, d]):
|
||||||
mean = array_ops.identity(mean)
|
mean = array_ops.identity(mean)
|
||||||
stddev = array_ops.identity(stddev)
|
stddev = array_ops.identity(stddev)
|
||||||
rmin, rmax, dmax = [self.renorm_clipping.get(key)
|
rmin, rmax, dmax = [
|
||||||
for key in ['rmin', 'rmax', 'dmax']]
|
self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax']
|
||||||
|
]
|
||||||
if rmin is not None:
|
if rmin is not None:
|
||||||
r = math_ops.maximum(r, rmin)
|
r = math_ops.maximum(r, rmin)
|
||||||
if rmax is not None:
|
if rmax is not None:
|
||||||
@ -661,13 +638,13 @@ class BatchNormalizationBase(Layer):
|
|||||||
d = math_ops.minimum(d, dmax)
|
d = math_ops.minimum(d, dmax)
|
||||||
# When not training, use r=1, d=0.
|
# When not training, use r=1, d=0.
|
||||||
r = tf_utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r))
|
r = tf_utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r))
|
||||||
d = tf_utils.smart_cond(training,
|
d = tf_utils.smart_cond(training, lambda: d,
|
||||||
lambda: d,
|
|
||||||
lambda: array_ops.zeros_like(d))
|
lambda: array_ops.zeros_like(d))
|
||||||
|
|
||||||
def _update_renorm_variable(var, value, inputs_size):
|
def _update_renorm_variable(var, value, inputs_size):
|
||||||
"""Updates a moving average and weight, returns the unbiased value."""
|
"""Updates a moving average and weight, returns the unbiased value."""
|
||||||
value = array_ops.identity(value)
|
value = array_ops.identity(value)
|
||||||
|
|
||||||
def _do_update():
|
def _do_update():
|
||||||
"""Updates the var, returns the updated value."""
|
"""Updates the var, returns the updated value."""
|
||||||
new_var = self._assign_moving_average(var, value, self.renorm_momentum,
|
new_var = self._assign_moving_average(var, value, self.renorm_momentum,
|
||||||
@ -676,6 +653,7 @@ class BatchNormalizationBase(Layer):
|
|||||||
|
|
||||||
def _fake_update():
|
def _fake_update():
|
||||||
return array_ops.identity(var)
|
return array_ops.identity(var)
|
||||||
|
|
||||||
return tf_utils.smart_cond(training, _do_update, _fake_update)
|
return tf_utils.smart_cond(training, _do_update, _fake_update)
|
||||||
|
|
||||||
# TODO(yuefengz): colocate the operations
|
# TODO(yuefengz): colocate the operations
|
||||||
@ -753,12 +731,13 @@ class BatchNormalizationBase(Layer):
|
|||||||
ndims = len(input_shape)
|
ndims = len(input_shape)
|
||||||
reduction_axes = [i for i in range(ndims) if i not in self.axis]
|
reduction_axes = [i for i in range(ndims) if i not in self.axis]
|
||||||
if self.virtual_batch_size is not None:
|
if self.virtual_batch_size is not None:
|
||||||
del reduction_axes[1] # Do not reduce along virtual batch dim
|
del reduction_axes[1] # Do not reduce along virtual batch dim
|
||||||
|
|
||||||
# Broadcasting only necessary for single-axis batch norm where the axis is
|
# Broadcasting only necessary for single-axis batch norm where the axis is
|
||||||
# not the last dimension
|
# not the last dimension
|
||||||
broadcast_shape = [1] * ndims
|
broadcast_shape = [1] * ndims
|
||||||
broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
|
broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
|
||||||
|
|
||||||
def _broadcast(v):
|
def _broadcast(v):
|
||||||
if (v is not None and len(v.shape) != ndims and
|
if (v is not None and len(v.shape) != ndims and
|
||||||
reduction_axes != list(range(ndims - 1))):
|
reduction_axes != list(range(ndims - 1))):
|
||||||
@ -783,11 +762,9 @@ class BatchNormalizationBase(Layer):
|
|||||||
if self.adjustment:
|
if self.adjustment:
|
||||||
adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
|
adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
|
||||||
# Adjust only during training.
|
# Adjust only during training.
|
||||||
adj_scale = tf_utils.smart_cond(training,
|
adj_scale = tf_utils.smart_cond(training, lambda: adj_scale,
|
||||||
lambda: adj_scale,
|
|
||||||
lambda: array_ops.ones_like(adj_scale))
|
lambda: array_ops.ones_like(adj_scale))
|
||||||
adj_bias = tf_utils.smart_cond(training,
|
adj_bias = tf_utils.smart_cond(training, lambda: adj_bias,
|
||||||
lambda: adj_bias,
|
|
||||||
lambda: array_ops.zeros_like(adj_bias))
|
lambda: array_ops.zeros_like(adj_bias))
|
||||||
scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)
|
scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)
|
||||||
|
|
||||||
@ -879,11 +856,8 @@ class BatchNormalizationBase(Layer):
|
|||||||
scale = math_ops.cast(scale, inputs.dtype)
|
scale = math_ops.cast(scale, inputs.dtype)
|
||||||
# TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
|
# TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
|
||||||
# math in float16 hurts validation accuracy of popular models like resnet.
|
# math in float16 hurts validation accuracy of popular models like resnet.
|
||||||
outputs = nn.batch_normalization(inputs,
|
outputs = nn.batch_normalization(inputs, _broadcast(mean),
|
||||||
_broadcast(mean),
|
_broadcast(variance), offset, scale,
|
||||||
_broadcast(variance),
|
|
||||||
offset,
|
|
||||||
scale,
|
|
||||||
self.epsilon)
|
self.epsilon)
|
||||||
# If some components of the shape got lost due to adjustments, fix that.
|
# If some components of the shape got lost due to adjustments, fix that.
|
||||||
outputs.set_shape(input_shape)
|
outputs.set_shape(input_shape)
|
||||||
@ -897,21 +871,32 @@ class BatchNormalizationBase(Layer):
|
|||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
config = {
|
config = {
|
||||||
'axis': self.axis,
|
'axis':
|
||||||
'momentum': self.momentum,
|
self.axis,
|
||||||
'epsilon': self.epsilon,
|
'momentum':
|
||||||
'center': self.center,
|
self.momentum,
|
||||||
'scale': self.scale,
|
'epsilon':
|
||||||
'beta_initializer': initializers.serialize(self.beta_initializer),
|
self.epsilon,
|
||||||
'gamma_initializer': initializers.serialize(self.gamma_initializer),
|
'center':
|
||||||
|
self.center,
|
||||||
|
'scale':
|
||||||
|
self.scale,
|
||||||
|
'beta_initializer':
|
||||||
|
initializers.serialize(self.beta_initializer),
|
||||||
|
'gamma_initializer':
|
||||||
|
initializers.serialize(self.gamma_initializer),
|
||||||
'moving_mean_initializer':
|
'moving_mean_initializer':
|
||||||
initializers.serialize(self.moving_mean_initializer),
|
initializers.serialize(self.moving_mean_initializer),
|
||||||
'moving_variance_initializer':
|
'moving_variance_initializer':
|
||||||
initializers.serialize(self.moving_variance_initializer),
|
initializers.serialize(self.moving_variance_initializer),
|
||||||
'beta_regularizer': regularizers.serialize(self.beta_regularizer),
|
'beta_regularizer':
|
||||||
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
|
regularizers.serialize(self.beta_regularizer),
|
||||||
'beta_constraint': constraints.serialize(self.beta_constraint),
|
'gamma_regularizer':
|
||||||
'gamma_constraint': constraints.serialize(self.gamma_constraint)
|
regularizers.serialize(self.gamma_regularizer),
|
||||||
|
'beta_constraint':
|
||||||
|
constraints.serialize(self.beta_constraint),
|
||||||
|
'gamma_constraint':
|
||||||
|
constraints.serialize(self.gamma_constraint)
|
||||||
}
|
}
|
||||||
# Only add TensorFlow-specific parameters if they are set, so as to preserve
|
# Only add TensorFlow-specific parameters if they are set, so as to preserve
|
||||||
# model compatibility with external Keras.
|
# model compatibility with external Keras.
|
||||||
@ -942,16 +927,14 @@ def replace_in_base_docstring(replacements):
|
|||||||
@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
|
@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
|
||||||
class BatchNormalization(BatchNormalizationBase):
|
class BatchNormalization(BatchNormalizationBase):
|
||||||
|
|
||||||
__doc__ = replace_in_base_docstring(
|
__doc__ = replace_in_base_docstring([("""
|
||||||
[('''
|
|
||||||
fused: if `True`, use a faster, fused implementation, or raise a ValueError
|
fused: if `True`, use a faster, fused implementation, or raise a ValueError
|
||||||
if the fused implementation cannot be used. If `None`, use the faster
|
if the fused implementation cannot be used. If `None`, use the faster
|
||||||
implementation if possible. If False, do not used the fused
|
implementation if possible. If False, do not used the fused
|
||||||
implementation.''',
|
implementation.""", """
|
||||||
'''
|
|
||||||
fused: if `None` or `True`, use a faster, fused implementation if possible.
|
fused: if `None` or `True`, use a faster, fused implementation if possible.
|
||||||
If `False`, use the system recommended implementation.'''),
|
If `False`, use the system recommended implementation."""),
|
||||||
('{{TRAINABLE_ATTRIBUTE_NOTE}}', '')])
|
('{{TRAINABLE_ATTRIBUTE_NOTE}}', '')])
|
||||||
|
|
||||||
_USE_V2_BEHAVIOR = False
|
_USE_V2_BEHAVIOR = False
|
||||||
|
|
||||||
@ -1048,37 +1031,30 @@ class LayerNormalization(Layer):
|
|||||||
|
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
axis: Integer or List/Tuple. The axis or axes
|
axis: Integer or List/Tuple. The axis or axes to normalize across. Typically
|
||||||
to normalize across. Typically this is the features axis/axes. The
|
this is the features axis/axes. The left-out axes are typically the batch
|
||||||
left-out axes are typically the batch axis/axes.
|
axis/axes. This argument defaults to `-1`, the last dimension in the
|
||||||
This argument defaults to `-1`, the last dimension in the input.
|
input.
|
||||||
epsilon: Small float added to variance to avoid dividing by zero.
|
epsilon: Small float added to variance to avoid dividing by zero. Defaults
|
||||||
Defaults to 1e-3
|
to 1e-3
|
||||||
center: If True, add offset of `beta` to normalized tensor.
|
center: If True, add offset of `beta` to normalized tensor. If False, `beta`
|
||||||
If False, `beta` is ignored. Defaults to True.
|
is ignored. Defaults to True.
|
||||||
scale: If True, multiply by `gamma`.
|
scale: If True, multiply by `gamma`. If False, `gamma` is not used. Defaults
|
||||||
If False, `gamma` is not used. Defaults to True.
|
to True. When the next layer is linear (also e.g. `nn.relu`), this can be
|
||||||
When the next layer is linear (also e.g. `nn.relu`),
|
disabled since the scaling will be done by the next layer.
|
||||||
this can be disabled since the scaling
|
|
||||||
will be done by the next layer.
|
|
||||||
beta_initializer: Initializer for the beta weight. Defaults to zeros.
|
beta_initializer: Initializer for the beta weight. Defaults to zeros.
|
||||||
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
|
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
|
||||||
beta_regularizer: Optional regularizer for the beta weight. None by default.
|
beta_regularizer: Optional regularizer for the beta weight. None by default.
|
||||||
gamma_regularizer: Optional regularizer for the gamma weight.
|
gamma_regularizer: Optional regularizer for the gamma weight. None by
|
||||||
None by default.
|
default.
|
||||||
beta_constraint: Optional constraint for the beta weight. None by default.
|
beta_constraint: Optional constraint for the beta weight. None by default.
|
||||||
gamma_constraint: Optional constraint for the gamma weight. None by default.
|
gamma_constraint: Optional constraint for the gamma weight. None by default.
|
||||||
trainable: Boolean, if `True` the variables will be marked as trainable.
|
trainable: Boolean, if `True` the variables will be marked as trainable.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
|
Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of
|
||||||
Input shape:
|
integers, does not include the samples axis) when using this layer as the
|
||||||
Arbitrary. Use the keyword argument `input_shape`
|
first layer in a model.
|
||||||
(tuple of integers, does not include the samples axis)
|
Output shape: Same shape as input.
|
||||||
when using this layer as the first layer in a model.
|
|
||||||
|
|
||||||
Output shape:
|
|
||||||
Same shape as input.
|
|
||||||
|
|
||||||
Reference:
|
Reference:
|
||||||
- [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
|
- [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
|
||||||
"""
|
"""
|
||||||
@ -1204,9 +1180,9 @@ class LayerNormalization(Layer):
|
|||||||
broadcast_shape = [1] * ndims
|
broadcast_shape = [1] * ndims
|
||||||
for dim in self.axis:
|
for dim in self.axis:
|
||||||
broadcast_shape[dim] = input_shape.dims[dim].value
|
broadcast_shape[dim] = input_shape.dims[dim].value
|
||||||
|
|
||||||
def _broadcast(v):
|
def _broadcast(v):
|
||||||
if (v is not None and len(v.shape) != ndims and
|
if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]):
|
||||||
self.axis != [ndims - 1]):
|
|
||||||
return array_ops.reshape(v, broadcast_shape)
|
return array_ops.reshape(v, broadcast_shape)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user