diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 61e134e3d94..e5723a3ef98 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -724,6 +724,13 @@ class BatchNormalizationBase(Layer): outputs = undo_virtual_batching(outputs) return outputs + inputs_dtype = inputs.dtype.base_dtype + if inputs_dtype in (dtypes.float16, dtypes.bfloat16): + # Do all math in float32 if given 16-bit inputs for numeric stability. + # In particular, it's very easy for variance to overflow in float16 and + # for safety we also choose to cast bfloat16 to float32. + inputs = math_ops.cast(inputs, dtypes.float32) + # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) @@ -852,11 +859,12 @@ class BatchNormalizationBase(Layer): offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) - # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing - # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) + if inputs_dtype in (dtypes.float16, dtypes.bfloat16): + outputs = math_ops.cast(outputs, inputs_dtype) + # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index e14977edfc4..ef43bcf5d22 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -146,7 +146,7 @@ class BatchNormalizationTest(keras_parameterized.TestCase): normalization_v2.BatchNormalization, dtype='float32') @keras_parameterized.run_all_keras_modes - def test_batchnorm_mixed_precision(self): + def test_batchnorm_float16(self): _run_batchnorm_correctness_test( normalization.BatchNormalization, dtype='float16') _run_batchnorm_correctness_test( @@ -154,7 +154,7 @@ class BatchNormalizationTest(keras_parameterized.TestCase): @combinations.generate(combinations.combine(mode=['graph', 'eager'])) @testing_utils.enable_v2_dtype_behavior - def test_batchnorm_policy(self): + def test_batchnorm_mixed_precision(self): norm = keras.layers.BatchNormalization( axis=-1, input_shape=(4, 4, 3), @@ -166,6 +166,20 @@ class BatchNormalizationTest(keras_parameterized.TestCase): self.assertEqual(norm.beta.dtype.base_dtype, 'float32') self.assertEqual(norm.gamma.dtype.base_dtype, 'float32') + @combinations.generate(combinations.combine(mode=['graph', 'eager'], + fused=[True, False])) + @testing_utils.enable_v2_dtype_behavior + def test_batchnorm_mixed_precision_does_not_overflow(self, fused): + norm = keras.layers.BatchNormalization( + axis=-1, + input_shape=(1, 1, 1), + fused=fused, + dtype=policy.Policy('mixed_float16')) + x = np.array([-1000., 1000.]).reshape((2, 1, 1, 1)) + y = norm(x, training=True) + expected_y = np.array([-1.0, 1.0]).reshape((2, 1, 1, 1)) + self.assertAllClose(keras.backend.eval(y), expected_y) + @keras_parameterized.run_all_keras_modes(always_skip_v1=True) def test_batchnorm_non_trainable_with_fit(self): # We use the same data shape for all the data we use in this test.