diff --git a/tensorflow/python/keras/layers/dense_attention.py b/tensorflow/python/keras/layers/dense_attention.py index 34879524b64..54f657b113d 100644 --- a/tensorflow/python/keras/layers/dense_attention.py +++ b/tensorflow/python/keras/layers/dense_attention.py @@ -126,7 +126,11 @@ class BaseDenseAttention(Layer): if scores_mask is not None: padding_mask = math_ops.logical_not(scores_mask) # Bias so padding positions do not contribute to attention distribution. - scores -= 1.e9 * math_ops.cast(padding_mask, dtype=K.floatx()) + # Note 65504. is the max float16 value. + if scores.dtype is dtypes.float16: + scores -= 65504. * math_ops.cast(padding_mask, dtype=scores.dtype) + else: + scores -= 1.e9 * math_ops.cast(padding_mask, dtype=scores.dtype) if training is None: training = K.learning_phase() weights = nn.softmax(scores) diff --git a/tensorflow/python/keras/layers/dense_attention_test.py b/tensorflow/python/keras/layers/dense_attention_test.py index 8570f41b34f..f025f4d17ef 100644 --- a/tensorflow/python/keras/layers/dense_attention_test.py +++ b/tensorflow/python/keras/layers/dense_attention_test.py @@ -24,9 +24,13 @@ import numpy as np from tensorflow.python import keras from tensorflow.python.eager import context from tensorflow.python.keras import combinations +from tensorflow.python.keras import testing_utils from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import dense_attention +from tensorflow.python.keras.mixed_precision import policy from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -757,6 +761,17 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase): new_layer = dense_attention.AdditiveAttention.from_config(config) self.assertEqual(new_layer.use_scale, True) + @testing_utils.enable_v2_dtype_behavior + def test_mixed_float16_policy(self): + # Test case for GitHub issue: + # https://github.com/tensorflow/tensorflow/issues/46064 + with policy.policy_scope('mixed_float16'): + q = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=1), 'float16') + v = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=2), 'float16') + k = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=3), 'float16') + layer = dense_attention.AdditiveAttention(causal=True) + _ = layer([q, v, k]) + @combinations.generate(combinations.combine(mode=['graph', 'eager'])) class LowerTriangularMaskTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/keras/mixed_precision/layer_correctness_test.py b/tensorflow/python/keras/mixed_precision/layer_correctness_test.py index bbccc8721cd..5a76fd65ab3 100644 --- a/tensorflow/python/keras/mixed_precision/layer_correctness_test.py +++ b/tensorflow/python/keras/mixed_precision/layer_correctness_test.py @@ -100,12 +100,13 @@ class LayerCorrectnessTest(keras_parameterized.TestCase): ('Activation', lambda: core.Activation('sigmoid'), (2, 2)), ('Reshape', lambda: core.Reshape((1, 4, 1)), (2, 2, 2)), ('Permute', lambda: core.Permute((2, 1)), (2, 2, 2)), - ('Attention', dense_attention.Attention, - [(2, 2, 3), (2, 3, 3), (2, 3, 3)]), - ('AdditiveAttention', dense_attention.AdditiveAttention, - [(2, 2, 3), (2, 3, 3), (2, 3, 3)]), - ('Embedding', lambda: embeddings.Embedding(4, 4), (2, 4), 2e-3, 2e-3, - np.random.randint(4, size=(2, 4))), + ('Attention', dense_attention.Attention, [(2, 2, 3), (2, 3, 3), + (2, 3, 3)]), + ('AdditiveAttention', dense_attention.AdditiveAttention, [(2, 2, 3), + (2, 3, 3), + (2, 3, 3)]), + ('Embedding', lambda: embeddings.Embedding(4, 4), + (2, 4), 2e-3, 2e-3, np.random.randint(4, size=(2, 4))), ('LocallyConnected1D', lambda: local.LocallyConnected1D(2, 2), (2, 2, 1)), ('LocallyConnected2D', lambda: local.LocallyConnected2D(2, 2), (2, 2, 2, 1)), @@ -120,8 +121,8 @@ class LayerCorrectnessTest(keras_parameterized.TestCase): ('GaussianNoise', lambda: noise.GaussianNoise(0.5), (2, 2)), ('GaussianDropout', lambda: noise.GaussianDropout(0.5), (2, 2)), ('AlphaDropout', lambda: noise.AlphaDropout(0.5), (2, 2)), - ('BatchNormalization', normalization_v2.BatchNormalization, (2, 2), - 1e-2, 1e-2), + ('BatchNormalization', normalization_v2.BatchNormalization, + (2, 2), 1e-2, 1e-2), ('LayerNormalization', normalization.LayerNormalization, (2, 2)), ('LayerNormalizationUnfused', lambda: normalization.LayerNormalization(axis=1), (2, 2, 2)), @@ -129,8 +130,8 @@ class LayerCorrectnessTest(keras_parameterized.TestCase): ('AveragePooling2D', pooling.AveragePooling2D, (2, 2, 2, 1)), ('GlobalMaxPooling2D', pooling.GlobalMaxPooling2D, (2, 2, 2, 1)), ('GlobalAveragePooling2D', pooling.GlobalAveragePooling2D, (2, 2, 2, 1)), - ('SimpleRNN', lambda: recurrent.SimpleRNN(units=4), (4, 4, 4), - 1e-2, 1e-2), + ('SimpleRNN', lambda: recurrent.SimpleRNN(units=4), + (4, 4, 4), 1e-2, 1e-2), ('GRU', lambda: recurrent.GRU(units=4), (4, 4, 4)), ('LSTM', lambda: recurrent.LSTM(units=4), (4, 4, 4)), ('GRUV2', lambda: recurrent_v2.GRU(units=4), (4, 4, 4)), @@ -139,6 +140,13 @@ class LayerCorrectnessTest(keras_parameterized.TestCase): (2, 2, 2)), ('Bidirectional', lambda: wrappers.Bidirectional(recurrent.SimpleRNN(units=4)), (2, 2, 2)), + ('AttentionLayerCausal', lambda: dense_attention.Attention(causal=True), [ + (2, 2, 3), (2, 3, 3), (2, 3, 3) + ]), + ('AdditiveAttentionLayerCausal', + lambda: dense_attention.AdditiveAttention(causal=True), [(2, 3, 4), + (2, 3, 4), + (2, 3, 4)]), ) def test_layer(self, f32_layer_fn, input_shape, rtol=2e-3, atol=2e-3, input_data=None):