Add MultiHeadAttention Layer for Keras.

PiperOrigin-RevId: 326359296
Change-Id: Iacdc310f66aa1848b068fa3f0fc8784ea7b80ef5
This commit is contained in:
Zhenyu Tan 2020-08-12 18:44:28 -07:00 committed by TensorFlower Gardener
parent 2dc8f029bf
commit f32c80b3ed
12 changed files with 1232 additions and 5 deletions

View File

@ -39,6 +39,7 @@ py_library(
":kernelized",
":local",
":merge",
":multi_head_attention",
":noise",
":normalization",
":normalization_v2",
@ -207,6 +208,22 @@ py_library(
],
)
py_library(
name = "multi_head_attention",
srcs = ["multi_head_attention.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:special_math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/keras:activations",
"//tensorflow/python/keras:base_layer",
"//tensorflow/python/keras:constraints",
"//tensorflow/python/keras:initializers",
"//tensorflow/python/keras:regularizers",
],
)
py_library(
name = "embeddings",
srcs = ["embeddings.py"],
@ -590,6 +607,18 @@ tf_py_test(
],
)
tf_py_test(
name = "multi_head_attention_test",
srcs = ["multi_head_attention_test.py"],
python_version = "PY3",
deps = [
":multi_head_attention",
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "embeddings_test",
size = "medium",

View File

@ -143,6 +143,9 @@ from tensorflow.python.keras.layers.embeddings import Embedding
# Einsum-based dense layer/
from tensorflow.python.keras.layers.einsum_dense import EinsumDense
# Multi-head Attention layer.
from tensorflow.python.keras.layers.multi_head_attention import MultiHeadAttention
# Locally-connected layers.
from tensorflow.python.keras.layers.local import LocallyConnected1D
from tensorflow.python.keras.layers.local import LocallyConnected2D

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
@ -259,10 +260,37 @@ class ThresholdedReLU(Layer):
return input_shape
def _large_compatible_negative(tensor_type):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
if tensor_type == dtypes.float16:
return dtypes.float16.min
return -1e9
@keras_export('keras.layers.Softmax')
class Softmax(Layer):
"""Softmax activation function.
Example without mask:
>>> inp = np.asarray([1., 2., 1.])
>>> layer = tf.keras.layers.Softmax()
>>> layer(inp).numpy()
array([0.21194157, 0.5761169 , 0.21194157], dtype=float32)
>>> mask = np.asarray([True, False, True], dtype=bool)
>>> layer(inp, mask).numpy()
array([0.5, 0. , 0.5], dtype=float32)
Input shape:
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
@ -272,7 +300,14 @@ class Softmax(Layer):
Same shape as the input.
Arguments:
axis: Integer, axis along which the softmax normalization is applied.
axis: Integer, or list of Integers, axis along which the softmax
normalization is applied.
Call arguments:
inputs: The inputs, or logits to the softmax layer.
mask: A boolean mask of the same shape as `inputs`. Defaults to `None`.
Returns:
softmaxed output with the same shape as `inputs`.
"""
def __init__(self, axis=-1, **kwargs):
@ -280,7 +315,23 @@ class Softmax(Layer):
self.supports_masking = True
self.axis = axis
def call(self, inputs):
def call(self, inputs, mask=None):
if mask is not None:
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -1e.9 for masked positions.
adder = (1.0 - math_ops.cast(mask, inputs.dtype)) * (
_large_compatible_negative(inputs.dtype))
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
inputs += adder
if isinstance(self.axis, (tuple, list)):
if len(self.axis) > 1:
return math_ops.exp(inputs - math_ops.reduce_logsumexp(
inputs, axis=self.axis, keepdims=True))
else:
return K.softmax(inputs, axis=self.axis[0])
return K.softmax(inputs, axis=self.axis)
def get_config(self):

View File

@ -0,0 +1,460 @@
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based attention layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
import string
import numpy as np
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.layers import advanced_activations
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import einsum_dense
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.util.tf_export import keras_export
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
(bs, <non-attention dims>, <attention dims>, num_heads, channels).
bs and <non-attention dims> are treated as <batch dims>.
The attention operations can be generalized:
(1) Query-key dot product:
(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)
(2) Combination:
(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
<query attention dims>, num_heads, channels)
Args:
rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns:
Einsum equations.
"""
target_notation = _CHR_IDX[:rank]
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
letter_offset = rank
source_notation = ""
for i in range(rank):
if i in batch_dims or i == rank - 1:
source_notation += target_notation[i]
else:
source_notation += _CHR_IDX[letter_offset]
letter_offset += 1
product_notation = "".join([target_notation[i] for i in batch_dims] +
[target_notation[i] for i in attn_axes] +
[source_notation[i] for i in attn_axes])
dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
product_notation)
attn_scores_rank = len(product_notation)
combine_equation = "%s,%s->%s" % (product_notation, source_notation,
target_notation)
return dot_product_equation, combine_equation, attn_scores_rank
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
return equation, bias_axes, len(output_str)
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
@keras_export("keras.layers.MultiHeadAttention")
class MultiHeadAttention(Layer):
"""MultiHeadAttention layer.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `query`, `key,` `value` are the same, then
this is self-attention. Each timestep in `query` attends to the
corresponding sequence in `key`, and returns a fixed-width vector.
This layer first projects `query`, `key` and `value`. These are
(effectively) a list of tensors of length `num_attention_heads`, where the
corresponding shapes are [batch_size, <query dimensions>, key_dim],
[batch_size, <key/value dimensions>, key_dim],
[batch_size, <key/value dimensions>, value_dim].
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor.
Finally, the result tensor with the last dimension as value_dim can take an
linear projection and return.
Examples:
Performs 1D cross-attention over two sequence inputs with an attention mask.
Returns the additional attention weights over heads.
>>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
>>> target = tf.keras.Input(shape=[8, 16])
>>> source = tf.keras.Input(shape=[4, 16])
>>> output_tensor, weights = layer(target, source,
... return_attention_scores=True)
>>> print(output_tensor.shape)
(None, 8, 16)
>>> print(weights.shape)
(None, 2, 8, 4)
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
>>> layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3))
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
>>> output_tensor = layer(input_tensor, input_tensor)
>>> print(output_tensor.shape)
(None, 5, 3, 4, 16)
Arguments:
num_heads: Number of attention heads.
key_dim: Size of each attention head for query and key.
value_dim: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
Call arguments:
query: Query `Tensor` of shape `[B, T, dim]`.
value: Value `Tensor` of shape `[B, S, dim]`.
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
`value` for both `key` and `value`, which is the most common case.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
to certain positions.
return_attention_scores: A boolean to indicate whether the output should
be attention output if True, or (attention_output, attention_scores) if
False. Defaults to False.
Returns:
attention_output: The result of the computation, of shape [B, T, E],
where `T` is for target sequence shapes and `E` is the query input last
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
are project to the shape specified by `output_shape`.
attention_scores: [Optional] multi-head attention coeffients over
attention axes.
"""
def __init__(self,
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._key_dim = key_dim
self._value_dim = value_dim if value_dim else key_dim
self._dropout = dropout
self._use_bias = use_bias
self._output_shape = output_shape
self._kernel_initializer = initializers.get(kernel_initializer)
self._bias_initializer = initializers.get(bias_initializer)
self._kernel_regularizer = regularizers.get(kernel_regularizer)
self._bias_regularizer = regularizers.get(bias_regularizer)
self._kernel_constraint = constraints.get(kernel_constraint)
self._bias_constraint = constraints.get(bias_constraint)
if attention_axes is not None and not isinstance(attention_axes,
collections.abc.Sized):
self._attention_axes = (attention_axes,)
else:
self._attention_axes = attention_axes
self._built_from_signature = False
def get_config(self):
config = {
"num_heads":
self._num_heads,
"key_dim":
self._key_dim,
"value_dim":
self._value_dim,
"dropout":
self._dropout,
"use_bias":
self._use_bias,
"output_shape":
self._output_shape,
"attention_axes":
self._attention_axes,
"kernel_initializer":
initializers.serialize(self._kernel_initializer),
"bias_initializer":
initializers.serialize(self._bias_initializer),
"kernel_regularizer":
regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
constraints.serialize(self._kernel_constraint),
"bias_constraint":
constraints.serialize(self._bias_constraint)
}
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def _build_from_signature(self, query, value, key=None):
"""Builds layers and variables.
Once the method is called, self._built_from_signature will be set to True.
Args:
query: query tensor or TensorShape.
value: value tensor or TensorShape.
key: key tensor or TensorShape.
"""
self._built_from_signature = True
if hasattr(query, "shape"):
query_shape = tensor_shape.TensorShape(query.shape)
else:
query_shape = query
if hasattr(value, "shape"):
value_shape = tensor_shape.TensorShape(value.shape)
else:
value_shape = value
if key is None:
key_shape = value_shape
elif hasattr(key, "shape"):
key_shape = tensor_shape.TensorShape(key.shape)
else:
key_shape = key
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
# Any setup work performed only once should happen in an `init_scope`
# to avoid creating symbolic Tensors that will later pollute any eager
# operations.
with tf_utils.maybe_init_scope(self):
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = einsum_dense.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = einsum_dense.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = einsum_dense.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1,
[self._num_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**common_kwargs)
# Builds the attention computations for multi-head dot product attention.
# These computations could be wrapped into the keras attention layer once
# it support mult-head einsum computations.
self._build_attention(output_rank)
if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else:
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
self._output_dense = einsum_dense.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
def _build_attention(self, rank):
"""Builds multi-head dot-product attention computations.
This function builds attributes necessary for `_compute_attention` to
costomize attention computation to replace the default dot-product
attention.
Args:
rank: the rank of query, key, value tensors.
"""
if self._attention_axes is None:
self._attention_axes = tuple(range(1, rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
self._dot_product_equation, self._combine_equation, attn_scores_rank = (
_build_attention_equation(rank, attn_axes=self._attention_axes))
norm_axes = tuple(
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
self._masked_softmax = advanced_activations.Softmax(axis=norm_axes)
self._dropout_layer = core.Dropout(rate=self._dropout)
def _compute_attention(self, query, key, value, attention_mask=None):
"""Applies Dot-product attention with query, key, value tensors.
This function defines the computation inside `call` with projected
multi-head Q, K, V inputs. Users can override this function for customized
attention implementation.
Args:
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
key: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
attention to certain positions.
Returns:
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = math_ops.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = special_math_ops.einsum(self._dot_product_equation, key,
query)
# Normalize the attention scores to probabilities.
# `attention_scores` = [B, N, T, S]
if attention_mask is not None:
# The expand dim happens starting from the `num_heads` dimension,
# (<batch_dims>, num_heads, <query_attention_dims, key_attention_dims>)
mask_expansion_axes = [-len(self._attention_axes) * 2 - 1]
for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
attention_mask = array_ops.expand_dims(
attention_mask, axis=mask_expansion_axes)
attention_scores = self._masked_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_scores_dropout = self._dropout_layer(attention_scores)
# `context_layer` = [B, T, N, H]
attention_output = special_math_ops.einsum(self._combine_equation,
attention_scores_dropout, value)
return attention_output, attention_scores
def call(self, query, value, key=None, attention_mask=None,
return_attention_scores=False):
if not self._built_from_signature:
self._build_from_signature(query=query, value=value, key=key)
if key is None:
key = value
# N = `num_attention_heads`
# H = `size_per_head`
# `query` = [B, T, N ,H]
query = self._query_dense(query)
# `key` = [B, S, N, H]
key = self._key_dense(key)
# `value` = [B, S, N, H]
value = self._value_dense(value)
attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask)
attention_output = self._output_dense(attention_output)
if return_attention_scores:
return attention_output, attention_scores
return attention_output

View File

@ -0,0 +1,230 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the attention layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras.layers import multi_head_attention
from tensorflow.python.platform import test
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
# guarantees forward compatibility of this code for the V2 switchover.
@keras_parameterized.run_all_keras_modes
class MultiHeadAttentionTest(keras_parameterized.TestCase):
@parameterized.named_parameters(
("key_value_same_proj", None, None, [40, 80]),
("key_value_different_proj", 32, 60, [40, 60]),
)
def test_non_masked_attention(self, value_dim, output_shape, output_dims):
"""Test that the attention layer can be created without a mask tensor."""
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=12,
key_dim=64,
value_dim=value_dim,
output_shape=output_shape)
# Create a 3-dimensional input (the first dimension is implicit).
query = keras.Input(shape=(40, 80))
value = keras.Input(shape=(20, 80))
output = test_layer(query=query, value=value)
self.assertEqual(output.shape.as_list(), [None] + output_dims)
def test_non_masked_self_attention(self):
"""Test with one input (self-attenntion) and no mask tensor."""
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = keras.Input(shape=(40, 80))
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_attention_scores(self):
"""Test attention outputs with coefficients."""
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = keras.Input(shape=(40, 80))
output, coef = test_layer(query, query, return_attention_scores=True)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
def test_attention_scores_with_values(self):
"""Test attention outputs with coefficients."""
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=12, key_dim=64)
# Create a 3-dimensional input (the first dimension is implicit).
query = keras.Input(shape=(40, 80))
value = keras.Input(shape=(60, 80))
output, coef = test_layer(query, value, return_attention_scores=True)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 60])
@parameterized.named_parameters(("with_bias", True), ("no_bias", False))
def test_masked_attention(self, use_bias):
"""Test with a mask tensor."""
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=2, key_dim=2, use_bias=use_bias)
# Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = keras.Input(shape=(4, 8))
value = keras.Input(shape=(2, 8))
mask_tensor = keras.Input(shape=(4, 2))
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = keras.Model([query, value, mask_tensor], output)
# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data])
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
# Tests the layer with three inputs: Q, K, V.
key = keras.Input(shape=(2, 8))
output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
model = keras.Model([query, value, key, mask_tensor], output)
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
unmasked_output_data = model.predict(
[from_data, to_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
if use_bias:
self.assertLen(test_layer._query_dense.trainable_variables, 2)
self.assertLen(test_layer._output_dense.trainable_variables, 2)
else:
self.assertLen(test_layer._query_dense.trainable_variables, 1)
self.assertLen(test_layer._output_dense.trainable_variables, 1)
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=12,
key_dim=64,
kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02))
# Create a 3-dimensional input (the first dimension is implicit).
query = keras.Input(shape=(40, 80))
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])
def test_masked_attention_with_scores(self):
"""Test with a mask tensor."""
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=2, key_dim=2)
# Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = keras.Input(shape=(4, 8))
value = keras.Input(shape=(2, 8))
mask_tensor = keras.Input(shape=(4, 2))
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
# Create a model containing the test layer.
model = keras.Model([query, value, mask_tensor], output)
# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
masked_output_data = model.predict([from_data, to_data, mask_data])
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones((batch_size, 4, 2))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
# Because one data is masked and one is not, the outputs should not be the
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
# Create a model containing attention scores.
output, scores = test_layer(
query=query, value=value, attention_mask=mask_tensor,
return_attention_scores=True)
model = keras.Model([query, value, mask_tensor], [output, scores])
masked_output_data_score, masked_score = model.predict(
[from_data, to_data, mask_data])
unmasked_output_data_score, unmasked_score = model.predict(
[from_data, to_data, null_mask_data])
self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score)
self.assertAllClose(masked_output_data, masked_output_data_score)
self.assertAllClose(unmasked_output_data, unmasked_output_data_score)
self.assertNotAllClose(masked_score, unmasked_score)
@parameterized.named_parameters(
("4d_inputs_1freebatch_mask2", [3, 4], [3, 2], [4, 2],
(2,)), ("4d_inputs_1freebatch_mask3", [3, 4], [3, 2], [3, 4, 2], (2,)),
("4d_inputs_1freebatch_mask4", [3, 4], [3, 2], [3, 2, 4, 2],
(2,)), ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)),
("5D_inputs_2D_attention_fullmask", [5, 3, 4], [5, 3, 2], [5, 3, 4, 3, 2],
(2, 3)))
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
"""Test with a mask tensor."""
test_layer = multi_head_attention.MultiHeadAttention(
num_heads=2, key_dim=2, attention_axes=attention_axes)
batch_size, hidden_size = 3, 8
# Generate data for the input (non-mask) tensors.
query_shape = [batch_size] + q_dims + [hidden_size]
value_shape = [batch_size] + v_dims + [hidden_size]
mask_shape = [batch_size] + mask_dims
query = 10 * np.random.random_sample(query_shape)
value = 10 * np.random.random_sample(value_shape)
# Invoke the data with a random set of mask data. This should mask at least
# one element.
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
# Invoke the same data, but with a null mask (where no elements are masked).
null_mask_data = np.ones(mask_shape)
# Because one data is masked and one is not, the outputs should not be the
# same.
query_tensor = keras.Input(query_shape[1:], name="query")
value_tensor = keras.Input(value_shape[1:], name="value")
mask_tensor = keras.Input(mask_shape[1:], name="mask")
output = test_layer(query=query_tensor, value=value_tensor,
attention_mask=mask_tensor)
model = keras.Model([query_tensor, value_tensor, mask_tensor], output)
self.assertNotAllClose(
model.predict([query, value, mask_data]),
model.predict([query, value, null_mask_data]))
if __name__ == "__main__":
test.main()

View File

@ -37,6 +37,7 @@ from tensorflow.python.keras.layers import einsum_dense
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.keras.layers import local
from tensorflow.python.keras.layers import merge
from tensorflow.python.keras.layers import multi_head_attention
from tensorflow.python.keras.layers import noise
from tensorflow.python.keras.layers import normalization
from tensorflow.python.keras.layers import normalization_v2
@ -70,7 +71,8 @@ ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional,
pooling, image_preprocessing, preprocessing_integer_lookup_v1,
preprocessing_normalization_v1, preprocessing_string_lookup_v1,
preprocessing_text_vectorization_v1, recurrent, wrappers,
hashing, category_crossing, category_encoding_v1, discretization)
hashing, category_crossing, category_encoding_v1, discretization,
multi_head_attention)
ALL_V2_MODULES = (rnn_cell_wrapper_v2, normalization_v2, recurrent_v2,
preprocessing_integer_lookup, preprocessing_normalization,
preprocessing_string_lookup, preprocessing_text_vectorization,

View File

@ -0,0 +1,222 @@
path: "tensorflow.keras.layers.MultiHeadAttention"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.multi_head_attention.MultiHeadAttention\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_mask"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "metrics"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "outbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_mask"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
}
member {
name: "supports_masking"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_metric"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_output_signature"
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -149,7 +149,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"

View File

@ -312,6 +312,10 @@ tf_module {
name: "Minimum"
mtype: "<type \'type\'>"
}
member {
name: "MultiHeadAttention"
mtype: "<type \'type\'>"
}
member {
name: "Multiply"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,222 @@
path: "tensorflow.keras.layers.MultiHeadAttention"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.multi_head_attention.MultiHeadAttention\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_mask"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "metrics"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "outbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_mask"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
}
member {
name: "supports_masking"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_metric"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_output_signature"
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -149,7 +149,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"

View File

@ -304,6 +304,10 @@ tf_module {
name: "Minimum"
mtype: "<type \'type\'>"
}
member {
name: "MultiHeadAttention"
mtype: "<type \'type\'>"
}
member {
name: "Multiply"
mtype: "<type \'type\'>"