Merge pull request #46321 from yongtang:46064-Attention-mixed_policy
PiperOrigin-RevId: 353172550 Change-Id: I9d6a98b89eb3a3948d39a15ef155e324a31ace47
This commit is contained in:
commit
43cbed5c96
@ -126,7 +126,11 @@ class BaseDenseAttention(Layer):
|
||||
if scores_mask is not None:
|
||||
padding_mask = math_ops.logical_not(scores_mask)
|
||||
# Bias so padding positions do not contribute to attention distribution.
|
||||
scores -= 1.e9 * math_ops.cast(padding_mask, dtype=K.floatx())
|
||||
# Note 65504. is the max float16 value.
|
||||
if scores.dtype is dtypes.float16:
|
||||
scores -= 65504. * math_ops.cast(padding_mask, dtype=scores.dtype)
|
||||
else:
|
||||
scores -= 1.e9 * math_ops.cast(padding_mask, dtype=scores.dtype)
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
weights = nn.softmax(scores)
|
||||
|
@ -24,9 +24,13 @@ import numpy as np
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.keras import combinations
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import dense_attention
|
||||
from tensorflow.python.keras.mixed_precision import policy
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -757,6 +761,17 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
new_layer = dense_attention.AdditiveAttention.from_config(config)
|
||||
self.assertEqual(new_layer.use_scale, True)
|
||||
|
||||
@testing_utils.enable_v2_dtype_behavior
|
||||
def test_mixed_float16_policy(self):
|
||||
# Test case for GitHub issue:
|
||||
# https://github.com/tensorflow/tensorflow/issues/46064
|
||||
with policy.policy_scope('mixed_float16'):
|
||||
q = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=1), 'float16')
|
||||
v = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=2), 'float16')
|
||||
k = math_ops.cast(random_ops.random_uniform((2, 3, 4), seed=3), 'float16')
|
||||
layer = dense_attention.AdditiveAttention(causal=True)
|
||||
_ = layer([q, v, k])
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class LowerTriangularMaskTest(test.TestCase, parameterized.TestCase):
|
||||
|
@ -100,12 +100,13 @@ class LayerCorrectnessTest(keras_parameterized.TestCase):
|
||||
('Activation', lambda: core.Activation('sigmoid'), (2, 2)),
|
||||
('Reshape', lambda: core.Reshape((1, 4, 1)), (2, 2, 2)),
|
||||
('Permute', lambda: core.Permute((2, 1)), (2, 2, 2)),
|
||||
('Attention', dense_attention.Attention,
|
||||
[(2, 2, 3), (2, 3, 3), (2, 3, 3)]),
|
||||
('AdditiveAttention', dense_attention.AdditiveAttention,
|
||||
[(2, 2, 3), (2, 3, 3), (2, 3, 3)]),
|
||||
('Embedding', lambda: embeddings.Embedding(4, 4), (2, 4), 2e-3, 2e-3,
|
||||
np.random.randint(4, size=(2, 4))),
|
||||
('Attention', dense_attention.Attention, [(2, 2, 3), (2, 3, 3),
|
||||
(2, 3, 3)]),
|
||||
('AdditiveAttention', dense_attention.AdditiveAttention, [(2, 2, 3),
|
||||
(2, 3, 3),
|
||||
(2, 3, 3)]),
|
||||
('Embedding', lambda: embeddings.Embedding(4, 4),
|
||||
(2, 4), 2e-3, 2e-3, np.random.randint(4, size=(2, 4))),
|
||||
('LocallyConnected1D', lambda: local.LocallyConnected1D(2, 2), (2, 2, 1)),
|
||||
('LocallyConnected2D', lambda: local.LocallyConnected2D(2, 2),
|
||||
(2, 2, 2, 1)),
|
||||
@ -120,8 +121,8 @@ class LayerCorrectnessTest(keras_parameterized.TestCase):
|
||||
('GaussianNoise', lambda: noise.GaussianNoise(0.5), (2, 2)),
|
||||
('GaussianDropout', lambda: noise.GaussianDropout(0.5), (2, 2)),
|
||||
('AlphaDropout', lambda: noise.AlphaDropout(0.5), (2, 2)),
|
||||
('BatchNormalization', normalization_v2.BatchNormalization, (2, 2),
|
||||
1e-2, 1e-2),
|
||||
('BatchNormalization', normalization_v2.BatchNormalization,
|
||||
(2, 2), 1e-2, 1e-2),
|
||||
('LayerNormalization', normalization.LayerNormalization, (2, 2)),
|
||||
('LayerNormalizationUnfused',
|
||||
lambda: normalization.LayerNormalization(axis=1), (2, 2, 2)),
|
||||
@ -129,8 +130,8 @@ class LayerCorrectnessTest(keras_parameterized.TestCase):
|
||||
('AveragePooling2D', pooling.AveragePooling2D, (2, 2, 2, 1)),
|
||||
('GlobalMaxPooling2D', pooling.GlobalMaxPooling2D, (2, 2, 2, 1)),
|
||||
('GlobalAveragePooling2D', pooling.GlobalAveragePooling2D, (2, 2, 2, 1)),
|
||||
('SimpleRNN', lambda: recurrent.SimpleRNN(units=4), (4, 4, 4),
|
||||
1e-2, 1e-2),
|
||||
('SimpleRNN', lambda: recurrent.SimpleRNN(units=4),
|
||||
(4, 4, 4), 1e-2, 1e-2),
|
||||
('GRU', lambda: recurrent.GRU(units=4), (4, 4, 4)),
|
||||
('LSTM', lambda: recurrent.LSTM(units=4), (4, 4, 4)),
|
||||
('GRUV2', lambda: recurrent_v2.GRU(units=4), (4, 4, 4)),
|
||||
@ -139,6 +140,13 @@ class LayerCorrectnessTest(keras_parameterized.TestCase):
|
||||
(2, 2, 2)),
|
||||
('Bidirectional',
|
||||
lambda: wrappers.Bidirectional(recurrent.SimpleRNN(units=4)), (2, 2, 2)),
|
||||
('AttentionLayerCausal', lambda: dense_attention.Attention(causal=True), [
|
||||
(2, 2, 3), (2, 3, 3), (2, 3, 3)
|
||||
]),
|
||||
('AdditiveAttentionLayerCausal',
|
||||
lambda: dense_attention.AdditiveAttention(causal=True), [(2, 3, 4),
|
||||
(2, 3, 4),
|
||||
(2, 3, 4)]),
|
||||
)
|
||||
def test_layer(self, f32_layer_fn, input_shape, rtol=2e-3, atol=2e-3,
|
||||
input_data=None):
|
||||
|
Loading…
x
Reference in New Issue
Block a user