Merging
This commit is contained in:
parent
259339c432
commit
2cf3cc466c
@ -248,6 +248,7 @@ class BatchNormalizationBase(Layer):
|
||||
axis = [self.axis] if isinstance(self.axis, int) else self.axis
|
||||
# Axis -3 is equivalent to 1, and axis -1 is equivalent to 3, because the
|
||||
# input rank is required to be 4 (which is checked later).
|
||||
# TODO(b/173253101): Once the input rank can be 5, update this check.
|
||||
if len(axis) > 1 or axis[0] not in (-3, -1, 1, 3):
|
||||
raise ValueError('Passing fused=True is only supported when axis is 1 '
|
||||
'or 3')
|
||||
@ -329,14 +330,19 @@ class BatchNormalizationBase(Layer):
|
||||
# TODO(yaozhang): if input is not 4D, reshape it to 4D and reshape the
|
||||
# output back to its original shape accordingly.
|
||||
if self._USE_V2_BEHAVIOR:
|
||||
# TODO(b/173253101): Using fused in the 5D case is currently disabled
|
||||
# due to a regression on UNet, so it is only currently only supported in
|
||||
# the 4D case.
|
||||
if self.fused is None:
|
||||
self.fused = ndims in (4, 5)
|
||||
elif self.fused and ndims not in (4, 5):
|
||||
raise ValueError('Batch normalization layers with fused=True only '
|
||||
'support 4D or 5D input tensors.')
|
||||
self.fused = ndims == 4
|
||||
elif self.fused and ndims != 4:
|
||||
raise ValueError('Batch normalization layers with `fused=True` only '
|
||||
'support 4D or 5D input tensors. '
|
||||
'Received tensor with shape: %s' %
|
||||
(tuple(input_shape),))
|
||||
else:
|
||||
assert self.fused is not None
|
||||
self.fused = (ndims in (4, 5) and self._fused_can_be_used())
|
||||
self.fused = (ndims == 4 and self._fused_can_be_used())
|
||||
# TODO(chrisying): fused batch norm is currently not supported for
|
||||
# multi-axis batch norm and by extension virtual batches. In some cases,
|
||||
# it might be possible to use fused batch norm but would require reshaping
|
||||
|
@ -241,6 +241,31 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
|
||||
self.assertAllClose(model.bn.moving_mean.numpy(), [0.047], atol=3e-3)
|
||||
self.assertAllClose(model.bn.moving_variance.numpy(), [0.9], atol=3e-2)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_bessels_correction(self):
|
||||
# Bessel's correction is currently only used in the fused case. In the
|
||||
# future, it may be used in the nonfused case as well.
|
||||
|
||||
x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1])
|
||||
layer = normalization_v2.BatchNormalization(
|
||||
momentum=0.5, moving_variance_initializer='zeros')
|
||||
layer(x, training=True)
|
||||
self.assertTrue(layer.fused)
|
||||
# Since fused is used, Bessel's correction is used. The variance of [0, 2]
|
||||
# is 2 with Bessel's correction. Since the momentum is 0.5, the variance is
|
||||
# 2 * 0.5 == 1.
|
||||
self.assertAllEqual(self.evaluate(layer.moving_variance), [1.])
|
||||
|
||||
x = constant_op.constant([0., 2.], shape=[2, 1, 1, 1, 1])
|
||||
layer = normalization_v2.BatchNormalization(
|
||||
momentum=0.5, moving_variance_initializer='zeros')
|
||||
layer(x, training=True)
|
||||
self.assertFalse(layer.fused)
|
||||
# Since fused is not used, Bessel's correction is not used. The variance of
|
||||
# [0, 2] is 1 without Bessel's correction. Since the momentum is 0.5, the
|
||||
# variance is 1 * 0.5 == 0.5.
|
||||
self.assertAllEqual(self.evaluate(layer.moving_variance), [0.5])
|
||||
|
||||
|
||||
class BatchNormalizationV1Test(keras_parameterized.TestCase):
|
||||
|
||||
@ -291,6 +316,12 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
|
||||
norm(inp)
|
||||
self.assertEqual(norm.fused, False)
|
||||
|
||||
norm = normalization_v2.BatchNormalization()
|
||||
self.assertIsNone(norm.fused)
|
||||
inp = keras.layers.Input(shape=(4, 4, 4, 4))
|
||||
norm(inp)
|
||||
self.assertEqual(norm.fused, False)
|
||||
|
||||
norm = normalization_v2.BatchNormalization(virtual_batch_size=2)
|
||||
self.assertEqual(norm.fused, False)
|
||||
inp = keras.layers.Input(shape=(4, 4, 4))
|
||||
|
Loading…
Reference in New Issue
Block a user