From 5e90f547a295efeaff92a8c1ef2b4da568485124 Mon Sep 17 00:00:00 2001 From: Hongkun Yu Date: Mon, 11 Jan 2021 16:17:52 -0800 Subject: [PATCH] Update keras MultiHeadAttention attention mask docstring to state 1 is to attend and 0 is to mask. https://github.com/tensorflow/tensorflow/issues/45854 PiperOrigin-RevId: 351258459 Change-Id: I1b6631cc7297d8a8077e754f45d701a03dccfe3e --- tensorflow/python/keras/layers/advanced_activations.py | 5 +++-- tensorflow/python/keras/layers/multi_head_attention.py | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index f588dad1574..6a0ae73b9b7 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -308,7 +308,8 @@ class Softmax(Layer): normalization is applied. Call arguments: inputs: The inputs, or logits to the softmax layer. - mask: A boolean mask of the same shape as `inputs`. Defaults to `None`. + mask: A boolean mask of the same shape as `inputs`. Defaults to `None`. The + mask specifies 1 to keep and 0 to mask. Returns: softmaxed output with the same shape as `inputs`. @@ -321,7 +322,7 @@ class Softmax(Layer): def call(self, inputs, mask=None): if mask is not None: - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # Since mask is 1.0 for positions we want to keep and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -1e.9 for masked positions. adder = (1.0 - math_ops.cast(mask, inputs.dtype)) * ( diff --git a/tensorflow/python/keras/layers/multi_head_attention.py b/tensorflow/python/keras/layers/multi_head_attention.py index 3f7ff856bc0..d57ce570eb1 100644 --- a/tensorflow/python/keras/layers/multi_head_attention.py +++ b/tensorflow/python/keras/layers/multi_head_attention.py @@ -193,7 +193,10 @@ class MultiHeadAttention(Layer): key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use `value` for both `key` and `value`, which is the most common case. attention_mask: a boolean mask of shape `[B, T, S]`, that prevents - attention to certain positions. + attention to certain positions. The boolean mask specifies which query + elements can attend to which key elements, 1 indicates attention and 0 + indicates no attention. Broadcasting can happen for the missing batch + dimensions and the head dimension. return_attention_scores: A boolean to indicate whether the output should be attention output if True, or (attention_output, attention_scores) if False. Defaults to False.