Add MultiHeadAttention Layer for Keras.
PiperOrigin-RevId: 326359296 Change-Id: Iacdc310f66aa1848b068fa3f0fc8784ea7b80ef5
This commit is contained in:
parent
2dc8f029bf
commit
f32c80b3ed
tensorflow
python/keras/layers
BUILD__init__.pyadvanced_activations.pymulti_head_attention.pymulti_head_attention_test.pyserialization.py
tools/api/golden
@ -39,6 +39,7 @@ py_library(
|
||||
":kernelized",
|
||||
":local",
|
||||
":merge",
|
||||
":multi_head_attention",
|
||||
":noise",
|
||||
":normalization",
|
||||
":normalization_v2",
|
||||
@ -207,6 +208,22 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "multi_head_attention",
|
||||
srcs = ["multi_head_attention.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:special_math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/keras:activations",
|
||||
"//tensorflow/python/keras:base_layer",
|
||||
"//tensorflow/python/keras:constraints",
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras:regularizers",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "embeddings",
|
||||
srcs = ["embeddings.py"],
|
||||
@ -590,6 +607,18 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "multi_head_attention_test",
|
||||
srcs = ["multi_head_attention_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":multi_head_attention",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "embeddings_test",
|
||||
size = "medium",
|
||||
|
@ -143,6 +143,9 @@ from tensorflow.python.keras.layers.embeddings import Embedding
|
||||
# Einsum-based dense layer/
|
||||
from tensorflow.python.keras.layers.einsum_dense import EinsumDense
|
||||
|
||||
# Multi-head Attention layer.
|
||||
from tensorflow.python.keras.layers.multi_head_attention import MultiHeadAttention
|
||||
|
||||
# Locally-connected layers.
|
||||
from tensorflow.python.keras.layers.local import LocallyConnected1D
|
||||
from tensorflow.python.keras.layers.local import LocallyConnected2D
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import constraints
|
||||
from tensorflow.python.keras import initializers
|
||||
@ -259,10 +260,37 @@ class ThresholdedReLU(Layer):
|
||||
return input_shape
|
||||
|
||||
|
||||
def _large_compatible_negative(tensor_type):
|
||||
"""Large negative number as Tensor.
|
||||
|
||||
This function is necessary because the standard value for epsilon
|
||||
in this module (-1e9) cannot be represented using tf.float16
|
||||
|
||||
Args:
|
||||
tensor_type: a dtype to determine the type.
|
||||
|
||||
Returns:
|
||||
a large negative number.
|
||||
"""
|
||||
if tensor_type == dtypes.float16:
|
||||
return dtypes.float16.min
|
||||
return -1e9
|
||||
|
||||
|
||||
@keras_export('keras.layers.Softmax')
|
||||
class Softmax(Layer):
|
||||
"""Softmax activation function.
|
||||
|
||||
Example without mask:
|
||||
|
||||
>>> inp = np.asarray([1., 2., 1.])
|
||||
>>> layer = tf.keras.layers.Softmax()
|
||||
>>> layer(inp).numpy()
|
||||
array([0.21194157, 0.5761169 , 0.21194157], dtype=float32)
|
||||
>>> mask = np.asarray([True, False, True], dtype=bool)
|
||||
>>> layer(inp, mask).numpy()
|
||||
array([0.5, 0. , 0.5], dtype=float32)
|
||||
|
||||
Input shape:
|
||||
Arbitrary. Use the keyword argument `input_shape`
|
||||
(tuple of integers, does not include the samples axis)
|
||||
@ -272,7 +300,14 @@ class Softmax(Layer):
|
||||
Same shape as the input.
|
||||
|
||||
Arguments:
|
||||
axis: Integer, axis along which the softmax normalization is applied.
|
||||
axis: Integer, or list of Integers, axis along which the softmax
|
||||
normalization is applied.
|
||||
Call arguments:
|
||||
inputs: The inputs, or logits to the softmax layer.
|
||||
mask: A boolean mask of the same shape as `inputs`. Defaults to `None`.
|
||||
|
||||
Returns:
|
||||
softmaxed output with the same shape as `inputs`.
|
||||
"""
|
||||
|
||||
def __init__(self, axis=-1, **kwargs):
|
||||
@ -280,7 +315,23 @@ class Softmax(Layer):
|
||||
self.supports_masking = True
|
||||
self.axis = axis
|
||||
|
||||
def call(self, inputs):
|
||||
def call(self, inputs, mask=None):
|
||||
if mask is not None:
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -1e.9 for masked positions.
|
||||
adder = (1.0 - math_ops.cast(mask, inputs.dtype)) * (
|
||||
_large_compatible_negative(inputs.dtype))
|
||||
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
inputs += adder
|
||||
if isinstance(self.axis, (tuple, list)):
|
||||
if len(self.axis) > 1:
|
||||
return math_ops.exp(inputs - math_ops.reduce_logsumexp(
|
||||
inputs, axis=self.axis, keepdims=True))
|
||||
else:
|
||||
return K.softmax(inputs, axis=self.axis[0])
|
||||
return K.softmax(inputs, axis=self.axis)
|
||||
|
||||
def get_config(self):
|
||||
|
460
tensorflow/python/keras/layers/multi_head_attention.py
Normal file
460
tensorflow/python/keras/layers/multi_head_attention.py
Normal file
@ -0,0 +1,460 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras-based attention layer."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import math
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import constraints
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.keras.layers import advanced_activations
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import einsum_dense
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
_CHR_IDX = string.ascii_lowercase
|
||||
|
||||
|
||||
def _build_attention_equation(rank, attn_axes):
|
||||
"""Builds einsum equations for the attention computation.
|
||||
|
||||
Query, key, value inputs after projection are expected to have the shape as:
|
||||
(bs, <non-attention dims>, <attention dims>, num_heads, channels).
|
||||
bs and <non-attention dims> are treated as <batch dims>.
|
||||
The attention operations can be generalized:
|
||||
(1) Query-key dot product:
|
||||
(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
|
||||
<key attention dims>, num_heads, channels) -> (<batch dims>,
|
||||
num_heads, <query attention dims>, <key attention dims>)
|
||||
(2) Combination:
|
||||
(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
|
||||
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
|
||||
<query attention dims>, num_heads, channels)
|
||||
|
||||
Args:
|
||||
rank: the rank of query, key, value tensors.
|
||||
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
|
||||
|
||||
Returns:
|
||||
Einsum equations.
|
||||
"""
|
||||
target_notation = _CHR_IDX[:rank]
|
||||
# `batch_dims` includes the head dim.
|
||||
batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))
|
||||
letter_offset = rank
|
||||
source_notation = ""
|
||||
for i in range(rank):
|
||||
if i in batch_dims or i == rank - 1:
|
||||
source_notation += target_notation[i]
|
||||
else:
|
||||
source_notation += _CHR_IDX[letter_offset]
|
||||
letter_offset += 1
|
||||
|
||||
product_notation = "".join([target_notation[i] for i in batch_dims] +
|
||||
[target_notation[i] for i in attn_axes] +
|
||||
[source_notation[i] for i in attn_axes])
|
||||
dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
|
||||
product_notation)
|
||||
attn_scores_rank = len(product_notation)
|
||||
combine_equation = "%s,%s->%s" % (product_notation, source_notation,
|
||||
target_notation)
|
||||
return dot_product_equation, combine_equation, attn_scores_rank
|
||||
|
||||
|
||||
def _build_proj_equation(free_dims, bound_dims, output_dims):
|
||||
"""Builds an einsum equation for projections inside multi-head attention."""
|
||||
input_str = ""
|
||||
kernel_str = ""
|
||||
output_str = ""
|
||||
bias_axes = ""
|
||||
letter_offset = 0
|
||||
for i in range(free_dims):
|
||||
char = _CHR_IDX[i + letter_offset]
|
||||
input_str += char
|
||||
output_str += char
|
||||
|
||||
letter_offset += free_dims
|
||||
for i in range(bound_dims):
|
||||
char = _CHR_IDX[i + letter_offset]
|
||||
input_str += char
|
||||
kernel_str += char
|
||||
|
||||
letter_offset += bound_dims
|
||||
for i in range(output_dims):
|
||||
char = _CHR_IDX[i + letter_offset]
|
||||
kernel_str += char
|
||||
output_str += char
|
||||
bias_axes += char
|
||||
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
|
||||
|
||||
return equation, bias_axes, len(output_str)
|
||||
|
||||
|
||||
def _get_output_shape(output_rank, known_last_dims):
|
||||
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
|
||||
|
||||
|
||||
@keras_export("keras.layers.MultiHeadAttention")
|
||||
class MultiHeadAttention(Layer):
|
||||
"""MultiHeadAttention layer.
|
||||
|
||||
This is an implementation of multi-headed attention based on "Attention
|
||||
is all you Need". If `query`, `key,` `value` are the same, then
|
||||
this is self-attention. Each timestep in `query` attends to the
|
||||
corresponding sequence in `key`, and returns a fixed-width vector.
|
||||
|
||||
This layer first projects `query`, `key` and `value`. These are
|
||||
(effectively) a list of tensors of length `num_attention_heads`, where the
|
||||
corresponding shapes are [batch_size, <query dimensions>, key_dim],
|
||||
[batch_size, <key/value dimensions>, key_dim],
|
||||
[batch_size, <key/value dimensions>, value_dim].
|
||||
|
||||
Then, the query and key tensors are dot-producted and scaled. These are
|
||||
softmaxed to obtain attention probabilities. The value tensors are then
|
||||
interpolated by these probabilities, then concatenated back to a single
|
||||
tensor.
|
||||
|
||||
Finally, the result tensor with the last dimension as value_dim can take an
|
||||
linear projection and return.
|
||||
|
||||
Examples:
|
||||
|
||||
Performs 1D cross-attention over two sequence inputs with an attention mask.
|
||||
Returns the additional attention weights over heads.
|
||||
|
||||
>>> layer = MultiHeadAttention(num_heads=2, key_dim=2)
|
||||
>>> target = tf.keras.Input(shape=[8, 16])
|
||||
>>> source = tf.keras.Input(shape=[4, 16])
|
||||
>>> output_tensor, weights = layer(target, source,
|
||||
... return_attention_scores=True)
|
||||
>>> print(output_tensor.shape)
|
||||
(None, 8, 16)
|
||||
>>> print(weights.shape)
|
||||
(None, 2, 8, 4)
|
||||
|
||||
Performs 2D self-attention over a 5D input tensor on axes 2 and 3.
|
||||
|
||||
>>> layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3))
|
||||
>>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16])
|
||||
>>> output_tensor = layer(input_tensor, input_tensor)
|
||||
>>> print(output_tensor.shape)
|
||||
(None, 5, 3, 4, 16)
|
||||
|
||||
Arguments:
|
||||
num_heads: Number of attention heads.
|
||||
key_dim: Size of each attention head for query and key.
|
||||
value_dim: Size of each attention head for value.
|
||||
dropout: Dropout probability.
|
||||
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
|
||||
output_shape: The expected shape of an output tensor, besides the batch and
|
||||
sequence dims. If not specified, projects back to the key feature dim.
|
||||
attention_axes: axes over which the attention is applied. `None` means
|
||||
attention over all axes, but batch, heads, and features.
|
||||
kernel_initializer: Initializer for dense layer kernels.
|
||||
bias_initializer: Initializer for dense layer biases.
|
||||
kernel_regularizer: Regularizer for dense layer kernels.
|
||||
bias_regularizer: Regularizer for dense layer biases.
|
||||
activity_regularizer: Regularizer for dense layer activity.
|
||||
kernel_constraint: Constraint for dense layer kernels.
|
||||
bias_constraint: Constraint for dense layer kernels.
|
||||
|
||||
Call arguments:
|
||||
query: Query `Tensor` of shape `[B, T, dim]`.
|
||||
value: Value `Tensor` of shape `[B, S, dim]`.
|
||||
key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use
|
||||
`value` for both `key` and `value`, which is the most common case.
|
||||
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention
|
||||
to certain positions.
|
||||
return_attention_scores: A boolean to indicate whether the output should
|
||||
be attention output if True, or (attention_output, attention_scores) if
|
||||
False. Defaults to False.
|
||||
|
||||
Returns:
|
||||
attention_output: The result of the computation, of shape [B, T, E],
|
||||
where `T` is for target sequence shapes and `E` is the query input last
|
||||
dimension if `output_shape` is `None`. Otherwise, the multi-head outputs
|
||||
are project to the shape specified by `output_shape`.
|
||||
attention_scores: [Optional] multi-head attention coeffients over
|
||||
attention axes.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_heads,
|
||||
key_dim,
|
||||
value_dim=None,
|
||||
dropout=0.0,
|
||||
use_bias=True,
|
||||
output_shape=None,
|
||||
attention_axes=None,
|
||||
kernel_initializer="glorot_uniform",
|
||||
bias_initializer="zeros",
|
||||
kernel_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
super(MultiHeadAttention, self).__init__(**kwargs)
|
||||
self._num_heads = num_heads
|
||||
self._key_dim = key_dim
|
||||
self._value_dim = value_dim if value_dim else key_dim
|
||||
self._dropout = dropout
|
||||
self._use_bias = use_bias
|
||||
self._output_shape = output_shape
|
||||
self._kernel_initializer = initializers.get(kernel_initializer)
|
||||
self._bias_initializer = initializers.get(bias_initializer)
|
||||
self._kernel_regularizer = regularizers.get(kernel_regularizer)
|
||||
self._bias_regularizer = regularizers.get(bias_regularizer)
|
||||
self._kernel_constraint = constraints.get(kernel_constraint)
|
||||
self._bias_constraint = constraints.get(bias_constraint)
|
||||
if attention_axes is not None and not isinstance(attention_axes,
|
||||
collections.abc.Sized):
|
||||
self._attention_axes = (attention_axes,)
|
||||
else:
|
||||
self._attention_axes = attention_axes
|
||||
self._built_from_signature = False
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"num_heads":
|
||||
self._num_heads,
|
||||
"key_dim":
|
||||
self._key_dim,
|
||||
"value_dim":
|
||||
self._value_dim,
|
||||
"dropout":
|
||||
self._dropout,
|
||||
"use_bias":
|
||||
self._use_bias,
|
||||
"output_shape":
|
||||
self._output_shape,
|
||||
"attention_axes":
|
||||
self._attention_axes,
|
||||
"kernel_initializer":
|
||||
initializers.serialize(self._kernel_initializer),
|
||||
"bias_initializer":
|
||||
initializers.serialize(self._bias_initializer),
|
||||
"kernel_regularizer":
|
||||
regularizers.serialize(self._kernel_regularizer),
|
||||
"bias_regularizer":
|
||||
regularizers.serialize(self._bias_regularizer),
|
||||
"activity_regularizer":
|
||||
regularizers.serialize(self._activity_regularizer),
|
||||
"kernel_constraint":
|
||||
constraints.serialize(self._kernel_constraint),
|
||||
"bias_constraint":
|
||||
constraints.serialize(self._bias_constraint)
|
||||
}
|
||||
base_config = super(MultiHeadAttention, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def _build_from_signature(self, query, value, key=None):
|
||||
"""Builds layers and variables.
|
||||
|
||||
Once the method is called, self._built_from_signature will be set to True.
|
||||
|
||||
Args:
|
||||
query: query tensor or TensorShape.
|
||||
value: value tensor or TensorShape.
|
||||
key: key tensor or TensorShape.
|
||||
"""
|
||||
self._built_from_signature = True
|
||||
if hasattr(query, "shape"):
|
||||
query_shape = tensor_shape.TensorShape(query.shape)
|
||||
else:
|
||||
query_shape = query
|
||||
if hasattr(value, "shape"):
|
||||
value_shape = tensor_shape.TensorShape(value.shape)
|
||||
else:
|
||||
value_shape = value
|
||||
if key is None:
|
||||
key_shape = value_shape
|
||||
elif hasattr(key, "shape"):
|
||||
key_shape = tensor_shape.TensorShape(key.shape)
|
||||
else:
|
||||
key_shape = key
|
||||
|
||||
common_kwargs = dict(
|
||||
kernel_initializer=self._kernel_initializer,
|
||||
bias_initializer=self._bias_initializer,
|
||||
kernel_regularizer=self._kernel_regularizer,
|
||||
bias_regularizer=self._bias_regularizer,
|
||||
activity_regularizer=self._activity_regularizer,
|
||||
kernel_constraint=self._kernel_constraint,
|
||||
bias_constraint=self._bias_constraint)
|
||||
# Any setup work performed only once should happen in an `init_scope`
|
||||
# to avoid creating symbolic Tensors that will later pollute any eager
|
||||
# operations.
|
||||
with tf_utils.maybe_init_scope(self):
|
||||
free_dims = query_shape.rank - 1
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
free_dims, bound_dims=1, output_dims=2)
|
||||
self._query_dense = einsum_dense.EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(output_rank - 1,
|
||||
[self._num_heads, self._key_dim]),
|
||||
bias_axes=bias_axes if self._use_bias else None,
|
||||
name="query",
|
||||
**common_kwargs)
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
key_shape.rank - 1, bound_dims=1, output_dims=2)
|
||||
self._key_dense = einsum_dense.EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(output_rank - 1,
|
||||
[self._num_heads, self._key_dim]),
|
||||
bias_axes=bias_axes if self._use_bias else None,
|
||||
name="key",
|
||||
**common_kwargs)
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
value_shape.rank - 1, bound_dims=1, output_dims=2)
|
||||
self._value_dense = einsum_dense.EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(output_rank - 1,
|
||||
[self._num_heads, self._value_dim]),
|
||||
bias_axes=bias_axes if self._use_bias else None,
|
||||
name="value",
|
||||
**common_kwargs)
|
||||
|
||||
# Builds the attention computations for multi-head dot product attention.
|
||||
# These computations could be wrapped into the keras attention layer once
|
||||
# it support mult-head einsum computations.
|
||||
self._build_attention(output_rank)
|
||||
if self._output_shape:
|
||||
if not isinstance(self._output_shape, collections.abc.Sized):
|
||||
output_shape = [self._output_shape]
|
||||
else:
|
||||
output_shape = self._output_shape
|
||||
else:
|
||||
output_shape = [query_shape[-1]]
|
||||
einsum_equation, bias_axes, output_rank = _build_proj_equation(
|
||||
free_dims, bound_dims=2, output_dims=len(output_shape))
|
||||
self._output_dense = einsum_dense.EinsumDense(
|
||||
einsum_equation,
|
||||
output_shape=_get_output_shape(output_rank - 1, output_shape),
|
||||
bias_axes=bias_axes if self._use_bias else None,
|
||||
name="attention_output",
|
||||
**common_kwargs)
|
||||
|
||||
def _build_attention(self, rank):
|
||||
"""Builds multi-head dot-product attention computations.
|
||||
|
||||
This function builds attributes necessary for `_compute_attention` to
|
||||
costomize attention computation to replace the default dot-product
|
||||
attention.
|
||||
|
||||
Args:
|
||||
rank: the rank of query, key, value tensors.
|
||||
"""
|
||||
if self._attention_axes is None:
|
||||
self._attention_axes = tuple(range(1, rank - 2))
|
||||
else:
|
||||
self._attention_axes = tuple(self._attention_axes)
|
||||
self._dot_product_equation, self._combine_equation, attn_scores_rank = (
|
||||
_build_attention_equation(rank, attn_axes=self._attention_axes))
|
||||
norm_axes = tuple(
|
||||
range(attn_scores_rank - len(self._attention_axes), attn_scores_rank))
|
||||
self._masked_softmax = advanced_activations.Softmax(axis=norm_axes)
|
||||
self._dropout_layer = core.Dropout(rate=self._dropout)
|
||||
|
||||
def _compute_attention(self, query, key, value, attention_mask=None):
|
||||
"""Applies Dot-product attention with query, key, value tensors.
|
||||
|
||||
This function defines the computation inside `call` with projected
|
||||
multi-head Q, K, V inputs. Users can override this function for customized
|
||||
attention implementation.
|
||||
|
||||
Args:
|
||||
query: Projected query `Tensor` of shape `[B, T, N, key_dim]`.
|
||||
key: Projected key `Tensor` of shape `[B, T, N, key_dim]`.
|
||||
value: Projected value `Tensor` of shape `[B, T, N, value_dim]`.
|
||||
attention_mask: a boolean mask of shape `[B, T, S]`, that prevents
|
||||
attention to certain positions.
|
||||
|
||||
Returns:
|
||||
attention_output: Multi-headed outputs of attention computation.
|
||||
attention_scores: Multi-headed attention weights.
|
||||
"""
|
||||
# Note: Applying scalar multiply at the smaller end of einsum improves
|
||||
# XLA performance, but may introduce slight numeric differences in
|
||||
# the Transformer attention head.
|
||||
query = math_ops.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw
|
||||
# attention scores.
|
||||
attention_scores = special_math_ops.einsum(self._dot_product_equation, key,
|
||||
query)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
# `attention_scores` = [B, N, T, S]
|
||||
if attention_mask is not None:
|
||||
# The expand dim happens starting from the `num_heads` dimension,
|
||||
# (<batch_dims>, num_heads, <query_attention_dims, key_attention_dims>)
|
||||
mask_expansion_axes = [-len(self._attention_axes) * 2 - 1]
|
||||
for _ in range(len(attention_scores.shape) - len(attention_mask.shape)):
|
||||
attention_mask = array_ops.expand_dims(
|
||||
attention_mask, axis=mask_expansion_axes)
|
||||
attention_scores = self._masked_softmax(attention_scores, attention_mask)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attention_scores_dropout = self._dropout_layer(attention_scores)
|
||||
|
||||
# `context_layer` = [B, T, N, H]
|
||||
attention_output = special_math_ops.einsum(self._combine_equation,
|
||||
attention_scores_dropout, value)
|
||||
return attention_output, attention_scores
|
||||
|
||||
def call(self, query, value, key=None, attention_mask=None,
|
||||
return_attention_scores=False):
|
||||
if not self._built_from_signature:
|
||||
self._build_from_signature(query=query, value=value, key=key)
|
||||
if key is None:
|
||||
key = value
|
||||
|
||||
# N = `num_attention_heads`
|
||||
# H = `size_per_head`
|
||||
# `query` = [B, T, N ,H]
|
||||
query = self._query_dense(query)
|
||||
|
||||
# `key` = [B, S, N, H]
|
||||
key = self._key_dense(key)
|
||||
|
||||
# `value` = [B, S, N, H]
|
||||
value = self._value_dense(value)
|
||||
|
||||
attention_output, attention_scores = self._compute_attention(
|
||||
query, key, value, attention_mask)
|
||||
attention_output = self._output_dense(attention_output)
|
||||
|
||||
if return_attention_scores:
|
||||
return attention_output, attention_scores
|
||||
return attention_output
|
||||
|
230
tensorflow/python/keras/layers/multi_head_attention_test.py
Normal file
230
tensorflow/python/keras/layers/multi_head_attention_test.py
Normal file
@ -0,0 +1,230 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for the attention layer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras.layers import multi_head_attention
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It
|
||||
# guarantees forward compatibility of this code for the V2 switchover.
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class MultiHeadAttentionTest(keras_parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("key_value_same_proj", None, None, [40, 80]),
|
||||
("key_value_different_proj", 32, 60, [40, 60]),
|
||||
)
|
||||
def test_non_masked_attention(self, value_dim, output_shape, output_dims):
|
||||
"""Test that the attention layer can be created without a mask tensor."""
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=12,
|
||||
key_dim=64,
|
||||
value_dim=value_dim,
|
||||
output_shape=output_shape)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
query = keras.Input(shape=(40, 80))
|
||||
value = keras.Input(shape=(20, 80))
|
||||
output = test_layer(query=query, value=value)
|
||||
self.assertEqual(output.shape.as_list(), [None] + output_dims)
|
||||
|
||||
def test_non_masked_self_attention(self):
|
||||
"""Test with one input (self-attenntion) and no mask tensor."""
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=12, key_dim=64)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
query = keras.Input(shape=(40, 80))
|
||||
output = test_layer(query, query)
|
||||
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
||||
|
||||
def test_attention_scores(self):
|
||||
"""Test attention outputs with coefficients."""
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=12, key_dim=64)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
query = keras.Input(shape=(40, 80))
|
||||
output, coef = test_layer(query, query, return_attention_scores=True)
|
||||
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
||||
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40])
|
||||
|
||||
def test_attention_scores_with_values(self):
|
||||
"""Test attention outputs with coefficients."""
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=12, key_dim=64)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
query = keras.Input(shape=(40, 80))
|
||||
value = keras.Input(shape=(60, 80))
|
||||
output, coef = test_layer(query, value, return_attention_scores=True)
|
||||
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
||||
self.assertEqual(coef.shape.as_list(), [None, 12, 40, 60])
|
||||
|
||||
@parameterized.named_parameters(("with_bias", True), ("no_bias", False))
|
||||
def test_masked_attention(self, use_bias):
|
||||
"""Test with a mask tensor."""
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=2, key_dim=2, use_bias=use_bias)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
batch_size = 3
|
||||
query = keras.Input(shape=(4, 8))
|
||||
value = keras.Input(shape=(2, 8))
|
||||
mask_tensor = keras.Input(shape=(4, 2))
|
||||
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
|
||||
|
||||
# Create a model containing the test layer.
|
||||
model = keras.Model([query, value, mask_tensor], output)
|
||||
|
||||
# Generate data for the input (non-mask) tensors.
|
||||
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
|
||||
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
|
||||
|
||||
# Invoke the data with a random set of mask data. This should mask at least
|
||||
# one element.
|
||||
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
|
||||
masked_output_data = model.predict([from_data, to_data, mask_data])
|
||||
|
||||
# Invoke the same data, but with a null mask (where no elements are masked).
|
||||
null_mask_data = np.ones((batch_size, 4, 2))
|
||||
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
|
||||
|
||||
# Because one data is masked and one is not, the outputs should not be the
|
||||
# same.
|
||||
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
||||
|
||||
# Tests the layer with three inputs: Q, K, V.
|
||||
key = keras.Input(shape=(2, 8))
|
||||
output = test_layer(query, value=value, key=key, attention_mask=mask_tensor)
|
||||
model = keras.Model([query, value, key, mask_tensor], output)
|
||||
|
||||
masked_output_data = model.predict([from_data, to_data, to_data, mask_data])
|
||||
unmasked_output_data = model.predict(
|
||||
[from_data, to_data, to_data, null_mask_data])
|
||||
# Because one data is masked and one is not, the outputs should not be the
|
||||
# same.
|
||||
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
||||
|
||||
if use_bias:
|
||||
self.assertLen(test_layer._query_dense.trainable_variables, 2)
|
||||
self.assertLen(test_layer._output_dense.trainable_variables, 2)
|
||||
else:
|
||||
self.assertLen(test_layer._query_dense.trainable_variables, 1)
|
||||
self.assertLen(test_layer._output_dense.trainable_variables, 1)
|
||||
|
||||
def test_initializer(self):
|
||||
"""Test with a specified initializer."""
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=12,
|
||||
key_dim=64,
|
||||
kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02))
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
query = keras.Input(shape=(40, 80))
|
||||
output = test_layer(query, query)
|
||||
self.assertEqual(output.shape.as_list(), [None, 40, 80])
|
||||
|
||||
def test_masked_attention_with_scores(self):
|
||||
"""Test with a mask tensor."""
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=2, key_dim=2)
|
||||
# Create a 3-dimensional input (the first dimension is implicit).
|
||||
batch_size = 3
|
||||
query = keras.Input(shape=(4, 8))
|
||||
value = keras.Input(shape=(2, 8))
|
||||
mask_tensor = keras.Input(shape=(4, 2))
|
||||
output = test_layer(query=query, value=value, attention_mask=mask_tensor)
|
||||
|
||||
# Create a model containing the test layer.
|
||||
model = keras.Model([query, value, mask_tensor], output)
|
||||
|
||||
# Generate data for the input (non-mask) tensors.
|
||||
from_data = 10 * np.random.random_sample((batch_size, 4, 8))
|
||||
to_data = 10 * np.random.random_sample((batch_size, 2, 8))
|
||||
|
||||
# Invoke the data with a random set of mask data. This should mask at least
|
||||
# one element.
|
||||
mask_data = np.random.randint(2, size=(batch_size, 4, 2))
|
||||
masked_output_data = model.predict([from_data, to_data, mask_data])
|
||||
|
||||
# Invoke the same data, but with a null mask (where no elements are masked).
|
||||
null_mask_data = np.ones((batch_size, 4, 2))
|
||||
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])
|
||||
|
||||
# Because one data is masked and one is not, the outputs should not be the
|
||||
# same.
|
||||
self.assertNotAllClose(masked_output_data, unmasked_output_data)
|
||||
|
||||
# Create a model containing attention scores.
|
||||
output, scores = test_layer(
|
||||
query=query, value=value, attention_mask=mask_tensor,
|
||||
return_attention_scores=True)
|
||||
model = keras.Model([query, value, mask_tensor], [output, scores])
|
||||
masked_output_data_score, masked_score = model.predict(
|
||||
[from_data, to_data, mask_data])
|
||||
unmasked_output_data_score, unmasked_score = model.predict(
|
||||
[from_data, to_data, null_mask_data])
|
||||
self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score)
|
||||
self.assertAllClose(masked_output_data, masked_output_data_score)
|
||||
self.assertAllClose(unmasked_output_data, unmasked_output_data_score)
|
||||
self.assertNotAllClose(masked_score, unmasked_score)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("4d_inputs_1freebatch_mask2", [3, 4], [3, 2], [4, 2],
|
||||
(2,)), ("4d_inputs_1freebatch_mask3", [3, 4], [3, 2], [3, 4, 2], (2,)),
|
||||
("4d_inputs_1freebatch_mask4", [3, 4], [3, 2], [3, 2, 4, 2],
|
||||
(2,)), ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)),
|
||||
("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)),
|
||||
("5D_inputs_2D_attention_fullmask", [5, 3, 4], [5, 3, 2], [5, 3, 4, 3, 2],
|
||||
(2, 3)))
|
||||
def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes):
|
||||
"""Test with a mask tensor."""
|
||||
test_layer = multi_head_attention.MultiHeadAttention(
|
||||
num_heads=2, key_dim=2, attention_axes=attention_axes)
|
||||
batch_size, hidden_size = 3, 8
|
||||
# Generate data for the input (non-mask) tensors.
|
||||
query_shape = [batch_size] + q_dims + [hidden_size]
|
||||
value_shape = [batch_size] + v_dims + [hidden_size]
|
||||
mask_shape = [batch_size] + mask_dims
|
||||
query = 10 * np.random.random_sample(query_shape)
|
||||
value = 10 * np.random.random_sample(value_shape)
|
||||
|
||||
# Invoke the data with a random set of mask data. This should mask at least
|
||||
# one element.
|
||||
mask_data = np.random.randint(2, size=mask_shape).astype("bool")
|
||||
# Invoke the same data, but with a null mask (where no elements are masked).
|
||||
null_mask_data = np.ones(mask_shape)
|
||||
# Because one data is masked and one is not, the outputs should not be the
|
||||
# same.
|
||||
query_tensor = keras.Input(query_shape[1:], name="query")
|
||||
value_tensor = keras.Input(value_shape[1:], name="value")
|
||||
mask_tensor = keras.Input(mask_shape[1:], name="mask")
|
||||
output = test_layer(query=query_tensor, value=value_tensor,
|
||||
attention_mask=mask_tensor)
|
||||
model = keras.Model([query_tensor, value_tensor, mask_tensor], output)
|
||||
|
||||
self.assertNotAllClose(
|
||||
model.predict([query, value, mask_data]),
|
||||
model.predict([query, value, null_mask_data]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -37,6 +37,7 @@ from tensorflow.python.keras.layers import einsum_dense
|
||||
from tensorflow.python.keras.layers import embeddings
|
||||
from tensorflow.python.keras.layers import local
|
||||
from tensorflow.python.keras.layers import merge
|
||||
from tensorflow.python.keras.layers import multi_head_attention
|
||||
from tensorflow.python.keras.layers import noise
|
||||
from tensorflow.python.keras.layers import normalization
|
||||
from tensorflow.python.keras.layers import normalization_v2
|
||||
@ -70,7 +71,8 @@ ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional,
|
||||
pooling, image_preprocessing, preprocessing_integer_lookup_v1,
|
||||
preprocessing_normalization_v1, preprocessing_string_lookup_v1,
|
||||
preprocessing_text_vectorization_v1, recurrent, wrappers,
|
||||
hashing, category_crossing, category_encoding_v1, discretization)
|
||||
hashing, category_crossing, category_encoding_v1, discretization,
|
||||
multi_head_attention)
|
||||
ALL_V2_MODULES = (rnn_cell_wrapper_v2, normalization_v2, recurrent_v2,
|
||||
preprocessing_integer_lookup, preprocessing_normalization,
|
||||
preprocessing_string_lookup, preprocessing_text_vectorization,
|
||||
|
@ -0,0 +1,222 @@
|
||||
path: "tensorflow.keras.layers.MultiHeadAttention"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.multi_head_attention.MultiHeadAttention\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "activity_regularizer"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dynamic"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "inbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "metrics"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "outbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "stateful"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "submodules"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "supports_masking"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_metric"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_signature"
|
||||
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "count_params"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_name_scope"
|
||||
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -149,7 +149,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -312,6 +312,10 @@ tf_module {
|
||||
name: "Minimum"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "MultiHeadAttention"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Multiply"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,222 @@
|
||||
path: "tensorflow.keras.layers.MultiHeadAttention"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.multi_head_attention.MultiHeadAttention\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "activity_regularizer"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dynamic"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "inbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "metrics"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "outbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "stateful"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "submodules"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "supports_masking"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_metric"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_signature"
|
||||
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "count_params"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_name_scope"
|
||||
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -149,7 +149,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -304,6 +304,10 @@ tf_module {
|
||||
name: "Minimum"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "MultiHeadAttention"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Multiply"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user