diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index aa8598d7319..2b360d21143 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -491,6 +491,9 @@ class BatchNormalization(Layer): return (r, d, new_mean, new_variance) + def _moments(self, inputs, reduction_axes, keep_dims): + return nn.moments(inputs, reduction_axes, keep_dims=keep_dims) + def call(self, inputs, training=None): if training is None: training = K.learning_phase() @@ -562,7 +565,8 @@ class BatchNormalization(Layer): # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 - mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) + mean, variance = self._moments( + inputs, reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance