Merge pull request #25538 from KyotoSunshine:patch-3

PiperOrigin-RevId: 234221477
This commit is contained in:
TensorFlower Gardener 2019-02-15 15:26:48 -08:00
commit 6de3cb0023

View File

@ -2308,7 +2308,7 @@ def layer_norm(inputs,
initializer=init_ops.ones_initializer(),
collections=gamma_collections,
trainable=trainable)
# Calculate the moments on the last axis (layer activations).
# By default, compute the moments across all the dimensions except the one with index 0.
norm_axes = list(range(begin_norm_axis, inputs_rank))
mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)
# Compute layer normalization using the batch_normalization function.