Merge pull request #42625 from WindQAQ:add-training-call-argument

PiperOrigin-RevId: 335507127
Change-Id: I106e0f2f8e2a78762309fc02ea883ad6f5c43d39
This commit is contained in:
TensorFlower Gardener 2020-10-05 15:05:09 -07:00
commit 8ba1672f8c
4 changed files with 44 additions and 11 deletions

View File

@ -191,11 +191,15 @@ class MultiHeadAttention(Layer):
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
return_attention_scores: A boolean to indicate whether the output should
be attention output if True, or (attention_output, attention_scores) if
False. Defaults to False.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (no dropout).
Defaults to either using the training mode of the parent layer/model,
or False (inference) if there is no parent layer.
Returns:
attention_output: The result of the computation, of shape [B, T, E],
@ -396,7 +400,12 @@ class MultiHeadAttention(Layer):
attention_mask, axis=mask_expansion_axes)
return self._softmax(attention_scores, attention_mask)
def _compute_attention(self, query, key, value, attention_mask=None):
def _compute_attention(self,
query,
key,
value,
attention_mask=None,
training=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
@ -409,6 +418,8 @@ class MultiHeadAttention(Layer):
value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
training: Python boolean indicating whether the layer should behave in
training mode (adding dropout) or in inference mode (doing nothing).
Returns:
attention_output: Multi-headed outputs of attention computation.
@ -428,15 +439,21 @@ class MultiHeadAttention(Layer):
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(attention_scores)
attention_scores_dropout = self._dropout_layer(
attention_scores, training=training)
# `context_layer` = [B, T, N, H]
attention_output = special_math_ops.einsum(self._combine_equation,
attention_scores_dropout, value)
return attention_output, attention_scores
def call(self, query, value, key=None, attention_mask=None,
return_attention_scores=False):
def call(self,
query,
value,
key=None,
attention_mask=None,
return_attention_scores=False,
training=None):
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
@ -454,10 +471,9 @@ class MultiHeadAttention(Layer):
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask)
query, key, value, attention_mask, training)
attention_output = self._output_dense(attention_output)
if return_attention_scores:
return attention_output, attention_scores
return attention_output

View File

@ -225,6 +225,22 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
model.predict([query, value, mask_data]),
model.predict([query, value, null_mask_data]))
def test_dropout(self):
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=2, key_dim=2, dropout=0.5)
# Generate data for the input (non-mask) tensors.
from_data = keras.backend.ones(shape=(32, 4, 8))
to_data = keras.backend.ones(shape=(32, 2, 8))
train_out = test_layer(from_data, to_data, None, None, None, True)
test_out = test_layer(from_data, to_data, None, None, None, False)
# Output should be close when not in training mode,
# and should not be close when enabling dropout in training mode.
self.assertNotAllClose(
keras.backend.eval(train_out),
keras.backend.eval(test_out))
class SubclassAttention(multi_head_attention.MultiHeadAttention):
@ -235,7 +251,8 @@ class SubclassAttention(multi_head_attention.MultiHeadAttention):
query_tensor,
key_tensor,
value_tensor,
attention_mask=None):
attention_mask=None,
training=None):
return value_tensor, None

View File

@ -149,7 +149,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "compute_mask"

View File

@ -149,7 +149,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "compute_mask"