diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 0737fe11712..29ab50c1b08 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -248,6 +248,7 @@ class BatchNormalizationBase(Layer):
     axis = [self.axis] if isinstance(self.axis, int) else self.axis
     # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
     # input rank is required to be 4 (which is checked later).
+    # TODO(b/173253101): Once the input rank can be 5, update this check.
     if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
       raise ValueError('Passing fused=True is only supported when axis is 1 '
                        'or 3')
@@ -329,14 +330,19 @@ class BatchNormalizationBase(Layer):
       # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
       # output back to its original shape accordingly.
       if self._USE_V2_BEHAVIOR:
+        # TODO(b/173253101): Using fused in the 5D case is currently disabled
+        # due to a regression on UNet, so it is only currently only supported in
+        # the 4D case.
         if self.fused is None:
-          self.fused = ndims in (4, 5)
-        elif self.fused and ndims not in (4, 5):
-          raise ValueError('Batch normalization layers with fused=True only '
-                           'support 4D or 5D input tensors.')
+          self.fused = ndims == 4
+        elif self.fused and ndims != 4:
+          raise ValueError('Batch normalization layers with `fused=True` only '
+                           'support 4D or 5D input tensors. '
+                           'Received tensor with shape: %s' %
+                           (tuple(input_shape),))
       else:
         assert self.fused is not None
-        self.fused = (ndims in (4, 5) and self._fused_can_be_used())
+        self.fused = (ndims == 4 and self._fused_can_be_used())
       # TODO(chrisying): fused batch norm is currently not supported for
       # multi-axis batch norm and by extension virtual batches. In some cases,
       # it might be possible to use fused batch norm but would require reshaping
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index a98db36ceea..d468e5d6db2 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -241,6 +241,31 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
     self.assertAllClose(model.bn.moving_mean.numpy(), [0.047], atol=3e-3)
     self.assertAllClose(model.bn.moving_variance.numpy(), [0.9], atol=3e-2)
 
+  @combinations.generate(combinations.combine(mode=['eager']))
+  def test_bessels_correction(self):
+    # Bessel's correction is currently only used in the fused case. In the
+    # future, it may be used in the nonfused case as well.
+
+    x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1])
+    layer = normalization_v2.BatchNormalization(
+        momentum=0.5, moving_variance_initializer='zeros')
+    layer(x, training=True)
+    self.assertTrue(layer.fused)
+    # Since fused is used, Bessel's correction is used. The variance of [0, 2]
+    # is 2 with Bessel's correction. Since the momentum is 0.5, the variance is
+    # 2 * 0.5 == 1.
+    self.assertAllEqual(self.evaluate(layer.moving_variance), [1.])
+
+    x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1, 1])
+    layer = normalization_v2.BatchNormalization(
+        momentum=0.5, moving_variance_initializer='zeros')
+    layer(x, training=True)
+    self.assertFalse(layer.fused)
+    # Since fused is not used, Bessel's correction is not used. The variance of
+    # [0, 2] is 1 without Bessel's correction. Since the momentum is 0.5, the
+    # variance is 1 * 0.5 == 0.5.
+    self.assertAllEqual(self.evaluate(layer.moving_variance), [0.5])
+
 
 class BatchNormalizationV1Test(keras_parameterized.TestCase):
 
@@ -291,6 +316,12 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
     norm(inp)
     self.assertEqual(norm.fused, False)
 
+    norm = normalization_v2.BatchNormalization()
+    self.assertIsNone(norm.fused)
+    inp = keras.layers.Input(shape=(4, 4, 4, 4))
+    norm(inp)
+    self.assertEqual(norm.fused, False)
+
     norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
     self.assertEqual(norm.fused, False)
     inp = keras.layers.Input(shape=(4, 4, 4))