diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index aa8598d7319..2b360d21143 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -491,6 +491,9 @@ class BatchNormalization(Layer):
 
     return (r, d, new_mean, new_variance)
 
+  def _moments(self, inputs, reduction_axes, keep_dims):
+    return nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
+
   def call(self, inputs, training=None):
     if training is None:
       training = K.learning_phase()
@@ -562,7 +565,8 @@ class BatchNormalization(Layer):
       # Some of the computations here are not necessary when training==False
       # but not a constant. However, this makes the code simpler.
       keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1
-      mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims)
+      mean, variance = self._moments(
+          inputs, reduction_axes, keep_dims=keep_dims)
 
       moving_mean = self.moving_mean
       moving_variance = self.moving_variance