Internal changes.

PiperOrigin-RevId: 223060752
This commit is contained in:
Mingxing Tan 2018-11-27 14:41:24 -08:00 committed by TensorFlower Gardener
parent 3409380f7c
commit a69210ede1

View File

@ -491,6 +491,9 @@ class BatchNormalization(Layer):
return (r, d, new_mean, new_variance)
def _moments(self, inputs, reduction_axes, keep_dims):
return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
@ -562,7 +565,8 @@ class BatchNormalization(Layer):
# Some of the computations here are not necessary when training==False
# but not a constant. However, this makes the code simpler.
keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
mean, variance = self._moments(
inputs, reduction_axes, keep_dims=keep_dims)
moving_mean = self.moving_mean
moving_variance = self.moving_variance