Do non-fused mixed precision BatchNormalization in float32.
It's very easy for variance to overflow in float16. This causes the output of BatchNormalization to be 0. Doing the computations in float32 and casting the output back to float16 solves the issue. The output has very little risk of underflow or overflow as the square root of variance is taken before being used. PiperOrigin-RevId: 315599303 Change-Id: I0b7162b01ec748eb003a465a26215d1772d43cac
This commit is contained in:
parent
ec2fb44030
commit
de0a617f4e
@ -724,6 +724,13 @@ class BatchNormalizationBase(Layer):
|
||||
outputs = undo_virtual_batching(outputs)
|
||||
return outputs
|
||||
|
||||
inputs_dtype = inputs.dtype.base_dtype
|
||||
if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
|
||||
# Do all math in float32 if given 16-bit inputs for numeric stability.
|
||||
# In particular, it's very easy for variance to overflow in float16 and
|
||||
# for safety we also choose to cast bfloat16 to float32.
|
||||
inputs = math_ops.cast(inputs, dtypes.float32)
|
||||
|
||||
# Compute the axes along which to reduce the mean / variance
|
||||
input_shape = inputs.shape
|
||||
ndims = len(input_shape)
|
||||
@ -852,11 +859,12 @@ class BatchNormalizationBase(Layer):
|
||||
offset = math_ops.cast(offset, inputs.dtype)
|
||||
if scale is not None:
|
||||
scale = math_ops.cast(scale, inputs.dtype)
|
||||
# TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
|
||||
# math in float16 hurts validation accuracy of popular models like resnet.
|
||||
outputs = nn.batch_normalization(inputs, _broadcast(mean),
|
||||
_broadcast(variance), offset, scale,
|
||||
self.epsilon)
|
||||
if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
|
||||
outputs = math_ops.cast(outputs, inputs_dtype)
|
||||
|
||||
# If some components of the shape got lost due to adjustments, fix that.
|
||||
outputs.set_shape(input_shape)
|
||||
|
||||
|
@ -146,7 +146,7 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
|
||||
normalization_v2.BatchNormalization, dtype='float32')
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_batchnorm_mixed_precision(self):
|
||||
def test_batchnorm_float16(self):
|
||||
_run_batchnorm_correctness_test(
|
||||
normalization.BatchNormalization, dtype='float16')
|
||||
_run_batchnorm_correctness_test(
|
||||
@ -154,7 +154,7 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_batchnorm_policy(self):
|
||||
def test_batchnorm_mixed_precision(self):
|
||||
norm = keras.layers.BatchNormalization(
|
||||
axis=-1,
|
||||
input_shape=(4, 4, 3),
|
||||
@ -166,6 +166,20 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
|
||||
self.assertEqual(norm.beta.dtype.base_dtype, 'float32')
|
||||
self.assertEqual(norm.gamma.dtype.base_dtype, 'float32')
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager'],
|
||||
fused=[True, False]))
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_batchnorm_mixed_precision_does_not_overflow(self, fused):
|
||||
norm = keras.layers.BatchNormalization(
|
||||
axis=-1,
|
||||
input_shape=(1, 1, 1),
|
||||
fused=fused,
|
||||
dtype=policy.Policy('mixed_float16'))
|
||||
x = np.array([-1000., 1000.]).reshape((2, 1, 1, 1))
|
||||
y = norm(x, training=True)
|
||||
expected_y = np.array([-1.0, 1.0]).reshape((2, 1, 1, 1))
|
||||
self.assertAllClose(keras.backend.eval(y), expected_y)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
def test_batchnorm_non_trainable_with_fit(self):
|
||||
# We use the same data shape for all the data we use in this test.
|
||||
|
Loading…
Reference in New Issue
Block a user