Merge pull request #42625 from WindQAQ:add-training-call-argument
PiperOrigin-RevId: 335507127 Change-Id: I106e0f2f8e2a78762309fc02ea883ad6f5c43d39
This commit is contained in:
commit
8ba1672f8c
@ -191,11 +191,15 @@ class MultiHeadAttention(Layer):
|
||||
value: Value `Tensor` of shape `[B, S, dim]`.
|
||||
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
|
||||
`value` for both `key` and `value`, which is the most common case.
|
||||
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
|
||||
to certain positions.
|
||||
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
|
||||
attention to certain positions.
|
||||
return_attention_scores: A boolean to indicate whether the output should
|
||||
be attention output if True, or (attention_output, attention_scores) if
|
||||
False. Defaults to False.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (no dropout).
|
||||
Defaults to either using the training mode of the parent layer/model,
|
||||
or False (inference) if there is no parent layer.
|
||||
|
||||
Returns:
|
||||
attention_output: The result of the computation, of shape [B, T, E],
|
||||
@ -396,7 +400,12 @@ class MultiHeadAttention(Layer):
|
||||
attention_mask, axis=mask_expansion_axes)
|
||||
return self._softmax(attention_scores, attention_mask)
|
||||
|
||||
def _compute_attention(self, query, key, value, attention_mask=None):
|
||||
def _compute_attention(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask=None,
|
||||
training=None):
|
||||
"""Applies Dot-product attention with query, key, value tensors.
|
||||
|
||||
This function defines the computation inside `call` with projected
|
||||
@ -409,6 +418,8 @@ class MultiHeadAttention(Layer):
|
||||
value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
|
||||
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
|
||||
attention to certain positions.
|
||||
training: Python boolean indicating whether the layer should behave in
|
||||
training mode (adding dropout) or in inference mode (doing nothing).
|
||||
|
||||
Returns:
|
||||
attention_output: Multi-headed outputs of attention computation.
|
||||
@ -428,15 +439,21 @@ class MultiHeadAttention(Layer):
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_scores_dropout = self._dropout_layer(attention_scores)
|
||||
attention_scores_dropout = self._dropout_layer(
|
||||
attention_scores, training=training)
|
||||
|
||||
# `context_layer` = [B, T, N, H]
|
||||
attention_output = special_math_ops.einsum(self._combine_equation,
|
||||
attention_scores_dropout, value)
|
||||
return attention_output, attention_scores
|
||||
|
||||
def call(self, query, value, key=None, attention_mask=None,
|
||||
return_attention_scores=False):
|
||||
def call(self,
|
||||
query,
|
||||
value,
|
||||
key=None,
|
||||
attention_mask=None,
|
||||
return_attention_scores=False,
|
||||
training=None):
|
||||
if not self._built_from_signature:
|
||||
self._build_from_signature(query=query, value=value, key=key)
|
||||
if key is None:
|
||||
@ -454,10 +471,9 @@ class MultiHeadAttention(Layer):
|
||||
value = self._value_dense(value)
|
||||
|
||||
attention_output, attention_scores = self._compute_attention(
|
||||
query, key, value, attention_mask)
|
||||
query, key, value, attention_mask, training)
|
||||
attention_output = self._output_dense(attention_output)
|
||||
|
||||
if return_attention_scores:
|
||||
return attention_output, attention_scores
|
||||
return attention_output
|
||||
|
||||
|
@ -225,6 +225,22 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
|
||||
model.predict([query, value, mask_data]),
|
||||
model.predict([query, value, null_mask_data]))
|
||||
|
||||
def test_dropout(self):
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=2, key_dim=2, dropout=0.5)
|
||||
|
||||
# Generate data for the input (non-mask) tensors.
|
||||
from_data = keras.backend.ones(shape=(32, 4, 8))
|
||||
to_data = keras.backend.ones(shape=(32, 2, 8))
|
||||
train_out = test_layer(from_data, to_data, None, None, None, True)
|
||||
test_out = test_layer(from_data, to_data, None, None, None, False)
|
||||
|
||||
# Output should be close when not in training mode,
|
||||
# and should not be close when enabling dropout in training mode.
|
||||
self.assertNotAllClose(
|
||||
keras.backend.eval(train_out),
|
||||
keras.backend.eval(test_out))
|
||||
|
||||
|
||||
class SubclassAttention(multi_head_attention.MultiHeadAttention):
|
||||
|
||||
@ -235,7 +251,8 @@ class SubclassAttention(multi_head_attention.MultiHeadAttention):
|
||||
query_tensor,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
attention_mask=None):
|
||||
attention_mask=None,
|
||||
training=None):
|
||||
return value_tensor, None
|
||||
|
||||
|
||||
|
@ -149,7 +149,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -149,7 +149,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
Loading…
Reference in New Issue
Block a user