diff --git a/tensorflow/python/keras/layers/multi_head_attention.py b/tensorflow/python/keras/layers/multi_head_attention.py index 7ddce8caceb..bda0056fe7e 100644 --- a/tensorflow/python/keras/layers/multi_head_attention.py +++ b/tensorflow/python/keras/layers/multi_head_attention.py @@ -191,11 +191,15 @@ class MultiHeadAttention(Layer): value: Value `Tensor` of shape `[B, S, dim]`. key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use `value` for both `key` and `value`, which is the most common case. - attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention - to certain positions. + attention_mask: a boolean mask of shape `[B, T, S]`, that prevents + attention to certain positions. return_attention_scores: A boolean to indicate whether the output should be attention output if True, or (attention_output, attention_scores) if False. Defaults to False. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (no dropout). + Defaults to either using the training mode of the parent layer/model, + or False (inference) if there is no parent layer. Returns: attention_output: The result of the computation, of shape [B, T, E], @@ -396,7 +400,12 @@ class MultiHeadAttention(Layer): attention_mask, axis=mask_expansion_axes) return self._softmax(attention_scores, attention_mask) - def _compute_attention(self, query, key, value, attention_mask=None): + def _compute_attention(self, + query, + key, + value, + attention_mask=None, + training=None): """Applies Dot-product attention with query, key, value tensors. This function defines the computation inside `call` with projected @@ -409,6 +418,8 @@ class MultiHeadAttention(Layer): value: Projected value `Tensor` of shape `[B, T, N, value_dim]`. attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention to certain positions. + training: Python boolean indicating whether the layer should behave in + training mode (adding dropout) or in inference mode (doing nothing). Returns: attention_output: Multi-headed outputs of attention computation. @@ -428,15 +439,21 @@ class MultiHeadAttention(Layer): # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - attention_scores_dropout = self._dropout_layer(attention_scores) + attention_scores_dropout = self._dropout_layer( + attention_scores, training=training) # `context_layer` = [B, T, N, H] attention_output = special_math_ops.einsum(self._combine_equation, attention_scores_dropout, value) return attention_output, attention_scores - def call(self, query, value, key=None, attention_mask=None, - return_attention_scores=False): + def call(self, + query, + value, + key=None, + attention_mask=None, + return_attention_scores=False, + training=None): if not self._built_from_signature: self._build_from_signature(query=query, value=value, key=key) if key is None: @@ -454,10 +471,9 @@ class MultiHeadAttention(Layer): value = self._value_dense(value) attention_output, attention_scores = self._compute_attention( - query, key, value, attention_mask) + query, key, value, attention_mask, training) attention_output = self._output_dense(attention_output) if return_attention_scores: return attention_output, attention_scores return attention_output - diff --git a/tensorflow/python/keras/layers/multi_head_attention_test.py b/tensorflow/python/keras/layers/multi_head_attention_test.py index a50fefd05ba..4c957b8973b 100644 --- a/tensorflow/python/keras/layers/multi_head_attention_test.py +++ b/tensorflow/python/keras/layers/multi_head_attention_test.py @@ -225,6 +225,22 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase): model.predict([query, value, mask_data]), model.predict([query, value, null_mask_data])) + def test_dropout(self): + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=2, key_dim=2, dropout=0.5) + + # Generate data for the input (non-mask) tensors. + from_data = keras.backend.ones(shape=(32, 4, 8)) + to_data = keras.backend.ones(shape=(32, 2, 8)) + train_out = test_layer(from_data, to_data, None, None, None, True) + test_out = test_layer(from_data, to_data, None, None, None, False) + + # Output should be close when not in training mode, + # and should not be close when enabling dropout in training mode. + self.assertNotAllClose( + keras.backend.eval(train_out), + keras.backend.eval(test_out)) + class SubclassAttention(multi_head_attention.MultiHeadAttention): @@ -235,7 +251,8 @@ class SubclassAttention(multi_head_attention.MultiHeadAttention): query_tensor, key_tensor, value_tensor, - attention_mask=None): + attention_mask=None, + training=None): return value_tensor, None diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt index 070ee20ab30..89fbb32194a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt @@ -149,7 +149,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt index 070ee20ab30..89fbb32194a 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt @@ -149,7 +149,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], " } member_method { name: "compute_mask"