Add training call argument for MultiHeadAttention
Remove trailing space
This commit is contained in:
parent
ee31098e59
commit
ba2198fc73
@ -396,7 +396,8 @@ class MultiHeadAttention(Layer):
|
||||
attention_mask, axis=mask_expansion_axes)
|
||||
return self._softmax(attention_scores, attention_mask)
|
||||
|
||||
def _compute_attention(self, query, key, value, attention_mask=None):
|
||||
def _compute_attention(self, query, key, value, attention_mask=None,
|
||||
training=None):
|
||||
"""Applies Dot-product attention with query, key, value tensors.
|
||||
|
||||
This function defines the computation inside `call` with projected
|
||||
@ -428,7 +429,8 @@ class MultiHeadAttention(Layer):
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_scores_dropout = self._dropout_layer(attention_scores)
|
||||
attention_scores_dropout = self._dropout_layer(attention_scores,
|
||||
training=training)
|
||||
|
||||
# `context_layer` = [B, T, N, H]
|
||||
attention_output = special_math_ops.einsum(self._combine_equation,
|
||||
@ -436,7 +438,7 @@ class MultiHeadAttention(Layer):
|
||||
return attention_output, attention_scores
|
||||
|
||||
def call(self, query, value, key=None, attention_mask=None,
|
||||
return_attention_scores=False):
|
||||
return_attention_scores=False, training=None):
|
||||
if not self._built_from_signature:
|
||||
self._build_from_signature(query=query, value=value, key=key)
|
||||
if key is None:
|
||||
@ -454,7 +456,7 @@ class MultiHeadAttention(Layer):
|
||||
value = self._value_dense(value)
|
||||
|
||||
attention_output, attention_scores = self._compute_attention(
|
||||
query, key, value, attention_mask)
|
||||
query, key, value, attention_mask, training)
|
||||
attention_output = self._output_dense(attention_output)
|
||||
|
||||
if return_attention_scores:
|
||||
|
@ -225,6 +225,22 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
|
||||
model.predict([query, value, mask_data]),
|
||||
model.predict([query, value, null_mask_data]))
|
||||
|
||||
def test_dropout(self):
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=2, key_dim=2, dropout=0.5)
|
||||
|
||||
# Generate data for the input (non-mask) tensors.
|
||||
from_data = 10 * np.random.random_sample((32, 4, 8))
|
||||
to_data = 10 * np.random.random_sample((32, 2, 8))
|
||||
|
||||
# Output should be close when not in training mode,
|
||||
# and should not be close when enabling dropout in training mode.
|
||||
self.assertAllClose(test_layer(from_data, to_data, training=False),
|
||||
test_layer(from_data, to_data, training=False))
|
||||
|
||||
self.assertNotAllClose(test_layer(from_data, to_data, training=True),
|
||||
test_layer(from_data, to_data, training=False))
|
||||
|
||||
|
||||
class SubclassAttention(multi_head_attention.MultiHeadAttention):
|
||||
|
||||
@ -235,7 +251,8 @@ class SubclassAttention(multi_head_attention.MultiHeadAttention):
|
||||
query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
attention_mask=None):
|
||||
attention_mask=None,
|
||||
training=None):
|
||||
return value_tensor, None
|
||||
|
||||
|
||||
|
@ -149,7 +149,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -149,7 +149,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
Loading…
x
Reference in New Issue
Block a user