Add training call argument for MultiHeadAttention

Remove trailing space
This commit is contained in:
Tzu-Wei Sung 2020-08-24 09:20:51 -07:00
parent ee31098e59
commit ba2198fc73
4 changed files with 26 additions and 7 deletions

View File

@ -396,7 +396,8 @@ class MultiHeadAttention(Layer):
attention_mask, axis=mask_expansion_axes)
return self._softmax(attention_scores, attention_mask)
def _compute_attention(self, query, key, value, attention_mask=None):
def _compute_attention(self, query, key, value, attention_mask=None,
training=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
@ -428,7 +429,8 @@ class MultiHeadAttention(Layer):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(attention_scores)
attention_scores_dropout = self._dropout_layer(attention_scores,
training=training)
# `context_layer` = [B, T, N, H]
attention_output = special_math_ops.einsum(self._combine_equation,
@ -436,7 +438,7 @@ class MultiHeadAttention(Layer):
return attention_output, attention_scores
def call(self, query, value, key=None, attention_mask=None,
return_attention_scores=False):
return_attention_scores=False, training=None):
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
@ -454,7 +456,7 @@ class MultiHeadAttention(Layer):
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask)
query, key, value, attention_mask, training)
attention_output = self._output_dense(attention_output)
if return_attention_scores:

View File

@ -225,6 +225,22 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
model.predict([query, value, mask_data]),
model.predict([query, value, null_mask_data]))
def test_dropout(self):
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=2, key_dim=2, dropout=0.5)
# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((32, 4, 8))
to_data = 10 * np.random.random_sample((32, 2, 8))
# Output should be close when not in training mode,
# and should not be close when enabling dropout in training mode.
self.assertAllClose(test_layer(from_data, to_data, training=False),
test_layer(from_data, to_data, training=False))
self.assertNotAllClose(test_layer(from_data, to_data, training=True),
test_layer(from_data, to_data, training=False))
class SubclassAttention(multi_head_attention.MultiHeadAttention):
@ -235,7 +251,8 @@ class SubclassAttention(multi_head_attention.MultiHeadAttention):
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
attention_mask=None,
training=None):
return value_tensor, None

View File

@ -149,7 +149,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "compute_mask"

View File

@ -149,7 +149,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "compute_mask"