Merge pull request #46321 from yongtang:46064-Attention-mixed_policy

PiperOrigin-RevId: 353172550
Change-Id: I9d6a98b89eb3a3948d39a15ef155e324a31ace47
This commit is contained in:
TensorFlower Gardener 2021-01-21 22:10:00 -08:00
commit 43cbed5c96
3 changed files with 38 additions and 11 deletions

View File

@ -126,7 +126,11 @@ class BaseDenseAttention(Layer):
if scores_mask is not None:
padding_mask = math_ops.logical_not(scores_mask)
# Bias so padding positions do not contribute to attention distribution.
scores -= 1.e9 * math_ops.cast(padding_mask, dtype=K.floatx())
# Note 65504. is the max float16 value.
if scores.dtype is dtypes.float16:
scores -= 65504. * math_ops.cast(padding_mask, dtype=scores.dtype)
else:
scores -= 1.e9 * math_ops.cast(padding_mask, dtype=scores.dtype)
if training is None:
training = K.learning_phase()
weights = nn.softmax(scores)

View File

@ -24,9 +24,13 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.keras import combinations
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import dense_attention
from tensorflow.python.keras.mixed_precision import policy
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@ -757,6 +761,17 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase):
new_layer = dense_attention.AdditiveAttention.from_config(config)
self.assertEqual(new_layer.use_scale, True)
@testing_utils.enable_v2_dtype_behavior
def test_mixed_float16_policy(self):
# Test case for GitHub issue:
# https://github.com/tensorflow/tensorflow/issues/46064
with policy.policy_scope('mixed_float16'):
q = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=1), 'float16')
v = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=2), 'float16')
k = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=3), 'float16')
layer = dense_attention.AdditiveAttention(causal=True)
_ = layer([q, v, k])
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
class LowerTriangularMaskTest(test.TestCase, parameterized.TestCase):

View File

@ -100,12 +100,13 @@ class LayerCorrectnessTest(keras_parameterized.TestCase):
('Activation', lambda: core.Activation('sigmoid'), (2, 2)),
('Reshape', lambda: core.Reshape((1, 4, 1)), (2, 2, 2)),
('Permute', lambda: core.Permute((2, 1)), (2, 2, 2)),
('Attention', dense_attention.Attention,
[(2, 2, 3), (2, 3, 3), (2, 3, 3)]),
('AdditiveAttention', dense_attention.AdditiveAttention,
[(2, 2, 3), (2, 3, 3), (2, 3, 3)]),
('Embedding', lambda: embeddings.Embedding(4, 4), (2, 4), 2e-3, 2e-3,
np.random.randint(4, size=(2, 4))),
('Attention', dense_attention.Attention, [(2, 2, 3), (2, 3, 3),
(2, 3, 3)]),
('AdditiveAttention', dense_attention.AdditiveAttention, [(2, 2, 3),
(2, 3, 3),
(2, 3, 3)]),
('Embedding', lambda: embeddings.Embedding(4, 4),
(2, 4), 2e-3, 2e-3, np.random.randint(4, size=(2, 4))),
('LocallyConnected1D', lambda: local.LocallyConnected1D(2, 2), (2, 2, 1)),
('LocallyConnected2D', lambda: local.LocallyConnected2D(2, 2),
(2, 2, 2, 1)),
@ -120,8 +121,8 @@ class LayerCorrectnessTest(keras_parameterized.TestCase):
('GaussianNoise', lambda: noise.GaussianNoise(0.5), (2, 2)),
('GaussianDropout', lambda: noise.GaussianDropout(0.5), (2, 2)),
('AlphaDropout', lambda: noise.AlphaDropout(0.5), (2, 2)),
('BatchNormalization', normalization_v2.BatchNormalization, (2, 2),
1e-2, 1e-2),
('BatchNormalization', normalization_v2.BatchNormalization,
(2, 2), 1e-2, 1e-2),
('LayerNormalization', normalization.LayerNormalization, (2, 2)),
('LayerNormalizationUnfused',
lambda: normalization.LayerNormalization(axis=1), (2, 2, 2)),
@ -129,8 +130,8 @@ class LayerCorrectnessTest(keras_parameterized.TestCase):
('AveragePooling2D', pooling.AveragePooling2D, (2, 2, 2, 1)),
('GlobalMaxPooling2D', pooling.GlobalMaxPooling2D, (2, 2, 2, 1)),
('GlobalAveragePooling2D', pooling.GlobalAveragePooling2D, (2, 2, 2, 1)),
('SimpleRNN', lambda: recurrent.SimpleRNN(units=4), (4, 4, 4),
1e-2, 1e-2),
('SimpleRNN', lambda: recurrent.SimpleRNN(units=4),
(4, 4, 4), 1e-2, 1e-2),
('GRU', lambda: recurrent.GRU(units=4), (4, 4, 4)),
('LSTM', lambda: recurrent.LSTM(units=4), (4, 4, 4)),
('GRUV2', lambda: recurrent_v2.GRU(units=4), (4, 4, 4)),
@ -139,6 +140,13 @@ class LayerCorrectnessTest(keras_parameterized.TestCase):
(2, 2, 2)),
('Bidirectional',
lambda: wrappers.Bidirectional(recurrent.SimpleRNN(units=4)), (2, 2, 2)),
('AttentionLayerCausal', lambda: dense_attention.Attention(causal=True), [
(2, 2, 3), (2, 3, 3), (2, 3, 3)
]),
('AdditiveAttentionLayerCausal',
lambda: dense_attention.AdditiveAttention(causal=True), [(2, 3, 4),
(2, 3, 4),
(2, 3, 4)]),
)
def test_layer(self, f32_layer_fn, input_shape, rtol=2e-3, atol=2e-3,
input_data=None):