Do non-fused mixed precision BatchNormalization in float32.

It's very easy for variance to overflow in float16. This causes the output of BatchNormalization to be 0. Doing the computations in float32 and casting the output back to float16 solves the issue. The output has very little risk of underflow or overflow as the square root of variance is taken before being used.

PiperOrigin-RevId: 315599303
Change-Id: I0b7162b01ec748eb003a465a26215d1772d43cac
This commit is contained in:
Reed Wanderman-Milne 2020-06-09 18:00:21 -07:00 committed by TensorFlower Gardener
parent ec2fb44030
commit de0a617f4e
2 changed files with 26 additions and 4 deletions

View File

@ -724,6 +724,13 @@ class BatchNormalizationBase(Layer):
outputs = undo_virtual_batching(outputs)
return outputs
inputs_dtype = inputs.dtype.base_dtype
if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
# Do all math in float32 if given 16-bit inputs for numeric stability.
# In particular, it's very easy for variance to overflow in float16 and
# for safety we also choose to cast bfloat16 to float32.
inputs = math_ops.cast(inputs, dtypes.float32)
# Compute the axes along which to reduce the mean / variance
input_shape = inputs.shape
ndims = len(input_shape)
@ -852,11 +859,12 @@ class BatchNormalizationBase(Layer):
offset = math_ops.cast(offset, inputs.dtype)
if scale is not None:
scale = math_ops.cast(scale, inputs.dtype)
# TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
# math in float16 hurts validation accuracy of popular models like resnet.
outputs = nn.batch_normalization(inputs, _broadcast(mean),
_broadcast(variance), offset, scale,
self.epsilon)
if inputs_dtype in (dtypes.float16, dtypes.bfloat16):
outputs = math_ops.cast(outputs, inputs_dtype)
# If some components of the shape got lost due to adjustments, fix that.
outputs.set_shape(input_shape)

View File

@ -146,7 +146,7 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
normalization_v2.BatchNormalization, dtype='float32')
@keras_parameterized.run_all_keras_modes
def test_batchnorm_mixed_precision(self):
def test_batchnorm_float16(self):
_run_batchnorm_correctness_test(
normalization.BatchNormalization, dtype='float16')
_run_batchnorm_correctness_test(
@ -154,7 +154,7 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
@testing_utils.enable_v2_dtype_behavior
def test_batchnorm_policy(self):
def test_batchnorm_mixed_precision(self):
norm = keras.layers.BatchNormalization(
axis=-1,
input_shape=(4, 4, 3),
@ -166,6 +166,20 @@ class BatchNormalizationTest(keras_parameterized.TestCase):
self.assertEqual(norm.beta.dtype.base_dtype, 'float32')
self.assertEqual(norm.gamma.dtype.base_dtype, 'float32')
@combinations.generate(combinations.combine(mode=['graph', 'eager'],
fused=[True, False]))
@testing_utils.enable_v2_dtype_behavior
def test_batchnorm_mixed_precision_does_not_overflow(self, fused):
norm = keras.layers.BatchNormalization(
axis=-1,
input_shape=(1, 1, 1),
fused=fused,
dtype=policy.Policy('mixed_float16'))
x = np.array([-1000., 1000.]).reshape((2, 1, 1, 1))
y = norm(x, training=True)
expected_y = np.array([-1.0, 1.0]).reshape((2, 1, 1, 1))
self.assertAllClose(keras.backend.eval(y), expected_y)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
def test_batchnorm_non_trainable_with_fit(self):
# We use the same data shape for all the data we use in this test.