Keep return_attention_scores consistent with MultiHeadAttention, before TF 2.4 release.
PiperOrigin-RevId: 336016413 Change-Id: If9cbc68586fddf1da221cfda35dd4b72b3b68897
This commit is contained in:
parent
37cf5b43c3
commit
fca5923fab
@ -209,6 +209,10 @@ h# Release 2.4.0
|
||||
* Improvements to Keras preprocessing layers:
|
||||
* TextVectorization can now accept a vocabulary list or file as an
|
||||
init arg.
|
||||
* In `Attention` and `AdditiveAttention` layers, the `call()` method now
|
||||
accepts a `return_attention_scores` argument. When set to
|
||||
True, the layer returns the attention scores as an additional output
|
||||
argument.
|
||||
* `tf.function` / AutoGraph:
|
||||
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
|
@ -49,8 +49,6 @@ class BaseDenseAttention(Layer):
|
||||
flow of information from the future towards the past.
|
||||
dropout: Float between 0 and 1. Fraction of the units to drop for the
|
||||
attention scores.
|
||||
return_attention_scores: bool, it `True`, returns the attention scores
|
||||
(after masking and softmax) as an additional output argument.
|
||||
|
||||
Call Arguments:
|
||||
|
||||
@ -69,6 +67,8 @@ class BaseDenseAttention(Layer):
|
||||
`mask==False` do not contribute to the result.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
return_attention_scores: bool, it `True`, returns the attention scores
|
||||
(after masking and softmax) as an additional output argument.
|
||||
|
||||
Output:
|
||||
|
||||
@ -77,12 +77,11 @@ class BaseDenseAttention(Layer):
|
||||
`[batch_size, Tq, Tv]`.
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, dropout=0.0, return_attention_scores=False,
|
||||
def __init__(self, causal=False, dropout=0.0,
|
||||
**kwargs):
|
||||
super(BaseDenseAttention, self).__init__(**kwargs)
|
||||
self.causal = causal
|
||||
self.dropout = dropout
|
||||
self.return_attention_scores = return_attention_scores
|
||||
self.supports_masking = True
|
||||
|
||||
def _calculate_scores(self, query, key):
|
||||
@ -140,7 +139,11 @@ class BaseDenseAttention(Layer):
|
||||
return math_ops.matmul(weights, value), weights
|
||||
|
||||
# TODO(b/125916026): Consider exposing a __call__ method with named args.
|
||||
def call(self, inputs, mask=None, training=None):
|
||||
def call(self,
|
||||
inputs,
|
||||
mask=None,
|
||||
training=None,
|
||||
return_attention_scores=False):
|
||||
self._validate_call_args(inputs=inputs, mask=mask)
|
||||
q = inputs[0]
|
||||
v = inputs[1]
|
||||
@ -170,7 +173,7 @@ class BaseDenseAttention(Layer):
|
||||
# Mask of shape [batch_size, Tq, 1].
|
||||
q_mask = array_ops.expand_dims(q_mask, axis=-1)
|
||||
result *= math_ops.cast(q_mask, dtype=result.dtype)
|
||||
if self.return_attention_scores:
|
||||
if return_attention_scores:
|
||||
return result, attention_scores
|
||||
return result
|
||||
|
||||
@ -209,7 +212,6 @@ class BaseDenseAttention(Layer):
|
||||
config = {
|
||||
'causal': self.causal,
|
||||
'dropout': self.dropout,
|
||||
'return_attention_scores': self.return_attention_scores,
|
||||
}
|
||||
base_config = super(BaseDenseAttention, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
@ -239,8 +241,6 @@ class Attention(BaseDenseAttention):
|
||||
flow of information from the future towards the past.
|
||||
dropout: Float between 0 and 1. Fraction of the units to drop for the
|
||||
attention scores.
|
||||
return_attention_scores: bool, it `True`, returns the attention scores
|
||||
(after masking and softmax) as an additional output argument.
|
||||
|
||||
Call Arguments:
|
||||
|
||||
@ -257,6 +257,8 @@ class Attention(BaseDenseAttention):
|
||||
* value_mask: A boolean mask `Tensor` of shape `[batch_size, Tv]`.
|
||||
If given, will apply the mask such that values at positions where
|
||||
`mask==False` do not contribute to the result.
|
||||
return_attention_scores: bool, it `True`, returns the attention scores
|
||||
(after masking and softmax) as an additional output argument.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
|
||||
@ -378,8 +380,6 @@ class AdditiveAttention(BaseDenseAttention):
|
||||
flow of information from the future towards the past.
|
||||
dropout: Float between 0 and 1. Fraction of the units to drop for the
|
||||
attention scores.
|
||||
return_attention_scores: bool, it `True`, returns the attention scores
|
||||
(after masking and softmax) as an additional output argument.
|
||||
|
||||
Call Arguments:
|
||||
|
||||
@ -398,6 +398,8 @@ class AdditiveAttention(BaseDenseAttention):
|
||||
`mask==False` do not contribute to the result.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
return_attention_scores: bool, it `True`, returns the attention scores
|
||||
(after masking and softmax) as an additional output argument.
|
||||
|
||||
Output:
|
||||
|
||||
|
@ -82,8 +82,8 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
# => softmax_scores000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863
|
||||
# softmax_scores001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137
|
||||
# softmax_scores002 = 0
|
||||
expected_scores = np.array(
|
||||
[[[0.73105857863, 0.26894142137, 0.]]], dtype=np.float32)
|
||||
expected_scores = np.array([[[0.73105857863, 0.26894142137, 0.]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
# Expected tensor of shape [1, 1, 1].
|
||||
# expected000 = 0.73105857863 * 1.6 + 0.26894142137 * 0.7 - 0 * 0.8
|
||||
@ -187,8 +187,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_calculate_scores_multi_dim(self):
|
||||
# Query tensor of shape [1, 2, 4]
|
||||
q = np.array(
|
||||
[[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
# Key tensor of shape [1, 3, 4]
|
||||
k = np.array(
|
||||
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
|
||||
@ -204,8 +203,8 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
# expected010 = 2.*1.5+2.1*1.6+2.2*1.7+2.3*1.8 = 14.24
|
||||
# expected011 = 2.*2.5+2.1*2.6+2.2*2.7+2.3*2.8 = 22.84
|
||||
# expected012 = 2.*3.5+2.1*3.6+2.2*3.7+2.3*3.8 = 31.44
|
||||
expected = np.array(
|
||||
[[[7.64, 12.24, 16.84], [14.24, 22.84, 31.44]]], dtype=np.float32)
|
||||
expected = np.array([[[7.64, 12.24, 16.84], [14.24, 22.84, 31.44]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
def test_calculate_scores_one_dim_batch_size_two(self):
|
||||
@ -241,8 +240,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_shape(self):
|
||||
# Query tensor of shape [1, 2, 4]
|
||||
q = np.array(
|
||||
[[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
# Value tensor of shape [1, 3, 4]
|
||||
v = np.array(
|
||||
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
|
||||
@ -257,8 +255,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_shape_with_key(self):
|
||||
# Query tensor of shape [1, 2, 4]
|
||||
q = np.array(
|
||||
[[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
# Value tensor of shape [1, 3, 4]
|
||||
v = np.array(
|
||||
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
|
||||
@ -342,12 +339,16 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
q_mask = np.array([[True, False]], dtype=np.bool_)
|
||||
# Value mask tensor of shape [1, 3]
|
||||
v_mask = np.array([[True, True, False]], dtype=np.bool_)
|
||||
attention_layer = dense_attention.Attention(
|
||||
return_attention_scores=return_attention_scores)
|
||||
attention_layer = dense_attention.Attention()
|
||||
if return_attention_scores:
|
||||
actual, actual_scores = attention_layer([q, v], mask=[q_mask, v_mask])
|
||||
actual, actual_scores = attention_layer(
|
||||
[q, v],
|
||||
mask=[q_mask, v_mask],
|
||||
return_attention_scores=return_attention_scores)
|
||||
else:
|
||||
actual = attention_layer([q, v], mask=[q_mask, v_mask])
|
||||
actual = attention_layer([q, v],
|
||||
mask=[q_mask, v_mask],
|
||||
return_attention_scores=return_attention_scores)
|
||||
|
||||
# Expected scores of shape [1, 2, 3]
|
||||
# scores = [[[1.1*1.6, 1.1*0.7, -1.1*0.8], [-0.5*1.6, -0.5*0.7, 0.5*0.8]]]
|
||||
@ -365,10 +366,9 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
# = 0.61063923394
|
||||
# attention_distribution012 = 0
|
||||
if return_attention_scores:
|
||||
expected_scores = np.array(
|
||||
[[[0.72908792234, 0.27091207765, 0.],
|
||||
[0.38936076605, 0.61063923394, 0.]]],
|
||||
dtype=np.float32)
|
||||
expected_scores = np.array([[[0.72908792234, 0.27091207765, 0.],
|
||||
[0.38936076605, 0.61063923394, 0.]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
# Expected tensor of shape [1, 2, 1] with zeros where q_mask == False.
|
||||
# expected000 = 0.72908792234 * 1.6 + 0.27091207765 * 0.7 - 0 * 0.8
|
||||
@ -406,12 +406,13 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
def test_self_attention_causal(self, return_attention_scores):
|
||||
# Query-value tensor of shape [1, 3, 1]
|
||||
q = np.array([[[0.5], [0.8], [-0.3]]], dtype=np.float32)
|
||||
attention_layer = dense_attention.Attention(
|
||||
causal=True, return_attention_scores=return_attention_scores)
|
||||
attention_layer = dense_attention.Attention(causal=True)
|
||||
if return_attention_scores:
|
||||
actual, actual_scores = attention_layer([q, q])
|
||||
actual, actual_scores = attention_layer(
|
||||
[q, q], return_attention_scores=return_attention_scores)
|
||||
else:
|
||||
actual = attention_layer([q, q])
|
||||
actual = attention_layer([q, q],
|
||||
return_attention_scores=return_attention_scores)
|
||||
|
||||
# Expected scores of shape [1, 3, 3]
|
||||
# scores = [[0.25, 0.4, -0.15], [0.4, 0.64, -0.24], [-0.15, -0.24, 0.09]]
|
||||
@ -426,8 +427,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
# = [0.31395396638, 0.28693232061, 0.399113713]
|
||||
if return_attention_scores:
|
||||
expected_scores = np.array(
|
||||
[[[1., 0., 0.],
|
||||
[0.44028635073, 0.55971364926, 0.],
|
||||
[[[1., 0., 0.], [0.44028635073, 0.55971364926, 0.],
|
||||
[0.31395396638, 0.28693232061, 0.399113713]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
@ -437,8 +437,8 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
# = 0.66791409477
|
||||
# expected020 = 0.31395396638 * 0.5 +0.28693232061 * 0.8 -0.399113713 * 0.3
|
||||
# = 0.26678872577
|
||||
expected = np.array(
|
||||
[[[0.5], [0.66791409477], [0.26678872577]]], dtype=np.float32)
|
||||
expected = np.array([[[0.5], [0.66791409477], [0.26678872577]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
def test_inputs_not_list(self):
|
||||
@ -501,24 +501,20 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose([[[0], [1]]], actual)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('', False, False),
|
||||
('use_scale', True, False),
|
||||
('return_attention_scores', False, True),
|
||||
('', False),
|
||||
('use_scale', True),
|
||||
)
|
||||
def test_serialization(self, use_scale, return_attention_scores):
|
||||
def test_serialization(self, use_scale):
|
||||
# Test serialization with use_scale
|
||||
layer = dense_attention.Attention(
|
||||
use_scale=use_scale, return_attention_scores=return_attention_scores)
|
||||
layer = dense_attention.Attention(use_scale=use_scale)
|
||||
|
||||
config = keras.layers.serialize(layer)
|
||||
new_layer = keras.layers.deserialize(config)
|
||||
self.assertEqual(new_layer.use_scale, use_scale)
|
||||
self.assertEqual(new_layer.return_attention_scores, return_attention_scores)
|
||||
|
||||
config = layer.get_config()
|
||||
new_layer = dense_attention.Attention.from_config(config)
|
||||
self.assertEqual(new_layer.use_scale, use_scale)
|
||||
self.assertEqual(new_layer.return_attention_scores, return_attention_scores)
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
@ -542,8 +538,7 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_calculate_scores_multi_dim(self):
|
||||
# Query tensor of shape [1, 2, 4]
|
||||
q = np.array(
|
||||
[[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
# Key tensor of shape [1, 3, 4]
|
||||
k = np.array(
|
||||
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
|
||||
@ -562,10 +557,9 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
# expected011 = 0.5*tanh(2.+2.5) + 0.6*tanh(2.1+2.6) + 0.7*tanh(2.2+2.7) + 0.8*tanh(2.3+2.8) = 2.59964024652
|
||||
# expected012 = 0.5*tanh(2.+3.5) + 0.6*tanh(2.1+3.6) + 0.7*tanh(2.2+3.7) + 0.8*tanh(2.3+3.8) = 2.59995130916
|
||||
# pylint:enable=line-too-long
|
||||
expected = np.array(
|
||||
[[[2.58044532581, 2.59734317449, 2.59964024652],
|
||||
[2.59734317449, 2.59964024652, 2.59995130916]]],
|
||||
dtype=np.float32)
|
||||
expected = np.array([[[2.58044532581, 2.59734317449, 2.59964024652],
|
||||
[2.59734317449, 2.59964024652, 2.59995130916]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
def test_calculate_scores_one_dim_batch_size_two(self):
|
||||
@ -582,14 +576,13 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
# Expected tensor of shape [2, 1, 1].
|
||||
# expected000 = 0.5 * tanh(1.1 + 1.6) = 0.49550372683
|
||||
# expected100 = 0.5 * tanh(2.1 + 2.6) = 0.49991728277
|
||||
expected = np.array(
|
||||
[[[0.49550372683]], [[0.49991728277]]], dtype=np.float32)
|
||||
expected = np.array([[[0.49550372683]], [[0.49991728277]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
def test_shape(self):
|
||||
# Query tensor of shape [1, 2, 4]
|
||||
q = np.array(
|
||||
[[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
# Value tensor of shape [1, 3, 4]
|
||||
v = np.array(
|
||||
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
|
||||
@ -604,8 +597,7 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_shape_no_scale(self):
|
||||
# Query tensor of shape [1, 2, 4]
|
||||
q = np.array(
|
||||
[[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
# Value tensor of shape [1, 3, 4]
|
||||
v = np.array(
|
||||
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
|
||||
@ -620,8 +612,7 @@ class AdditiveAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_shape_with_key(self):
|
||||
# Query tensor of shape [1, 2, 4]
|
||||
q = np.array(
|
||||
[[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
q = np.array([[[1., 1.1, 1.2, 1.3], [2., 2.1, 2.2, 2.3]]], dtype=np.float32)
|
||||
# Value tensor of shape [1, 3, 4]
|
||||
v = np.array(
|
||||
[[[1.5, 1.6, 1.7, 1.8], [2.5, 2.6, 2.7, 2.8], [3.5, 3.6, 3.7, 3.8]]],
|
||||
@ -779,8 +770,8 @@ class LowerTriangularMaskTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def test_orthogonal_shape(self):
|
||||
actual = dense_attention._lower_triangular_mask([3, 2])
|
||||
expected = np.array(
|
||||
[[True, False], [True, True], [True, True]], dtype=np.bool_)
|
||||
expected = np.array([[True, False], [True, True], [True, True]],
|
||||
dtype=np.bool_)
|
||||
self.assertAllEqual(expected, actual)
|
||||
|
||||
def test_three_dim(self):
|
||||
|
@ -150,7 +150,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -150,7 +150,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -150,7 +150,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -150,7 +150,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
Loading…
Reference in New Issue
Block a user