Merge pull request #25538 from KyotoSunshine:patch-3
PiperOrigin-RevId: 234221477
This commit is contained in:
commit
6de3cb0023
@ -2308,7 +2308,7 @@ def layer_norm(inputs,
|
||||
initializer=init_ops.ones_initializer(),
|
||||
collections=gamma_collections,
|
||||
trainable=trainable)
|
||||
# Calculate the moments on the last axis (layer activations).
|
||||
# By default, compute the moments across all the dimensions except the one with index 0.
|
||||
norm_axes = list(range(begin_norm_axis, inputs_rank))
|
||||
mean, variance = nn.moments(inputs, norm_axes, keep_dims=True)
|
||||
# Compute layer normalization using the batch_normalization function.
|
||||
|
Loading…
Reference in New Issue
Block a user