Internal changes.
PiperOrigin-RevId: 223060752
This commit is contained in:
parent
3409380f7c
commit
a69210ede1
@ -491,6 +491,9 @@ class BatchNormalization(Layer):
|
|||||||
|
|
||||||
return (r, d, new_mean, new_variance)
|
return (r, d, new_mean, new_variance)
|
||||||
|
|
||||||
|
def _moments(self, inputs, reduction_axes, keep_dims):
|
||||||
|
return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
|
||||||
|
|
||||||
def call(self, inputs, training=None):
|
def call(self, inputs, training=None):
|
||||||
if training is None:
|
if training is None:
|
||||||
training = K.learning_phase()
|
training = K.learning_phase()
|
||||||
@ -562,7 +565,8 @@ class BatchNormalization(Layer):
|
|||||||
# Some of the computations here are not necessary when training==False
|
# Some of the computations here are not necessary when training==False
|
||||||
# but not a constant. However, this makes the code simpler.
|
# but not a constant. However, this makes the code simpler.
|
||||||
keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
|
keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
|
||||||
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
|
mean, variance = self._moments(
|
||||||
|
inputs, reduction_axes, keep_dims=keep_dims)
|
||||||
|
|
||||||
moving_mean = self.moving_mean
|
moving_mean = self.moving_mean
|
||||||
moving_variance = self.moving_variance
|
moving_variance = self.moving_variance
|
||||||
|
Loading…
Reference in New Issue
Block a user