Supports return_attention_scores option in tf.keras.layers.Attention.
PiperOrigin-RevId: 324249853 Change-Id: I09d251722bb82b01965e161f3e33f3e570e1d5fe
This commit is contained in:
parent
d0e0b226d4
commit
4f2eefab89
@ -49,8 +49,6 @@ class BaseDenseAttention(Layer):
|
||||
flow of information from the future towards the past.
|
||||
dropout: Float between 0 and 1. Fraction of the units to drop for the
|
||||
attention scores.
|
||||
return_attention_scores: bool, it `True`, returns the attention scores
|
||||
(after masking and softmax) as an additional output argument.
|
||||
|
||||
Call Arguments:
|
||||
|
||||
@ -70,19 +68,15 @@ class BaseDenseAttention(Layer):
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
|
||||
Output:
|
||||
Output shape:
|
||||
|
||||
Attention outputs of shape `[batch_size, Tq, dim]`.
|
||||
[Optional] Attention scores after masking and softmax with shape
|
||||
`[batch_size, Tq, Tv]`.
|
||||
"""
|
||||
|
||||
def __init__(self, causal=False, dropout=0.0, return_attention_scores=False,
|
||||
**kwargs):
|
||||
def __init__(self, causal=False, dropout=0.0, **kwargs):
|
||||
super(BaseDenseAttention, self).__init__(**kwargs)
|
||||
self.causal = causal
|
||||
self.dropout = dropout
|
||||
self.return_attention_scores = return_attention_scores
|
||||
self.supports_masking = True
|
||||
|
||||
def _calculate_scores(self, query, key):
|
||||
@ -121,8 +115,6 @@ class BaseDenseAttention(Layer):
|
||||
|
||||
Returns:
|
||||
Tensor of shape `[batch_size, Tq, dim]`.
|
||||
Attention scores after masking and softmax with shape
|
||||
`[batch_size, Tq, Tv]`.
|
||||
"""
|
||||
if scores_mask is not None:
|
||||
padding_mask = math_ops.logical_not(scores_mask)
|
||||
@ -137,7 +129,7 @@ class BaseDenseAttention(Layer):
|
||||
|
||||
weights = control_flow_util.smart_cond(training, dropped_weights,
|
||||
lambda: array_ops.identity(weights))
|
||||
return math_ops.matmul(weights, value), weights
|
||||
return math_ops.matmul(weights, value)
|
||||
|
||||
# TODO(b/125916026): Consider exposing a __call__ method with named args.
|
||||
def call(self, inputs, mask=None, training=None):
|
||||
@ -164,14 +156,12 @@ class BaseDenseAttention(Layer):
|
||||
else:
|
||||
causal_mask = None
|
||||
scores_mask = _merge_masks(v_mask, causal_mask)
|
||||
result, attention_scores = self._apply_scores(
|
||||
result = self._apply_scores(
|
||||
scores=scores, value=v, scores_mask=scores_mask, training=training)
|
||||
if q_mask is not None:
|
||||
# Mask of shape [batch_size, Tq, 1].
|
||||
q_mask = array_ops.expand_dims(q_mask, axis=-1)
|
||||
result *= math_ops.cast(q_mask, dtype=result.dtype)
|
||||
if self.return_attention_scores:
|
||||
return result, attention_scores
|
||||
return result
|
||||
|
||||
def compute_mask(self, inputs, mask=None):
|
||||
@ -209,7 +199,6 @@ class BaseDenseAttention(Layer):
|
||||
config = {
|
||||
'causal': self.causal,
|
||||
'dropout': self.dropout,
|
||||
'return_attention_scores': self.return_attention_scores,
|
||||
}
|
||||
base_config = super(BaseDenseAttention, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
@ -239,8 +228,6 @@ class Attention(BaseDenseAttention):
|
||||
flow of information from the future towards the past.
|
||||
dropout: Float between 0 and 1. Fraction of the units to drop for the
|
||||
attention scores.
|
||||
return_attention_scores: bool, it `True`, returns the attention scores
|
||||
(after masking and softmax) as an additional output argument.
|
||||
|
||||
Call Arguments:
|
||||
|
||||
@ -260,11 +247,9 @@ class Attention(BaseDenseAttention):
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
|
||||
Output:
|
||||
Output shape:
|
||||
|
||||
Attention outputs of shape `[batch_size, Tq, dim]`.
|
||||
[Optional] Attention scores after masking and softmax with shape
|
||||
`[batch_size, Tq, Tv]`.
|
||||
|
||||
The meaning of `query`, `value` and `key` depend on the application. In the
|
||||
case of text similarity, for example, `query` is the sequence embeddings of
|
||||
@ -378,8 +363,6 @@ class AdditiveAttention(BaseDenseAttention):
|
||||
flow of information from the future towards the past.
|
||||
dropout: Float between 0 and 1. Fraction of the units to drop for the
|
||||
attention scores.
|
||||
return_attention_scores: bool, it `True`, returns the attention scores
|
||||
(after masking and softmax) as an additional output argument.
|
||||
|
||||
Call Arguments:
|
||||
|
||||
@ -399,11 +382,9 @@ class AdditiveAttention(BaseDenseAttention):
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
|
||||
Output:
|
||||
Output shape:
|
||||
|
||||
Attention outputs of shape `[batch_size, Tq, dim]`.
|
||||
[Optional] Attention scores after masking and softmax with shape
|
||||
`[batch_size, Tq, Tv]`.
|
||||
|
||||
The meaning of `query`, `value` and `key` depend on the application. In the
|
||||
case of text similarity, for example, `query` is the sequence embeddings of
|
||||
|
@ -40,14 +40,11 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
v = np.array([[[1.6]]], dtype=np.float32)
|
||||
# Scores mask tensor of shape [1, 1, 1]
|
||||
scores_mask = np.array([[[True]]], dtype=np.bool_)
|
||||
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
actual = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
scores=scores, value=v, scores_mask=scores_mask)
|
||||
|
||||
# Expected softmax_scores = [[[1]]]
|
||||
expected_scores = np.array([[[1.]]], dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
# Expected tensor of shape [1, 1, 1].
|
||||
# expected000 = softmax_scores[0, 0] * 1.6 = 1.6
|
||||
# expected000 = softmax(scores)[0, 0] * 1.6 = 1.6
|
||||
expected = np.array([[[1.6]]], dtype=np.float32)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@ -56,14 +53,11 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
scores = np.array([[[1.1]]], dtype=np.float32)
|
||||
# Value tensor of shape [1, 1, 1]
|
||||
v = np.array([[[1.6]]], dtype=np.float32)
|
||||
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
actual = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
scores=scores, value=v)
|
||||
|
||||
# Expected softmax_scores = [[[1]]]
|
||||
expected_scores = np.array([[[1.]]], dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
# Expected tensor of shape [1, 1, 1].
|
||||
# expected000 = softmax_scores[0, 0] * 1.6 = 1.6
|
||||
# expected000 = softmax(scores)[0, 0] * 1.6 = 1.6
|
||||
expected = np.array([[[1.6]]], dtype=np.float32)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@ -74,17 +68,15 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
|
||||
# Scores mask tensor of shape [1, 1, 3]
|
||||
scores_mask = np.array([[[True, True, False]]], dtype=np.bool_)
|
||||
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
actual = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
scores=scores, value=v, scores_mask=scores_mask)
|
||||
|
||||
# Expected softmax scores = softmax(scores) with zeros in positions where
|
||||
# v_mask == False.
|
||||
# => softmax_scores000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863
|
||||
# softmax_scores001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137
|
||||
# softmax_scores002 = 0
|
||||
expected_scores = np.array(
|
||||
[[[0.73105857863, 0.26894142137, 0.]]], dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
# Expected attention distribution = softmax(scores) with zeros in
|
||||
# positions where v_mask == False.
|
||||
# => attention_distribution000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863
|
||||
# attention_distribution001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137
|
||||
# attention_distribution002 = 0
|
||||
#
|
||||
# Expected tensor of shape [1, 1, 1].
|
||||
# expected000 = 0.73105857863 * 1.6 + 0.26894142137 * 0.7 - 0 * 0.8
|
||||
# = 1.35795272077
|
||||
@ -96,19 +88,17 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
scores = np.array([[[1., 0., 1.]]], dtype=np.float32)
|
||||
# Value tensor of shape [1, 3, 1]
|
||||
v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
|
||||
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
actual = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
scores=scores, value=v)
|
||||
|
||||
# Expected softmax_scores = softmax(scores).
|
||||
# => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1))
|
||||
# = 0.42231879825
|
||||
# softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1))
|
||||
# = 0.15536240349
|
||||
# softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1))
|
||||
# = 0.42231879825
|
||||
expected_scores = np.array(
|
||||
[[[0.42231879825, 0.15536240349, 0.42231879825]]], dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
# Expected attention distribution = softmax(scores).
|
||||
# => attention_distribution000 = exp(1)/(exp(1) + exp(0) + exp(1))
|
||||
# = 0.42231879825
|
||||
# attention_distribution001 = exp(0)/(exp(1) + exp(0) + exp(1))
|
||||
# = 0.15536240349
|
||||
# attention_distribution002 = exp(1)/(exp(1) + exp(0) + exp(1))
|
||||
# = 0.42231879825
|
||||
#
|
||||
# Expected tensor of shape [1, 1, 1].
|
||||
# expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7
|
||||
# - 0.42231879825 * 0.8
|
||||
@ -123,15 +113,12 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
v = np.array([[[1.6]], [[2.6]]], dtype=np.float32)
|
||||
# Scpres mask tensor of shape [2, 1, 1]
|
||||
scores_mask = np.array([[[True]], [[True]]], dtype=np.bool_)
|
||||
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
actual = dense_attention.BaseDenseAttention()._apply_scores(
|
||||
scores=scores, value=v, scores_mask=scores_mask)
|
||||
|
||||
# Expected softmax_scores = [[[1]], [[1]]]
|
||||
expected_scores = np.array([[[1.]], [[1.]]], dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
# Expected tensor of shape [2, 1, 1].
|
||||
# expected000 = softmax_scores[0, 0] * 1.6 = 1.6
|
||||
# expected100 = softmax_scores[1, 0] * 2.6 = 2.6
|
||||
# expected000 = softmax(scores)[0, 0] * 1.6 = 1.6
|
||||
# expected100 = softmax(scores)[1, 0] * 2.6 = 2.6
|
||||
expected = np.array([[[1.6]], [[2.6]]], dtype=np.float32)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@ -144,13 +131,9 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
|
||||
dim = 7
|
||||
scores = np.ones((batch_size, tq, tv))
|
||||
value = np.ones((batch_size, tv, dim))
|
||||
actual, actual_scores = dense_attention.BaseDenseAttention(
|
||||
dropout=0.1)._apply_scores(
|
||||
scores=scores, value=value, training=False)
|
||||
actual = dense_attention.BaseDenseAttention(dropout=0.1)._apply_scores(
|
||||
scores=scores, value=value, training=False)
|
||||
|
||||
# Expected Tensor of shape `[batch_size, tq, tv]`.
|
||||
expected_scores_shape = [batch_size, tq, tv]
|
||||
self.assertAllEqual(expected_scores_shape, array_ops.shape(actual_scores))
|
||||
# Expected Tensor of shape `[batch_size, tq, dim]`.
|
||||
expected_shape = [batch_size, tq, dim]
|
||||
self.assertAllEqual(expected_shape, array_ops.shape(actual))
|
||||
@ -329,11 +312,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
expected = np.array([[[0.58127362329]]], dtype=np.float32)
|
||||
self.assertAllClose(expected, actual)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('', False),
|
||||
('return_attention_scores', True),
|
||||
)
|
||||
def test_multi_dim_with_query_mask(self, return_attention_scores):
|
||||
def test_multi_dim_with_query_mask(self):
|
||||
# Query tensor of shape [1, 2, 1]
|
||||
q = np.array([[[1.1], [-0.5]]], dtype=np.float32)
|
||||
# Value tensor of shape [1, 3, 1]
|
||||
@ -342,12 +321,8 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
q_mask = np.array([[True, False]], dtype=np.bool_)
|
||||
# Value mask tensor of shape [1, 3]
|
||||
v_mask = np.array([[True, True, False]], dtype=np.bool_)
|
||||
attention_layer = dense_attention.Attention(
|
||||
return_attention_scores=return_attention_scores)
|
||||
if return_attention_scores:
|
||||
actual, actual_scores = attention_layer([q, v], mask=[q_mask, v_mask])
|
||||
else:
|
||||
actual = attention_layer([q, v], mask=[q_mask, v_mask])
|
||||
attention_layer = dense_attention.Attention()
|
||||
actual = attention_layer([q, v], mask=[q_mask, v_mask])
|
||||
|
||||
# Expected scores of shape [1, 2, 3]
|
||||
# scores = [[[1.1*1.6, 1.1*0.7, -1.1*0.8], [-0.5*1.6, -0.5*0.7, 0.5*0.8]]]
|
||||
@ -364,12 +339,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
# attention_distribution011 = exp(-0.35)/(exp(-0.8) + exp(-0.35))
|
||||
# = 0.61063923394
|
||||
# attention_distribution012 = 0
|
||||
if return_attention_scores:
|
||||
expected_scores = np.array(
|
||||
[[[0.72908792234, 0.27091207765, 0.],
|
||||
[0.38936076605, 0.61063923394, 0.]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
#
|
||||
# Expected tensor of shape [1, 2, 1] with zeros where q_mask == False.
|
||||
# expected000 = 0.72908792234 * 1.6 + 0.27091207765 * 0.7 - 0 * 0.8
|
||||
# = 1.3561791301
|
||||
@ -398,19 +368,11 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
sess.run(attention_layer.scale.initializer)
|
||||
self.assertAllClose(1., attention_layer.scale.value())
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('', False),
|
||||
('return_attention_scores', True),
|
||||
)
|
||||
def test_self_attention_causal(self, return_attention_scores):
|
||||
def test_self_attention_causal(self):
|
||||
# Query-value tensor of shape [1, 3, 1]
|
||||
q = np.array([[[0.5], [0.8], [-0.3]]], dtype=np.float32)
|
||||
attention_layer = dense_attention.Attention(
|
||||
causal=True, return_attention_scores=return_attention_scores)
|
||||
if return_attention_scores:
|
||||
actual, actual_scores = attention_layer([q, q])
|
||||
else:
|
||||
actual = attention_layer([q, q])
|
||||
attention_layer = dense_attention.Attention(causal=True)
|
||||
actual = attention_layer([q, q])
|
||||
|
||||
# Expected scores of shape [1, 3, 3]
|
||||
# scores = [[0.25, 0.4, -0.15], [0.4, 0.64, -0.24], [-0.15, -0.24, 0.09]]
|
||||
@ -423,13 +385,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
# = [exp(-0.15), exp(-0.24), exp(0.09)]
|
||||
# / (exp(-0.15) + exp(-0.24) + exp(0.09))
|
||||
# = [0.31395396638, 0.28693232061, 0.399113713]
|
||||
if return_attention_scores:
|
||||
expected_scores = np.array(
|
||||
[[[1., 0., 0.],
|
||||
[0.44028635073, 0.55971364926, 0.],
|
||||
[0.31395396638, 0.28693232061, 0.399113713]]],
|
||||
dtype=np.float32)
|
||||
self.assertAllClose(expected_scores, actual_scores)
|
||||
#
|
||||
# Expected tensor of shape [1, 3, 1].
|
||||
# expected000 = 0.5
|
||||
# expected010 = 0.44028635073 * 0.5 + 0.55971364926 * 0.8
|
||||
@ -499,25 +455,17 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
|
||||
actual = attention_layer([q, v])
|
||||
self.assertAllClose([[[0], [1]]], actual)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('', False, False),
|
||||
('use_scale', True, False),
|
||||
('return_attention_scores', False, True),
|
||||
)
|
||||
def test_serialization(self, use_scale, return_attention_scores):
|
||||
def test_serialization(self):
|
||||
# Test serialization with use_scale
|
||||
layer = dense_attention.Attention(
|
||||
use_scale=use_scale, return_attention_scores=return_attention_scores)
|
||||
layer = dense_attention.Attention(use_scale=True)
|
||||
|
||||
config = keras.layers.serialize(layer)
|
||||
new_layer = keras.layers.deserialize(config)
|
||||
self.assertEqual(new_layer.use_scale, use_scale)
|
||||
self.assertEqual(new_layer.return_attention_scores, return_attention_scores)
|
||||
self.assertEqual(new_layer.use_scale, True)
|
||||
|
||||
config = layer.get_config()
|
||||
new_layer = dense_attention.Attention.from_config(config)
|
||||
self.assertEqual(new_layer.use_scale, use_scale)
|
||||
self.assertEqual(new_layer.return_attention_scores, return_attention_scores)
|
||||
self.assertEqual(new_layer.use_scale, True)
|
||||
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
|
Loading…
x
Reference in New Issue
Block a user