diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 403b522ce45..9d9524e4e4b 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2308,7 +2308,7 @@ def layer_norm(inputs, initializer=init_ops.ones_initializer(), collections=gamma_collections, trainable=trainable) - # Calculate the moments on the last axis (layer activations). + # By default, compute the moments across all the dimensions except the one with index 0. norm_axes = list(range(begin_norm_axis, inputs_rank)) mean, variance = nn.moments(inputs, norm_axes, keep_dims=True) # Compute layer normalization using the batch_normalization function.