Internal changes.
PiperOrigin-RevId: 223060752
This commit is contained in:
parent
3409380f7c
commit
a69210ede1
@ -491,6 +491,9 @@ class BatchNormalization(Layer):
|
||||
|
||||
return (r, d, new_mean, new_variance)
|
||||
|
||||
def _moments(self, inputs, reduction_axes, keep_dims):
|
||||
return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
@ -562,7 +565,8 @@ class BatchNormalization(Layer):
|
||||
# Some of the computations here are not necessary when training==False
|
||||
# but not a constant. However, this makes the code simpler.
|
||||
keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
|
||||
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
|
||||
mean, variance = self._moments(
|
||||
inputs, reduction_axes, keep_dims=keep_dims)
|
||||
|
||||
moving_mean = self.moving_mean
|
||||
moving_variance = self.moving_variance
|
||||
|
Loading…
Reference in New Issue
Block a user