Fix formatting of file

PiperOrigin-RevId: 312408716
Change-Id: I63f427c3453745008b159afc7a459df63b0ec8d0
This commit is contained in:
Gaurav Jain 2020-05-19 20:31:31 -07:00 committed by TensorFlower Gardener
parent 361470d24a
commit f8a797e13e

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Normalization layers. """Normalization layers."""
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -43,7 +42,7 @@ from tensorflow.python.util.tf_export import keras_export
class BatchNormalizationBase(Layer): class BatchNormalizationBase(Layer):
r"""Normalize and scale inputs or activations. (Ioffe and Szegedy, 2014). r"""Normalize and scale inputs or activations.
Normalize the activations of the previous layer at each batch, Normalize the activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation i.e. applies a transformation that maintains the mean activation
@ -65,20 +64,16 @@ class BatchNormalizationBase(Layer):
`training=False` when calling the model, or using `model.predict`. `training=False` when calling the model, or using `model.predict`.
Arguments: Arguments:
axis: Integer, the axis that should be normalized axis: Integer, the axis that should be normalized (typically the features
(typically the features axis). axis). For instance, after a `Conv2D` layer with
For instance, after a `Conv2D` layer with `data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
`data_format="channels_first"`,
set `axis=1` in `BatchNormalization`.
momentum: Momentum for the moving average. momentum: Momentum for the moving average.
epsilon: Small float added to variance to avoid dividing by zero. epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor. center: If True, add offset of `beta` to normalized tensor. If False, `beta`
If False, `beta` is ignored. is ignored.
scale: If True, multiply by `gamma`. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the
If False, `gamma` is not used. next layer is linear (also e.g. `nn.relu`), this can be disabled since the
When the next layer is linear (also e.g. `nn.relu`), scaling will be done by the next layer.
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight. beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight. gamma_initializer: Initializer for the gamma weight.
moving_mean_initializer: Initializer for the moving mean. moving_mean_initializer: Initializer for the moving mean.
@ -91,15 +86,15 @@ class BatchNormalizationBase(Layer):
https://arxiv.org/abs/1702.03275). This adds extra variables during https://arxiv.org/abs/1702.03275). This adds extra variables during
training. The inference is the same for either value of this parameter. training. The inference is the same for either value of this parameter.
renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
scalar `Tensors` used to clip the renorm correction. The correction scalar `Tensors` used to clip the renorm correction. The correction `(r,
`(r, d)` is used as `corrected_value = normalized_value * r + d`, with d)` is used as `corrected_value = normalized_value * r + d`, with `r`
`r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin, clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
dmax are set to inf, 0, inf, respectively. dmax are set to inf, 0, inf, respectively.
renorm_momentum: Momentum used to update the moving means and standard renorm_momentum: Momentum used to update the moving means and standard
deviations with renorm. Unlike `momentum`, this affects training deviations with renorm. Unlike `momentum`, this affects training and
and should be neither too small (which would add noise) nor too large should be neither too small (which would add noise) nor too large (which
(which would give stale estimates). Note that `momentum` is still applied would give stale estimates). Note that `momentum` is still applied to get
to get the means and variances for inference. the means and variances for inference.
fused: if `True`, use a faster, fused implementation, or raise a ValueError fused: if `True`, use a faster, fused implementation, or raise a ValueError
if the fused implementation cannot be used. If `None`, use the faster if the fused implementation cannot be used. If `None`, use the faster
implementation if possible. If False, do not used the fused implementation if possible. If False, do not used the fused
@ -117,54 +112,36 @@ class BatchNormalizationBase(Layer):
example, if axis==-1, example, if axis==-1,
`adjustment = lambda shape: ( `adjustment = lambda shape: (
tf.random.uniform(shape[-1:], 0.93, 1.07), tf.random.uniform(shape[-1:], 0.93, 1.07),
tf.random.uniform(shape[-1:], -0.1, 0.1))` tf.random.uniform(shape[-1:], -0.1, 0.1))` will scale the normalized
will scale the normalized value by up to 7% up or down, then shift the value by up to 7% up or down, then shift the result by up to 0.1
result by up to 0.1 (with independent scaling and bias for each feature (with independent scaling and bias for each feature but shared
but shared across all examples), and finally apply gamma and/or beta. If across all examples), and finally apply gamma and/or beta. If
`None`, no adjustment is applied. Cannot be specified if `None`, no adjustment is applied. Cannot be specified if
virtual_batch_size is specified. virtual_batch_size is specified.
Call arguments: Call arguments:
inputs: Input tensor (of any rank). inputs: Input tensor (of any rank).
training: Python boolean indicating whether the layer should behave in training: Python boolean indicating whether the layer should behave in
training mode or in inference mode. training mode or in inference mode.
- `training=True`: The layer will normalize its inputs using the - `training=True`: The layer will normalize its inputs using the mean and
mean and variance of the current batch of inputs. variance of the current batch of inputs.
- `training=False`: The layer will normalize its inputs using the - `training=False`: The layer will normalize its inputs using the mean and
mean and variance of its moving statistics, learned during training. variance of its moving statistics, learned during training.
Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of
Input shape: integers, does not include the samples axis) when using this layer as the
Arbitrary. Use the keyword argument `input_shape` first layer in a model.
(tuple of integers, does not include the samples axis) Output shape: Same shape as input. {{TRAINABLE_ATTRIBUTE_NOTE}}
when using this layer as the first layer in a model. Normalization equations: Consider the intermediate activations \(x\) of a
mini-batch of size
Output shape: \\(m\\): We can compute the mean and variance of the batch \\({\mu_B} =
Same shape as input. \frac{1}{m} \sum_{i=1}^{m} {x_i}\\) \\({\sigma_B^2} = \frac{1}{m}
\sum_{i=1}^{m} ({x_i} - {\mu_B})^2\\) and then compute a normalized
{{TRAINABLE_ATTRIBUTE_NOTE}} \\(x\\), including a small factor \\({\epsilon}\\) for numerical
stability. \\(\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 +
Normalization equations: \epsilon}}\\) And finally \\(\hat{x}\) is linearly transformed by
Consider the intermediate activations \(x\) of a mini-batch of size \({\gamma}\\)
\\(m\\): and \\({\beta}\\), which are learned parameters: \\({y_i} = {\gamma *
\hat{x_i} + \beta}\\)
We can compute the mean and variance of the batch
\\({\mu_B} = \frac{1}{m} \sum_{i=1}^{m} {x_i}\\)
\\({\sigma_B^2} = \frac{1}{m} \sum_{i=1}^{m} ({x_i} - {\mu_B})^2\\)
and then compute a normalized \\(x\\), including a small factor
\\({\epsilon}\\) for numerical stability.
\\(\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}\\)
And finally \\(\hat{x}\) is linearly transformed by \({\gamma}\\)
and \\({\beta}\\), which are learned parameters:
\\({y_i} = {\gamma * \hat{x_i} + \beta}\\)
Reference: Reference:
- [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167). - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).
""" """
@ -195,8 +172,7 @@ class BatchNormalizationBase(Layer):
adjustment=None, adjustment=None,
name=None, name=None,
**kwargs): **kwargs):
super(BatchNormalizationBase, self).__init__( super(BatchNormalizationBase, self).__init__(name=name, **kwargs)
name=name, **kwargs)
if isinstance(axis, (list, tuple)): if isinstance(axis, (list, tuple)):
self.axis = axis[:] self.axis = axis[:]
elif isinstance(axis, int): elif isinstance(axis, int):
@ -275,8 +251,8 @@ class BatchNormalizationBase(Layer):
# TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check. # TODO(reedwm): Support fp64 in FusedBatchNorm then remove this check.
if self._compute_dtype not in ('float16', 'bfloat16', 'float32', None): if self._compute_dtype not in ('float16', 'bfloat16', 'float32', None):
raise ValueError('Passing fused=True is only supported when the compute ' raise ValueError('Passing fused=True is only supported when the compute '
'dtype is float16, bfloat16, or float32. Got dtype: %s' 'dtype is float16, bfloat16, or float32. Got dtype: %s' %
% (self._compute_dtype,)) (self._compute_dtype,))
def _fused_can_be_used(self): def _fused_can_be_used(self):
try: try:
@ -380,8 +356,9 @@ class BatchNormalizationBase(Layer):
param_shape = (list(axis_to_dim.values())[0],) param_shape = (list(axis_to_dim.values())[0],)
else: else:
# Parameter shape is the original shape but with 1 in all non-axis dims # Parameter shape is the original shape but with 1 in all non-axis dims
param_shape = [axis_to_dim[i] if i in axis_to_dim param_shape = [
else 1 for i in range(ndims)] axis_to_dim[i] if i in axis_to_dim else 1 for i in range(ndims)
]
if self.virtual_batch_size is not None: if self.virtual_batch_size is not None:
# When using virtual batches, add an extra dim at index 1 # When using virtual batches, add an extra dim at index 1
param_shape.insert(1, 1) param_shape.insert(1, 1)
@ -507,8 +484,7 @@ class BatchNormalizationBase(Layer):
decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay') decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay')
if decay.dtype != variable.dtype.base_dtype: if decay.dtype != variable.dtype.base_dtype:
decay = math_ops.cast(decay, variable.dtype.base_dtype) decay = math_ops.cast(decay, variable.dtype.base_dtype)
update_delta = ( update_delta = (variable - math_ops.cast(value, variable.dtype)) * decay
variable - math_ops.cast(value, variable.dtype)) * decay
if inputs_size is not None: if inputs_size is not None:
update_delta = array_ops.where(inputs_size > 0, update_delta, update_delta = array_ops.where(inputs_size > 0, update_delta,
K.zeros_like(update_delta)) K.zeros_like(update_delta))
@ -650,8 +626,9 @@ class BatchNormalizationBase(Layer):
with ops.control_dependencies([r, d]): with ops.control_dependencies([r, d]):
mean = array_ops.identity(mean) mean = array_ops.identity(mean)
stddev = array_ops.identity(stddev) stddev = array_ops.identity(stddev)
rmin, rmax, dmax = [self.renorm_clipping.get(key) rmin, rmax, dmax = [
for key in ['rmin', 'rmax', 'dmax']] self.renorm_clipping.get(key) for key in ['rmin', 'rmax', 'dmax']
]
if rmin is not None: if rmin is not None:
r = math_ops.maximum(r, rmin) r = math_ops.maximum(r, rmin)
if rmax is not None: if rmax is not None:
@ -661,13 +638,13 @@ class BatchNormalizationBase(Layer):
d = math_ops.minimum(d, dmax) d = math_ops.minimum(d, dmax)
# When not training, use r=1, d=0. # When not training, use r=1, d=0.
r = tf_utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r)) r = tf_utils.smart_cond(training, lambda: r, lambda: array_ops.ones_like(r))
d = tf_utils.smart_cond(training, d = tf_utils.smart_cond(training, lambda: d,
lambda: d,
lambda: array_ops.zeros_like(d)) lambda: array_ops.zeros_like(d))
def _update_renorm_variable(var, value, inputs_size): def _update_renorm_variable(var, value, inputs_size):
"""Updates a moving average and weight, returns the unbiased value.""" """Updates a moving average and weight, returns the unbiased value."""
value = array_ops.identity(value) value = array_ops.identity(value)
def _do_update(): def _do_update():
"""Updates the var, returns the updated value.""" """Updates the var, returns the updated value."""
new_var = self._assign_moving_average(var, value, self.renorm_momentum, new_var = self._assign_moving_average(var, value, self.renorm_momentum,
@ -676,6 +653,7 @@ class BatchNormalizationBase(Layer):
def _fake_update(): def _fake_update():
return array_ops.identity(var) return array_ops.identity(var)
return tf_utils.smart_cond(training, _do_update, _fake_update) return tf_utils.smart_cond(training, _do_update, _fake_update)
# TODO(yuefengz): colocate the operations # TODO(yuefengz): colocate the operations
@ -759,6 +737,7 @@ class BatchNormalizationBase(Layer):
# not the last dimension # not the last dimension
broadcast_shape = [1] * ndims broadcast_shape = [1] * ndims
broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
def _broadcast(v): def _broadcast(v):
if (v is not None and len(v.shape) != ndims and if (v is not None and len(v.shape) != ndims and
reduction_axes != list(range(ndims - 1))): reduction_axes != list(range(ndims - 1))):
@ -783,11 +762,9 @@ class BatchNormalizationBase(Layer):
if self.adjustment: if self.adjustment:
adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
# Adjust only during training. # Adjust only during training.
adj_scale = tf_utils.smart_cond(training, adj_scale = tf_utils.smart_cond(training, lambda: adj_scale,
lambda: adj_scale,
lambda: array_ops.ones_like(adj_scale)) lambda: array_ops.ones_like(adj_scale))
adj_bias = tf_utils.smart_cond(training, adj_bias = tf_utils.smart_cond(training, lambda: adj_bias,
lambda: adj_bias,
lambda: array_ops.zeros_like(adj_bias)) lambda: array_ops.zeros_like(adj_bias))
scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset)
@ -879,11 +856,8 @@ class BatchNormalizationBase(Layer):
scale = math_ops.cast(scale, inputs.dtype) scale = math_ops.cast(scale, inputs.dtype)
# TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
# math in float16 hurts validation accuracy of popular models like resnet. # math in float16 hurts validation accuracy of popular models like resnet.
outputs = nn.batch_normalization(inputs, outputs = nn.batch_normalization(inputs, _broadcast(mean),
_broadcast(mean), _broadcast(variance), offset, scale,
_broadcast(variance),
offset,
scale,
self.epsilon) self.epsilon)
# If some components of the shape got lost due to adjustments, fix that. # If some components of the shape got lost due to adjustments, fix that.
outputs.set_shape(input_shape) outputs.set_shape(input_shape)
@ -897,21 +871,32 @@ class BatchNormalizationBase(Layer):
def get_config(self): def get_config(self):
config = { config = {
'axis': self.axis, 'axis':
'momentum': self.momentum, self.axis,
'epsilon': self.epsilon, 'momentum':
'center': self.center, self.momentum,
'scale': self.scale, 'epsilon':
'beta_initializer': initializers.serialize(self.beta_initializer), self.epsilon,
'gamma_initializer': initializers.serialize(self.gamma_initializer), 'center':
self.center,
'scale':
self.scale,
'beta_initializer':
initializers.serialize(self.beta_initializer),
'gamma_initializer':
initializers.serialize(self.gamma_initializer),
'moving_mean_initializer': 'moving_mean_initializer':
initializers.serialize(self.moving_mean_initializer), initializers.serialize(self.moving_mean_initializer),
'moving_variance_initializer': 'moving_variance_initializer':
initializers.serialize(self.moving_variance_initializer), initializers.serialize(self.moving_variance_initializer),
'beta_regularizer': regularizers.serialize(self.beta_regularizer), 'beta_regularizer':
'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), regularizers.serialize(self.beta_regularizer),
'beta_constraint': constraints.serialize(self.beta_constraint), 'gamma_regularizer':
'gamma_constraint': constraints.serialize(self.gamma_constraint) regularizers.serialize(self.gamma_regularizer),
'beta_constraint':
constraints.serialize(self.beta_constraint),
'gamma_constraint':
constraints.serialize(self.gamma_constraint)
} }
# Only add TensorFlow-specific parameters if they are set, so as to preserve # Only add TensorFlow-specific parameters if they are set, so as to preserve
# model compatibility with external Keras. # model compatibility with external Keras.
@ -942,15 +927,13 @@ def replace_in_base_docstring(replacements):
@keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring @keras_export(v1=['keras.layers.BatchNormalization']) # pylint: disable=missing-docstring
class BatchNormalization(BatchNormalizationBase): class BatchNormalization(BatchNormalizationBase):
__doc__ = replace_in_base_docstring( __doc__ = replace_in_base_docstring([("""
[('''
fused: if `True`, use a faster, fused implementation, or raise a ValueError fused: if `True`, use a faster, fused implementation, or raise a ValueError
if the fused implementation cannot be used. If `None`, use the faster if the fused implementation cannot be used. If `None`, use the faster
implementation if possible. If False, do not used the fused implementation if possible. If False, do not used the fused
implementation.''', implementation.""", """
'''
fused: if `None` or `True`, use a faster, fused implementation if possible. fused: if `None` or `True`, use a faster, fused implementation if possible.
If `False`, use the system recommended implementation.'''), If `False`, use the system recommended implementation."""),
('{{TRAINABLE_ATTRIBUTE_NOTE}}', '')]) ('{{TRAINABLE_ATTRIBUTE_NOTE}}', '')])
_USE_V2_BEHAVIOR = False _USE_V2_BEHAVIOR = False
@ -1048,37 +1031,30 @@ class LayerNormalization(Layer):
Arguments: Arguments:
axis: Integer or List/Tuple. The axis or axes axis: Integer or List/Tuple. The axis or axes to normalize across. Typically
to normalize across. Typically this is the features axis/axes. The this is the features axis/axes. The left-out axes are typically the batch
left-out axes are typically the batch axis/axes. axis/axes. This argument defaults to `-1`, the last dimension in the
This argument defaults to `-1`, the last dimension in the input. input.
epsilon: Small float added to variance to avoid dividing by zero. epsilon: Small float added to variance to avoid dividing by zero. Defaults
Defaults to 1e-3 to 1e-3
center: If True, add offset of `beta` to normalized tensor. center: If True, add offset of `beta` to normalized tensor. If False, `beta`
If False, `beta` is ignored. Defaults to True. is ignored. Defaults to True.
scale: If True, multiply by `gamma`. scale: If True, multiply by `gamma`. If False, `gamma` is not used. Defaults
If False, `gamma` is not used. Defaults to True. to True. When the next layer is linear (also e.g. `nn.relu`), this can be
When the next layer is linear (also e.g. `nn.relu`), disabled since the scaling will be done by the next layer.
this can be disabled since the scaling
will be done by the next layer.
beta_initializer: Initializer for the beta weight. Defaults to zeros. beta_initializer: Initializer for the beta weight. Defaults to zeros.
gamma_initializer: Initializer for the gamma weight. Defaults to ones. gamma_initializer: Initializer for the gamma weight. Defaults to ones.
beta_regularizer: Optional regularizer for the beta weight. None by default. beta_regularizer: Optional regularizer for the beta weight. None by default.
gamma_regularizer: Optional regularizer for the gamma weight. gamma_regularizer: Optional regularizer for the gamma weight. None by
None by default. default.
beta_constraint: Optional constraint for the beta weight. None by default. beta_constraint: Optional constraint for the beta weight. None by default.
gamma_constraint: Optional constraint for the gamma weight. None by default. gamma_constraint: Optional constraint for the gamma weight. None by default.
trainable: Boolean, if `True` the variables will be marked as trainable. trainable: Boolean, if `True` the variables will be marked as trainable.
Defaults to True. Defaults to True.
Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of
Input shape: integers, does not include the samples axis) when using this layer as the
Arbitrary. Use the keyword argument `input_shape` first layer in a model.
(tuple of integers, does not include the samples axis) Output shape: Same shape as input.
when using this layer as the first layer in a model.
Output shape:
Same shape as input.
Reference: Reference:
- [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450). - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).
""" """
@ -1204,9 +1180,9 @@ class LayerNormalization(Layer):
broadcast_shape = [1] * ndims broadcast_shape = [1] * ndims
for dim in self.axis: for dim in self.axis:
broadcast_shape[dim] = input_shape.dims[dim].value broadcast_shape[dim] = input_shape.dims[dim].value
def _broadcast(v): def _broadcast(v):
if (v is not None and len(v.shape) != ndims and if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]):
self.axis != [ndims - 1]):
return array_ops.reshape(v, broadcast_shape) return array_ops.reshape(v, broadcast_shape)
return v return v