Supports return_attention_scores option in tf.keras.layers.Attention.

PiperOrigin-RevId: 324249853
Change-Id: I09d251722bb82b01965e161f3e33f3e570e1d5fe
This commit is contained in:
A. Unique TensorFlower 2020-07-31 11:37:20 -07:00 committed by TensorFlower Gardener
parent d0e0b226d4
commit 4f2eefab89
2 changed files with 43 additions and 114 deletions

View File

@ -49,8 +49,6 @@ class BaseDenseAttention(Layer):
flow of information from the future towards the past.
dropout: Float between 0 and 1. Fraction of the units to drop for the
attention scores.
return_attention_scores: bool, it `True`, returns the attention scores
(after masking and softmax) as an additional output argument.
Call Arguments:
@ -70,19 +68,15 @@ class BaseDenseAttention(Layer):
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (no dropout).
Output:
Output shape:
Attention outputs of shape `[batch_size, Tq, dim]`.
[Optional] Attention scores after masking and softmax with shape
`[batch_size, Tq, Tv]`.
"""
def __init__(self, causal=False, dropout=0.0, return_attention_scores=False,
**kwargs):
def __init__(self, causal=False, dropout=0.0, **kwargs):
super(BaseDenseAttention, self).__init__(**kwargs)
self.causal = causal
self.dropout = dropout
self.return_attention_scores = return_attention_scores
self.supports_masking = True
def _calculate_scores(self, query, key):
@ -121,8 +115,6 @@ class BaseDenseAttention(Layer):
Returns:
Tensor of shape `[batch_size, Tq, dim]`.
Attention scores after masking and softmax with shape
`[batch_size, Tq, Tv]`.
"""
if scores_mask is not None:
padding_mask = math_ops.logical_not(scores_mask)
@ -137,7 +129,7 @@ class BaseDenseAttention(Layer):
weights = control_flow_util.smart_cond(training, dropped_weights,
lambda: array_ops.identity(weights))
return math_ops.matmul(weights, value), weights
return math_ops.matmul(weights, value)
# TODO(b/125916026): Consider exposing a __call__ method with named args.
def call(self, inputs, mask=None, training=None):
@ -164,14 +156,12 @@ class BaseDenseAttention(Layer):
else:
causal_mask = None
scores_mask = _merge_masks(v_mask, causal_mask)
result, attention_scores = self._apply_scores(
result = self._apply_scores(
scores=scores, value=v, scores_mask=scores_mask, training=training)
if q_mask is not None:
# Mask of shape [batch_size, Tq, 1].
q_mask = array_ops.expand_dims(q_mask, axis=-1)
result *= math_ops.cast(q_mask, dtype=result.dtype)
if self.return_attention_scores:
return result, attention_scores
return result
def compute_mask(self, inputs, mask=None):
@ -209,7 +199,6 @@ class BaseDenseAttention(Layer):
config = {
'causal': self.causal,
'dropout': self.dropout,
'return_attention_scores': self.return_attention_scores,
}
base_config = super(BaseDenseAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@ -239,8 +228,6 @@ class Attention(BaseDenseAttention):
flow of information from the future towards the past.
dropout: Float between 0 and 1. Fraction of the units to drop for the
attention scores.
return_attention_scores: bool, it `True`, returns the attention scores
(after masking and softmax) as an additional output argument.
Call Arguments:
@ -260,11 +247,9 @@ class Attention(BaseDenseAttention):
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (no dropout).
Output:
Output shape:
Attention outputs of shape `[batch_size, Tq, dim]`.
[Optional] Attention scores after masking and softmax with shape
`[batch_size, Tq, Tv]`.
The meaning of `query`, `value` and `key` depend on the application. In the
case of text similarity, for example, `query` is the sequence embeddings of
@ -378,8 +363,6 @@ class AdditiveAttention(BaseDenseAttention):
flow of information from the future towards the past.
dropout: Float between 0 and 1. Fraction of the units to drop for the
attention scores.
return_attention_scores: bool, it `True`, returns the attention scores
(after masking and softmax) as an additional output argument.
Call Arguments:
@ -399,11 +382,9 @@ class AdditiveAttention(BaseDenseAttention):
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (no dropout).
Output:
Output shape:
Attention outputs of shape `[batch_size, Tq, dim]`.
[Optional] Attention scores after masking and softmax with shape
`[batch_size, Tq, Tv]`.
The meaning of `query`, `value` and `key` depend on the application. In the
case of text similarity, for example, `query` is the sequence embeddings of

View File

@ -40,14 +40,11 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
v = np.array([[[1.6]]], dtype=np.float32)
# Scores mask tensor of shape [1, 1, 1]
scores_mask = np.array([[[True]]], dtype=np.bool_)
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
actual = dense_attention.BaseDenseAttention()._apply_scores(
scores=scores, value=v, scores_mask=scores_mask)
# Expected softmax_scores = [[[1]]]
expected_scores = np.array([[[1.]]], dtype=np.float32)
self.assertAllClose(expected_scores, actual_scores)
# Expected tensor of shape [1, 1, 1].
# expected000 = softmax_scores[0, 0] * 1.6 = 1.6
# expected000 = softmax(scores)[0, 0] * 1.6 = 1.6
expected = np.array([[[1.6]]], dtype=np.float32)
self.assertAllClose(expected, actual)
@ -56,14 +53,11 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
scores = np.array([[[1.1]]], dtype=np.float32)
# Value tensor of shape [1, 1, 1]
v = np.array([[[1.6]]], dtype=np.float32)
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
actual = dense_attention.BaseDenseAttention()._apply_scores(
scores=scores, value=v)
# Expected softmax_scores = [[[1]]]
expected_scores = np.array([[[1.]]], dtype=np.float32)
self.assertAllClose(expected_scores, actual_scores)
# Expected tensor of shape [1, 1, 1].
# expected000 = softmax_scores[0, 0] * 1.6 = 1.6
# expected000 = softmax(scores)[0, 0] * 1.6 = 1.6
expected = np.array([[[1.6]]], dtype=np.float32)
self.assertAllClose(expected, actual)
@ -74,17 +68,15 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
# Scores mask tensor of shape [1, 1, 3]
scores_mask = np.array([[[True, True, False]]], dtype=np.bool_)
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
actual = dense_attention.BaseDenseAttention()._apply_scores(
scores=scores, value=v, scores_mask=scores_mask)
# Expected softmax scores = softmax(scores) with zeros in positions where
# v_mask == False.
# => softmax_scores000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863
# softmax_scores001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137
# softmax_scores002 = 0
expected_scores = np.array(
[[[0.73105857863, 0.26894142137, 0.]]], dtype=np.float32)
self.assertAllClose(expected_scores, actual_scores)
# Expected attention distribution = softmax(scores) with zeros in
# positions where v_mask == False.
# => attention_distribution000 = exp(1)/(exp(1) + exp(0)) = 0.73105857863
# attention_distribution001 = exp(0)/(exp(1) + exp(0)) = 0.26894142137
# attention_distribution002 = 0
#
# Expected tensor of shape [1, 1, 1].
# expected000 = 0.73105857863 * 1.6 + 0.26894142137 * 0.7 - 0 * 0.8
# = 1.35795272077
@ -96,19 +88,17 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
scores = np.array([[[1., 0., 1.]]], dtype=np.float32)
# Value tensor of shape [1, 3, 1]
v = np.array([[[1.6], [0.7], [-0.8]]], dtype=np.float32)
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
actual = dense_attention.BaseDenseAttention()._apply_scores(
scores=scores, value=v)
# Expected softmax_scores = softmax(scores).
# => softmax_scores000 = exp(1)/(exp(1) + exp(0) + exp(1))
# = 0.42231879825
# softmax_scores001 = exp(0)/(exp(1) + exp(0) + exp(1))
# = 0.15536240349
# softmax_scores002 = exp(1)/(exp(1) + exp(0) + exp(1))
# = 0.42231879825
expected_scores = np.array(
[[[0.42231879825, 0.15536240349, 0.42231879825]]], dtype=np.float32)
self.assertAllClose(expected_scores, actual_scores)
# Expected attention distribution = softmax(scores).
# => attention_distribution000 = exp(1)/(exp(1) + exp(0) + exp(1))
# = 0.42231879825
# attention_distribution001 = exp(0)/(exp(1) + exp(0) + exp(1))
# = 0.15536240349
# attention_distribution002 = exp(1)/(exp(1) + exp(0) + exp(1))
# = 0.42231879825
#
# Expected tensor of shape [1, 1, 1].
# expected000 = 0.42231879825 * 1.6 + 0.15536240349 * 0.7
# - 0.42231879825 * 0.8
@ -123,15 +113,12 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
v = np.array([[[1.6]], [[2.6]]], dtype=np.float32)
# Scpres mask tensor of shape [2, 1, 1]
scores_mask = np.array([[[True]], [[True]]], dtype=np.bool_)
actual, actual_scores = dense_attention.BaseDenseAttention()._apply_scores(
actual = dense_attention.BaseDenseAttention()._apply_scores(
scores=scores, value=v, scores_mask=scores_mask)
# Expected softmax_scores = [[[1]], [[1]]]
expected_scores = np.array([[[1.]], [[1.]]], dtype=np.float32)
self.assertAllClose(expected_scores, actual_scores)
# Expected tensor of shape [2, 1, 1].
# expected000 = softmax_scores[0, 0] * 1.6 = 1.6
# expected100 = softmax_scores[1, 0] * 2.6 = 2.6
# expected000 = softmax(scores)[0, 0] * 1.6 = 1.6
# expected100 = softmax(scores)[1, 0] * 2.6 = 2.6
expected = np.array([[[1.6]], [[2.6]]], dtype=np.float32)
self.assertAllClose(expected, actual)
@ -144,13 +131,9 @@ class BaseDenseAttentionTest(test.TestCase, parameterized.TestCase):
dim = 7
scores = np.ones((batch_size, tq, tv))
value = np.ones((batch_size, tv, dim))
actual, actual_scores = dense_attention.BaseDenseAttention(
dropout=0.1)._apply_scores(
scores=scores, value=value, training=False)
actual = dense_attention.BaseDenseAttention(dropout=0.1)._apply_scores(
scores=scores, value=value, training=False)
# Expected Tensor of shape `[batch_size, tq, tv]`.
expected_scores_shape = [batch_size, tq, tv]
self.assertAllEqual(expected_scores_shape, array_ops.shape(actual_scores))
# Expected Tensor of shape `[batch_size, tq, dim]`.
expected_shape = [batch_size, tq, dim]
self.assertAllEqual(expected_shape, array_ops.shape(actual))
@ -329,11 +312,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
expected = np.array([[[0.58127362329]]], dtype=np.float32)
self.assertAllClose(expected, actual)
@parameterized.named_parameters(
('', False),
('return_attention_scores', True),
)
def test_multi_dim_with_query_mask(self, return_attention_scores):
def test_multi_dim_with_query_mask(self):
# Query tensor of shape [1, 2, 1]
q = np.array([[[1.1], [-0.5]]], dtype=np.float32)
# Value tensor of shape [1, 3, 1]
@ -342,12 +321,8 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
q_mask = np.array([[True, False]], dtype=np.bool_)
# Value mask tensor of shape [1, 3]
v_mask = np.array([[True, True, False]], dtype=np.bool_)
attention_layer = dense_attention.Attention(
return_attention_scores=return_attention_scores)
if return_attention_scores:
actual, actual_scores = attention_layer([q, v], mask=[q_mask, v_mask])
else:
actual = attention_layer([q, v], mask=[q_mask, v_mask])
attention_layer = dense_attention.Attention()
actual = attention_layer([q, v], mask=[q_mask, v_mask])
# Expected scores of shape [1, 2, 3]
# scores = [[[1.1*1.6, 1.1*0.7, -1.1*0.8], [-0.5*1.6, -0.5*0.7, 0.5*0.8]]]
@ -364,12 +339,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
# attention_distribution011 = exp(-0.35)/(exp(-0.8) + exp(-0.35))
# = 0.61063923394
# attention_distribution012 = 0
if return_attention_scores:
expected_scores = np.array(
[[[0.72908792234, 0.27091207765, 0.],
[0.38936076605, 0.61063923394, 0.]]],
dtype=np.float32)
self.assertAllClose(expected_scores, actual_scores)
#
# Expected tensor of shape [1, 2, 1] with zeros where q_mask == False.
# expected000 = 0.72908792234 * 1.6 + 0.27091207765 * 0.7 - 0 * 0.8
# = 1.3561791301
@ -398,19 +368,11 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
sess.run(attention_layer.scale.initializer)
self.assertAllClose(1., attention_layer.scale.value())
@parameterized.named_parameters(
('', False),
('return_attention_scores', True),
)
def test_self_attention_causal(self, return_attention_scores):
def test_self_attention_causal(self):
# Query-value tensor of shape [1, 3, 1]
q = np.array([[[0.5], [0.8], [-0.3]]], dtype=np.float32)
attention_layer = dense_attention.Attention(
causal=True, return_attention_scores=return_attention_scores)
if return_attention_scores:
actual, actual_scores = attention_layer([q, q])
else:
actual = attention_layer([q, q])
attention_layer = dense_attention.Attention(causal=True)
actual = attention_layer([q, q])
# Expected scores of shape [1, 3, 3]
# scores = [[0.25, 0.4, -0.15], [0.4, 0.64, -0.24], [-0.15, -0.24, 0.09]]
@ -423,13 +385,7 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
# = [exp(-0.15), exp(-0.24), exp(0.09)]
# / (exp(-0.15) + exp(-0.24) + exp(0.09))
# = [0.31395396638, 0.28693232061, 0.399113713]
if return_attention_scores:
expected_scores = np.array(
[[[1., 0., 0.],
[0.44028635073, 0.55971364926, 0.],
[0.31395396638, 0.28693232061, 0.399113713]]],
dtype=np.float32)
self.assertAllClose(expected_scores, actual_scores)
#
# Expected tensor of shape [1, 3, 1].
# expected000 = 0.5
# expected010 = 0.44028635073 * 0.5 + 0.55971364926 * 0.8
@ -499,25 +455,17 @@ class AttentionTest(test.TestCase, parameterized.TestCase):
actual = attention_layer([q, v])
self.assertAllClose([[[0], [1]]], actual)
@parameterized.named_parameters(
('', False, False),
('use_scale', True, False),
('return_attention_scores', False, True),
)
def test_serialization(self, use_scale, return_attention_scores):
def test_serialization(self):
# Test serialization with use_scale
layer = dense_attention.Attention(
use_scale=use_scale, return_attention_scores=return_attention_scores)
layer = dense_attention.Attention(use_scale=True)
config = keras.layers.serialize(layer)
new_layer = keras.layers.deserialize(config)
self.assertEqual(new_layer.use_scale, use_scale)
self.assertEqual(new_layer.return_attention_scores, return_attention_scores)
self.assertEqual(new_layer.use_scale, True)
config = layer.get_config()
new_layer = dense_attention.Attention.from_config(config)
self.assertEqual(new_layer.use_scale, use_scale)
self.assertEqual(new_layer.return_attention_scores, return_attention_scores)
self.assertEqual(new_layer.use_scale, True)
@combinations.generate(combinations.combine(mode=['graph', 'eager']))