Improve the docs for batch normalization.
PiperOrigin-RevId: 272004386
This commit is contained in:
parent
8427f19013
commit
823ab85e60
@ -41,12 +41,27 @@ from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
class BatchNormalizationBase(Layer):
|
||||
"""Base class of Batch normalization layer (Ioffe and Szegedy, 2014).
|
||||
r"""Normalize and scale inputs or activations. (Ioffe and Szegedy, 2014).
|
||||
|
||||
Normalize the activations of the previous layer at each batch,
|
||||
i.e. applies a transformation that maintains the mean activation
|
||||
close to 0 and the activation standard deviation close to 1.
|
||||
|
||||
Batch normalization differs from other layers in several key aspects:
|
||||
|
||||
1) Adding BatchNormalization with `training=True` to a model causes the
|
||||
result of one example to depend on the contents of all other examples in a
|
||||
minibatch. Be careful when padding batches or masking examples, as these can
|
||||
change the minibatch statistics and affect other examples.
|
||||
|
||||
2) Updates to the weights (moving statistics) are based on the forward pass
|
||||
of a model rather than the result of gradient computations.
|
||||
|
||||
3) When performing inference using a model containing batch normalization, it
|
||||
is generally (though not always) desirable to use accumulated statistics
|
||||
rather than mini-batch statistics. This is acomplished by passing
|
||||
`training=False` when calling the model, or using `model.predict`.
|
||||
|
||||
Arguments:
|
||||
axis: Integer, the axis that should be normalized
|
||||
(typically the features axis).
|
||||
@ -124,11 +139,31 @@ class BatchNormalizationBase(Layer):
|
||||
Output shape:
|
||||
Same shape as input.
|
||||
|
||||
References:
|
||||
- [Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
|
||||
|
||||
{{TRAINABLE_ATTRIBUTE_NOTE}}
|
||||
|
||||
Normalization equations:
|
||||
Consider the intermediate activations \(x\) of a mini-batch of size
|
||||
\(m\):
|
||||
|
||||
We can compute the mean and variance of the batch
|
||||
|
||||
\({\mu_B} = \frac{1}{m} \sum_{i=1}^{m} {x_i}\)
|
||||
|
||||
\({\sigma_B^2} = \frac{1}{m} \sum_{i=1}^{m} ({x_i} - {\mu_B})^2\)
|
||||
|
||||
and then compute a normalized \(x\), including a small factor
|
||||
\({\epsilon}\) for numerical stability.
|
||||
|
||||
\(\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}\)
|
||||
|
||||
And finally \(\hat{x}\) is linearly transformed by \({\gamma}\)
|
||||
and \({\beta}\), which are learned parameters:
|
||||
|
||||
\({y_i} = {\gamma * \hat{x_i} + \beta}\)
|
||||
|
||||
References:
|
||||
- [Batch Normalization: Accelerating Deep Network Training by Reducing
|
||||
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
|
||||
"""
|
||||
|
||||
# By default, the base class uses V2 behavior. The BatchNormalization V1
|
||||
@ -849,7 +884,7 @@ def replace_in_base_docstring(replacements):
|
||||
string = BatchNormalizationBase.__doc__
|
||||
for old, new in replacements:
|
||||
assert old in string
|
||||
string.replace(old, new)
|
||||
string = string.replace(old, new)
|
||||
return string
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user