Merge pull request #42625 from WindQAQ:add-training-call-argument
PiperOrigin-RevId: 335507127 Change-Id: I106e0f2f8e2a78762309fc02ea883ad6f5c43d39
This commit is contained in:
commit
8ba1672f8c
@ -191,11 +191,15 @@ class MultiHeadAttention(Layer):
|
|||||||
value: Value `Tensor` of shape `[B, S, dim]`.
|
value: Value `Tensor` of shape `[B, S, dim]`.
|
||||||
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
|
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
|
||||||
`value` for both `key` and `value`, which is the most common case.
|
`value` for both `key` and `value`, which is the most common case.
|
||||||
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
|
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
|
||||||
to certain positions.
|
attention to certain positions.
|
||||||
return_attention_scores: A boolean to indicate whether the output should
|
return_attention_scores: A boolean to indicate whether the output should
|
||||||
be attention output if True, or (attention_output, attention_scores) if
|
be attention output if True, or (attention_output, attention_scores) if
|
||||||
False. Defaults to False.
|
False. Defaults to False.
|
||||||
|
training: Python boolean indicating whether the layer should behave in
|
||||||
|
training mode (adding dropout) or in inference mode (no dropout).
|
||||||
|
Defaults to either using the training mode of the parent layer/model,
|
||||||
|
or False (inference) if there is no parent layer.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
attention_output: The result of the computation, of shape [B, T, E],
|
attention_output: The result of the computation, of shape [B, T, E],
|
||||||
@ -396,7 +400,12 @@ class MultiHeadAttention(Layer):
|
|||||||
attention_mask, axis=mask_expansion_axes)
|
attention_mask, axis=mask_expansion_axes)
|
||||||
return self._softmax(attention_scores, attention_mask)
|
return self._softmax(attention_scores, attention_mask)
|
||||||
|
|
||||||
def _compute_attention(self, query, key, value, attention_mask=None):
|
def _compute_attention(self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask=None,
|
||||||
|
training=None):
|
||||||
"""Applies Dot-product attention with query, key, value tensors.
|
"""Applies Dot-product attention with query, key, value tensors.
|
||||||
|
|
||||||
This function defines the computation inside `call` with projected
|
This function defines the computation inside `call` with projected
|
||||||
@ -409,6 +418,8 @@ class MultiHeadAttention(Layer):
|
|||||||
value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
|
value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
|
||||||
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
|
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
|
||||||
attention to certain positions.
|
attention to certain positions.
|
||||||
|
training: Python boolean indicating whether the layer should behave in
|
||||||
|
training mode (adding dropout) or in inference mode (doing nothing).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
attention_output: Multi-headed outputs of attention computation.
|
attention_output: Multi-headed outputs of attention computation.
|
||||||
@ -428,15 +439,21 @@ class MultiHeadAttention(Layer):
|
|||||||
|
|
||||||
# This is actually dropping out entire tokens to attend to, which might
|
# This is actually dropping out entire tokens to attend to, which might
|
||||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
attention_scores_dropout = self._dropout_layer(attention_scores)
|
attention_scores_dropout = self._dropout_layer(
|
||||||
|
attention_scores, training=training)
|
||||||
|
|
||||||
# `context_layer` = [B, T, N, H]
|
# `context_layer` = [B, T, N, H]
|
||||||
attention_output = special_math_ops.einsum(self._combine_equation,
|
attention_output = special_math_ops.einsum(self._combine_equation,
|
||||||
attention_scores_dropout, value)
|
attention_scores_dropout, value)
|
||||||
return attention_output, attention_scores
|
return attention_output, attention_scores
|
||||||
|
|
||||||
def call(self, query, value, key=None, attention_mask=None,
|
def call(self,
|
||||||
return_attention_scores=False):
|
query,
|
||||||
|
value,
|
||||||
|
key=None,
|
||||||
|
attention_mask=None,
|
||||||
|
return_attention_scores=False,
|
||||||
|
training=None):
|
||||||
if not self._built_from_signature:
|
if not self._built_from_signature:
|
||||||
self._build_from_signature(query=query, value=value, key=key)
|
self._build_from_signature(query=query, value=value, key=key)
|
||||||
if key is None:
|
if key is None:
|
||||||
@ -454,10 +471,9 @@ class MultiHeadAttention(Layer):
|
|||||||
value = self._value_dense(value)
|
value = self._value_dense(value)
|
||||||
|
|
||||||
attention_output, attention_scores = self._compute_attention(
|
attention_output, attention_scores = self._compute_attention(
|
||||||
query, key, value, attention_mask)
|
query, key, value, attention_mask, training)
|
||||||
attention_output = self._output_dense(attention_output)
|
attention_output = self._output_dense(attention_output)
|
||||||
|
|
||||||
if return_attention_scores:
|
if return_attention_scores:
|
||||||
return attention_output, attention_scores
|
return attention_output, attention_scores
|
||||||
return attention_output
|
return attention_output
|
||||||
|
|
||||||
|
@ -225,6 +225,22 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
|
|||||||
model.predict([query, value, mask_data]),
|
model.predict([query, value, mask_data]),
|
||||||
model.predict([query, value, null_mask_data]))
|
model.predict([query, value, null_mask_data]))
|
||||||
|
|
||||||
|
def test_dropout(self):
|
||||||
|
test_layer = multi_head_attention.MultiHeadAttention(
|
||||||
|
num_heads=2, key_dim=2, dropout=0.5)
|
||||||
|
|
||||||
|
# Generate data for the input (non-mask) tensors.
|
||||||
|
from_data = keras.backend.ones(shape=(32, 4, 8))
|
||||||
|
to_data = keras.backend.ones(shape=(32, 2, 8))
|
||||||
|
train_out = test_layer(from_data, to_data, None, None, None, True)
|
||||||
|
test_out = test_layer(from_data, to_data, None, None, None, False)
|
||||||
|
|
||||||
|
# Output should be close when not in training mode,
|
||||||
|
# and should not be close when enabling dropout in training mode.
|
||||||
|
self.assertNotAllClose(
|
||||||
|
keras.backend.eval(train_out),
|
||||||
|
keras.backend.eval(test_out))
|
||||||
|
|
||||||
|
|
||||||
class SubclassAttention(multi_head_attention.MultiHeadAttention):
|
class SubclassAttention(multi_head_attention.MultiHeadAttention):
|
||||||
|
|
||||||
@ -235,7 +251,8 @@ class SubclassAttention(multi_head_attention.MultiHeadAttention):
|
|||||||
query_tensor,
|
query_tensor,
|
||||||
key_tensor,
|
key_tensor,
|
||||||
value_tensor,
|
value_tensor,
|
||||||
attention_mask=None):
|
attention_mask=None,
|
||||||
|
training=None):
|
||||||
return value_tensor, None
|
return value_tensor, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -149,7 +149,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "call"
|
name: "call"
|
||||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "compute_mask"
|
name: "compute_mask"
|
||||||
|
@ -149,7 +149,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "call"
|
name: "call"
|
||||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "compute_mask"
|
name: "compute_mask"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user