From 2cf3cc466c187fac46eb35577f107aa49276f1dc Mon Sep 17 00:00:00 2001 From: Reed Wanderman-Milne Date: Mon, 16 Nov 2020 15:01:07 -0800 Subject: [PATCH] Merging --- .../python/keras/layers/normalization.py | 16 +++++++--- .../python/keras/layers/normalization_test.py | 31 +++++++++++++++++++ 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 0737fe11712..29ab50c1b08 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -248,6 +248,7 @@ class BatchNormalizationBase(Layer): axis = [self.axis] if isinstance(self.axis, int) else self.axis # Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the # input rank is required to be 4 (which is checked later). + # TODO(b/173253101): Once the input rank can be 5, update this check. if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3): raise ValueError('Passing fused=True is only supported when axis is 1 ' 'or 3') @@ -329,14 +330,19 @@ class BatchNormalizationBase(Layer): # TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the # output back to its original shape accordingly. if self._USE_V2_BEHAVIOR: + # TODO(b/173253101): Using fused in the 5D case is currently disabled + # due to a regression on UNet, so it is only currently only supported in + # the 4D case. if self.fused is None: - self.fused = ndims in (4, 5) - elif self.fused and ndims not in (4, 5): - raise ValueError('Batch normalization layers with fused=True only ' - 'support 4D or 5D input tensors.') + self.fused = ndims == 4 + elif self.fused and ndims != 4: + raise ValueError('Batch normalization layers with `fused=True` only ' + 'support 4D or 5D input tensors. ' + 'Received tensor with shape: %s' % + (tuple(input_shape),)) else: assert self.fused is not None - self.fused = (ndims in (4, 5) and self._fused_can_be_used()) + self.fused = (ndims == 4 and self._fused_can_be_used()) # TODO(chrisying): fused batch norm is currently not supported for # multi-axis batch norm and by extension virtual batches. In some cases, # it might be possible to use fused batch norm but would require reshaping diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index a98db36ceea..d468e5d6db2 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -241,6 +241,31 @@ class BatchNormalizationTest(keras_parameterized.TestCase): self.assertAllClose(model.bn.moving_mean.numpy(), [0.047], atol=3e-3) self.assertAllClose(model.bn.moving_variance.numpy(), [0.9], atol=3e-2) + @combinations.generate(combinations.combine(mode=['eager'])) + def test_bessels_correction(self): + # Bessel's correction is currently only used in the fused case. In the + # future, it may be used in the nonfused case as well. + + x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1]) + layer = normalization_v2.BatchNormalization( + momentum=0.5, moving_variance_initializer='zeros') + layer(x, training=True) + self.assertTrue(layer.fused) + # Since fused is used, Bessel's correction is used. The variance of [0, 2] + # is 2 with Bessel's correction. Since the momentum is 0.5, the variance is + # 2 * 0.5 == 1. + self.assertAllEqual(self.evaluate(layer.moving_variance), [1.]) + + x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1, 1]) + layer = normalization_v2.BatchNormalization( + momentum=0.5, moving_variance_initializer='zeros') + layer(x, training=True) + self.assertFalse(layer.fused) + # Since fused is not used, Bessel's correction is not used. The variance of + # [0, 2] is 1 without Bessel's correction. Since the momentum is 0.5, the + # variance is 1 * 0.5 == 0.5. + self.assertAllEqual(self.evaluate(layer.moving_variance), [0.5]) + class BatchNormalizationV1Test(keras_parameterized.TestCase): @@ -291,6 +316,12 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase): norm(inp) self.assertEqual(norm.fused, False) + norm = normalization_v2.BatchNormalization() + self.assertIsNone(norm.fused) + inp = keras.layers.Input(shape=(4, 4, 4, 4)) + norm(inp) + self.assertEqual(norm.fused, False) + norm = normalization_v2.BatchNormalization(virtual_batch_size=2) self.assertEqual(norm.fused, False) inp = keras.layers.Input(shape=(4, 4, 4))