This commit is contained in:
Reed Wanderman-Milne 2020-11-16 15:01:07 -08:00 committed by Geeta Chavan
parent 259339c432
commit 2cf3cc466c
2 changed files with 42 additions and 5 deletions

View File

@ -248,6 +248,7 @@ class BatchNormalizationBase(Layer):
axis = [self.axis] if isinstance(self.axis, int) else self.axis
# Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
# input rank is required to be 4 (which is checked later).
# TODO(b/173253101): Once the input rank can be 5, update this check.
if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
raise ValueError('Passing fused=True is only supported when axis is 1 '
'or 3')
@ -329,14 +330,19 @@ class BatchNormalizationBase(Layer):
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
# output back to its original shape accordingly.
if self._USE_V2_BEHAVIOR:
# TODO(b/173253101): Using fused in the 5D case is currently disabled
# due to a regression on UNet, so it is only currently only supported in
# the 4D case.
if self.fused is None:
self.fused = ndims in (4, 5)
elif self.fused and ndims not in (4, 5):
raise ValueError('Batch normalization layers with fused=True only '
'support 4D or 5D input tensors.')
self.fused = ndims == 4
elif self.fused and ndims != 4:
raise ValueError('Batch normalization layers with `fused=True` only '
'support 4D or 5D input tensors. '
'Received tensor with shape: %s' %
(tuple(input_shape),))
else:
assert self.fused is not None
self.fused = (ndims in (4, 5) and self._fused_can_be_used())
self.fused = (ndims == 4 and self._fused_can_be_used())
# TODO(chrisying): fused batch norm is currently not supported for
# multi-axis batch norm and by extension virtual batches. In some cases,
# it might be possible to use fused batch norm but would require reshaping

View File

@ -241,6 +241,31 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
self.assertAllClose(model.bn.moving_mean.numpy(), [0.047], atol=3e-3)
self.assertAllClose(model.bn.moving_variance.numpy(), [0.9], atol=3e-2)
@combinations.generate(combinations.combine(mode=['eager']))
def test_bessels_correction(self):
# Bessel's correction is currently only used in the fused case. In the
# future, it may be used in the nonfused case as well.
x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1])
layer = normalization_v2.BatchNormalization(
momentum=0.5, moving_variance_initializer='zeros')
layer(x, training=True)
self.assertTrue(layer.fused)
# Since fused is used, Bessel's correction is used. The variance of [0, 2]
# is 2 with Bessel's correction. Since the momentum is 0.5, the variance is
# 2 * 0.5 == 1.
self.assertAllEqual(self.evaluate(layer.moving_variance), [1.])
x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1, 1])
layer = normalization_v2.BatchNormalization(
momentum=0.5, moving_variance_initializer='zeros')
layer(x, training=True)
self.assertFalse(layer.fused)
# Since fused is not used, Bessel's correction is not used. The variance of
# [0, 2] is 1 without Bessel's correction. Since the momentum is 0.5, the
# variance is 1 * 0.5 == 0.5.
self.assertAllEqual(self.evaluate(layer.moving_variance), [0.5])
class BatchNormalizationV1Test(keras_parameterized.TestCase):
@ -291,6 +316,12 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
norm(inp)
self.assertEqual(norm.fused, False)
norm = normalization_v2.BatchNormalization()
self.assertIsNone(norm.fused)
inp = keras.layers.Input(shape=(4, 4, 4, 4))
norm(inp)
self.assertEqual(norm.fused, False)
norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
self.assertEqual(norm.fused, False)
inp = keras.layers.Input(shape=(4, 4, 4))