From f32c80b3ed0d64eb0363f4196171467de79390d1 Mon Sep 17 00:00:00 2001 From: Zhenyu Tan Date: Wed, 12 Aug 2020 18:44:28 -0700 Subject: [PATCH] Add MultiHeadAttention Layer for Keras. PiperOrigin-RevId: 326359296 Change-Id: Iacdc310f66aa1848b068fa3f0fc8784ea7b80ef5 --- tensorflow/python/keras/layers/BUILD | 29 ++ tensorflow/python/keras/layers/__init__.py | 3 + .../keras/layers/advanced_activations.py | 55 ++- .../keras/layers/multi_head_attention.py | 460 ++++++++++++++++++ .../keras/layers/multi_head_attention_test.py | 230 +++++++++ .../python/keras/layers/serialization.py | 4 +- ...w.keras.layers.-multi-head-attention.pbtxt | 222 +++++++++ .../v1/tensorflow.keras.layers.-softmax.pbtxt | 2 +- .../golden/v1/tensorflow.keras.layers.pbtxt | 4 + ...w.keras.layers.-multi-head-attention.pbtxt | 222 +++++++++ .../v2/tensorflow.keras.layers.-softmax.pbtxt | 2 +- .../golden/v2/tensorflow.keras.layers.pbtxt | 4 + 12 files changed, 1232 insertions(+), 5 deletions(-) create mode 100644 tensorflow/python/keras/layers/multi_head_attention.py create mode 100644 tensorflow/python/keras/layers/multi_head_attention_test.py create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index fe46f580162..e3497c59061 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -39,6 +39,7 @@ py_library( ":kernelized", ":local", ":merge", + ":multi_head_attention", ":noise", ":normalization", ":normalization_v2", @@ -207,6 +208,22 @@ py_library( ], ) +py_library( + name = "multi_head_attention", + srcs = ["multi_head_attention.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:special_math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", + "//tensorflow/python/keras:activations", + "//tensorflow/python/keras:base_layer", + "//tensorflow/python/keras:constraints", + "//tensorflow/python/keras:initializers", + "//tensorflow/python/keras:regularizers", + ], +) + py_library( name = "embeddings", srcs = ["embeddings.py"], @@ -590,6 +607,18 @@ tf_py_test( ], ) +tf_py_test( + name = "multi_head_attention_test", + srcs = ["multi_head_attention_test.py"], + python_version = "PY3", + deps = [ + ":multi_head_attention", + "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", + "@absl_py//absl/testing:parameterized", + ], +) + cuda_py_test( name = "embeddings_test", size = "medium", diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index 8ce1c7d8224..b07773ae03a 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -143,6 +143,9 @@ from tensorflow.python.keras.layers.embeddings import Embedding # Einsum-based dense layer/ from tensorflow.python.keras.layers.einsum_dense import EinsumDense +# Multi-head Attention layer. +from tensorflow.python.keras.layers.multi_head_attention import MultiHeadAttention + # Locally-connected layers. from tensorflow.python.keras.layers.local import LocallyConnected1D from tensorflow.python.keras.layers.local import LocallyConnected2D diff --git a/tensorflow/python/keras/layers/advanced_activations.py b/tensorflow/python/keras/layers/advanced_activations.py index 7cb40c172b7..e4323b45dc4 100644 --- a/tensorflow/python/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/layers/advanced_activations.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import dtypes from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints from tensorflow.python.keras import initializers @@ -259,10 +260,37 @@ class ThresholdedReLU(Layer): return input_shape +def _large_compatible_negative(tensor_type): + """Large negative number as Tensor. + + This function is necessary because the standard value for epsilon + in this module (-1e9) cannot be represented using tf.float16 + + Args: + tensor_type: a dtype to determine the type. + + Returns: + a large negative number. + """ + if tensor_type == dtypes.float16: + return dtypes.float16.min + return -1e9 + + @keras_export('keras.layers.Softmax') class Softmax(Layer): """Softmax activation function. + Example without mask: + + >>> inp = np.asarray([1., 2., 1.]) + >>> layer = tf.keras.layers.Softmax() + >>> layer(inp).numpy() + array([0.21194157, 0.5761169 , 0.21194157], dtype=float32) + >>> mask = np.asarray([True, False, True], dtype=bool) + >>> layer(inp, mask).numpy() + array([0.5, 0. , 0.5], dtype=float32) + Input shape: Arbitrary. Use the keyword argument `input_shape` (tuple of integers, does not include the samples axis) @@ -272,7 +300,14 @@ class Softmax(Layer): Same shape as the input. Arguments: - axis: Integer, axis along which the softmax normalization is applied. + axis: Integer, or list of Integers, axis along which the softmax + normalization is applied. + Call arguments: + inputs: The inputs, or logits to the softmax layer. + mask: A boolean mask of the same shape as `inputs`. Defaults to `None`. + + Returns: + softmaxed output with the same shape as `inputs`. """ def __init__(self, axis=-1, **kwargs): @@ -280,7 +315,23 @@ class Softmax(Layer): self.supports_masking = True self.axis = axis - def call(self, inputs): + def call(self, inputs, mask=None): + if mask is not None: + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -1e.9 for masked positions. + adder = (1.0 - math_ops.cast(mask, inputs.dtype)) * ( + _large_compatible_negative(inputs.dtype)) + + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + inputs += adder + if isinstance(self.axis, (tuple, list)): + if len(self.axis) > 1: + return math_ops.exp(inputs - math_ops.reduce_logsumexp( + inputs, axis=self.axis, keepdims=True)) + else: + return K.softmax(inputs, axis=self.axis[0]) return K.softmax(inputs, axis=self.axis) def get_config(self): diff --git a/tensorflow/python/keras/layers/multi_head_attention.py b/tensorflow/python/keras/layers/multi_head_attention.py new file mode 100644 index 00000000000..210d6133d58 --- /dev/null +++ b/tensorflow/python/keras/layers/multi_head_attention.py @@ -0,0 +1,460 @@ +# Lint as: python3 +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-based attention layer.""" +# pylint: disable=g-classes-have-attributes +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import math +import string + +import numpy as np + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.keras.layers import advanced_activations +from tensorflow.python.keras.layers import core +from tensorflow.python.keras.layers import einsum_dense +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.util.tf_export import keras_export + + +_CHR_IDX = string.ascii_lowercase + + +def _build_attention_equation(rank, attn_axes): + """Builds einsum equations for the attention computation. + + Query, key, value inputs after projection are expected to have the shape as: + (bs, , , num_heads, channels). + bs and are treated as . + The attention operations can be generalized: + (1) Query-key dot product: + (, , num_heads, channels), (, + , num_heads, channels) -> (, + num_heads, , ) + (2) Combination: + (, num_heads, , ), + (, , num_heads, channels) -> (, + , num_heads, channels) + + Args: + rank: the rank of query, key, value tensors. + attn_axes: a list/tuple of axes, [1, rank), that will do attention. + + Returns: + Einsum equations. + """ + target_notation = _CHR_IDX[:rank] + # `batch_dims` includes the head dim. + batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,))) + letter_offset = rank + source_notation = "" + for i in range(rank): + if i in batch_dims or i == rank - 1: + source_notation += target_notation[i] + else: + source_notation += _CHR_IDX[letter_offset] + letter_offset += 1 + + product_notation = "".join([target_notation[i] for i in batch_dims] + + [target_notation[i] for i in attn_axes] + + [source_notation[i] for i in attn_axes]) + dot_product_equation = "%s,%s->%s" % (source_notation, target_notation, + product_notation) + attn_scores_rank = len(product_notation) + combine_equation = "%s,%s->%s" % (product_notation, source_notation, + target_notation) + return dot_product_equation, combine_equation, attn_scores_rank + + +def _build_proj_equation(free_dims, bound_dims, output_dims): + """Builds an einsum equation for projections inside multi-head attention.""" + input_str = "" + kernel_str = "" + output_str = "" + bias_axes = "" + letter_offset = 0 + for i in range(free_dims): + char = _CHR_IDX[i + letter_offset] + input_str += char + output_str += char + + letter_offset += free_dims + for i in range(bound_dims): + char = _CHR_IDX[i + letter_offset] + input_str += char + kernel_str += char + + letter_offset += bound_dims + for i in range(output_dims): + char = _CHR_IDX[i + letter_offset] + kernel_str += char + output_str += char + bias_axes += char + equation = "%s,%s->%s" % (input_str, kernel_str, output_str) + + return equation, bias_axes, len(output_str) + + +def _get_output_shape(output_rank, known_last_dims): + return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) + + +@keras_export("keras.layers.MultiHeadAttention") +class MultiHeadAttention(Layer): + """MultiHeadAttention layer. + + This is an implementation of multi-headed attention based on "Attention + is all you Need". If `query`, `key,` `value` are the same, then + this is self-attention. Each timestep in `query` attends to the + corresponding sequence in `key`, and returns a fixed-width vector. + + This layer first projects `query`, `key` and `value`. These are + (effectively) a list of tensors of length `num_attention_heads`, where the + corresponding shapes are [batch_size, , key_dim], + [batch_size, , key_dim], + [batch_size, , value_dim]. + + Then, the query and key tensors are dot-producted and scaled. These are + softmaxed to obtain attention probabilities. The value tensors are then + interpolated by these probabilities, then concatenated back to a single + tensor. + + Finally, the result tensor with the last dimension as value_dim can take an + linear projection and return. + + Examples: + + Performs 1D cross-attention over two sequence inputs with an attention mask. + Returns the additional attention weights over heads. + + >>> layer = MultiHeadAttention(num_heads=2, key_dim=2) + >>> target = tf.keras.Input(shape=[8, 16]) + >>> source = tf.keras.Input(shape=[4, 16]) + >>> output_tensor, weights = layer(target, source, + ... return_attention_scores=True) + >>> print(output_tensor.shape) + (None, 8, 16) + >>> print(weights.shape) + (None, 2, 8, 4) + + Performs 2D self-attention over a 5D input tensor on axes 2 and 3. + + >>> layer = MultiHeadAttention(num_heads=2, key_dim=2, attention_axes=(2, 3)) + >>> input_tensor = tf.keras.Input(shape=[5, 3, 4, 16]) + >>> output_tensor = layer(input_tensor, input_tensor) + >>> print(output_tensor.shape) + (None, 5, 3, 4, 16) + + Arguments: + num_heads: Number of attention heads. + key_dim: Size of each attention head for query and key. + value_dim: Size of each attention head for value. + dropout: Dropout probability. + use_bias: Boolean, whether the dense layers use bias vectors/matrices. + output_shape: The expected shape of an output tensor, besides the batch and + sequence dims. If not specified, projects back to the key feature dim. + attention_axes: axes over which the attention is applied. `None` means + attention over all axes, but batch, heads, and features. + kernel_initializer: Initializer for dense layer kernels. + bias_initializer: Initializer for dense layer biases. + kernel_regularizer: Regularizer for dense layer kernels. + bias_regularizer: Regularizer for dense layer biases. + activity_regularizer: Regularizer for dense layer activity. + kernel_constraint: Constraint for dense layer kernels. + bias_constraint: Constraint for dense layer kernels. + + Call arguments: + query: Query `Tensor` of shape `[B, T, dim]`. + value: Value `Tensor` of shape `[B, S, dim]`. + key: Optional key `Tensor` of shape `[B, S, dim]`. If not given, will use + `value` for both `key` and `value`, which is the most common case. + attention_mask: a boolean mask of shape `[B, T, S]`, that prevents attention + to certain positions. + return_attention_scores: A boolean to indicate whether the output should + be attention output if True, or (attention_output, attention_scores) if + False. Defaults to False. + + Returns: + attention_output: The result of the computation, of shape [B, T, E], + where `T` is for target sequence shapes and `E` is the query input last + dimension if `output_shape` is `None`. Otherwise, the multi-head outputs + are project to the shape specified by `output_shape`. + attention_scores: [Optional] multi-head attention coeffients over + attention axes. + """ + + def __init__(self, + num_heads, + key_dim, + value_dim=None, + dropout=0.0, + use_bias=True, + output_shape=None, + attention_axes=None, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs): + super(MultiHeadAttention, self).__init__(**kwargs) + self._num_heads = num_heads + self._key_dim = key_dim + self._value_dim = value_dim if value_dim else key_dim + self._dropout = dropout + self._use_bias = use_bias + self._output_shape = output_shape + self._kernel_initializer = initializers.get(kernel_initializer) + self._bias_initializer = initializers.get(bias_initializer) + self._kernel_regularizer = regularizers.get(kernel_regularizer) + self._bias_regularizer = regularizers.get(bias_regularizer) + self._kernel_constraint = constraints.get(kernel_constraint) + self._bias_constraint = constraints.get(bias_constraint) + if attention_axes is not None and not isinstance(attention_axes, + collections.abc.Sized): + self._attention_axes = (attention_axes,) + else: + self._attention_axes = attention_axes + self._built_from_signature = False + + def get_config(self): + config = { + "num_heads": + self._num_heads, + "key_dim": + self._key_dim, + "value_dim": + self._value_dim, + "dropout": + self._dropout, + "use_bias": + self._use_bias, + "output_shape": + self._output_shape, + "attention_axes": + self._attention_axes, + "kernel_initializer": + initializers.serialize(self._kernel_initializer), + "bias_initializer": + initializers.serialize(self._bias_initializer), + "kernel_regularizer": + regularizers.serialize(self._kernel_regularizer), + "bias_regularizer": + regularizers.serialize(self._bias_regularizer), + "activity_regularizer": + regularizers.serialize(self._activity_regularizer), + "kernel_constraint": + constraints.serialize(self._kernel_constraint), + "bias_constraint": + constraints.serialize(self._bias_constraint) + } + base_config = super(MultiHeadAttention, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def _build_from_signature(self, query, value, key=None): + """Builds layers and variables. + + Once the method is called, self._built_from_signature will be set to True. + + Args: + query: query tensor or TensorShape. + value: value tensor or TensorShape. + key: key tensor or TensorShape. + """ + self._built_from_signature = True + if hasattr(query, "shape"): + query_shape = tensor_shape.TensorShape(query.shape) + else: + query_shape = query + if hasattr(value, "shape"): + value_shape = tensor_shape.TensorShape(value.shape) + else: + value_shape = value + if key is None: + key_shape = value_shape + elif hasattr(key, "shape"): + key_shape = tensor_shape.TensorShape(key.shape) + else: + key_shape = key + + common_kwargs = dict( + kernel_initializer=self._kernel_initializer, + bias_initializer=self._bias_initializer, + kernel_regularizer=self._kernel_regularizer, + bias_regularizer=self._bias_regularizer, + activity_regularizer=self._activity_regularizer, + kernel_constraint=self._kernel_constraint, + bias_constraint=self._bias_constraint) + # Any setup work performed only once should happen in an `init_scope` + # to avoid creating symbolic Tensors that will later pollute any eager + # operations. + with tf_utils.maybe_init_scope(self): + free_dims = query_shape.rank - 1 + einsum_equation, bias_axes, output_rank = _build_proj_equation( + free_dims, bound_dims=1, output_dims=2) + self._query_dense = einsum_dense.EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, + [self._num_heads, self._key_dim]), + bias_axes=bias_axes if self._use_bias else None, + name="query", + **common_kwargs) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + key_shape.rank - 1, bound_dims=1, output_dims=2) + self._key_dense = einsum_dense.EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, + [self._num_heads, self._key_dim]), + bias_axes=bias_axes if self._use_bias else None, + name="key", + **common_kwargs) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + value_shape.rank - 1, bound_dims=1, output_dims=2) + self._value_dense = einsum_dense.EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, + [self._num_heads, self._value_dim]), + bias_axes=bias_axes if self._use_bias else None, + name="value", + **common_kwargs) + + # Builds the attention computations for multi-head dot product attention. + # These computations could be wrapped into the keras attention layer once + # it support mult-head einsum computations. + self._build_attention(output_rank) + if self._output_shape: + if not isinstance(self._output_shape, collections.abc.Sized): + output_shape = [self._output_shape] + else: + output_shape = self._output_shape + else: + output_shape = [query_shape[-1]] + einsum_equation, bias_axes, output_rank = _build_proj_equation( + free_dims, bound_dims=2, output_dims=len(output_shape)) + self._output_dense = einsum_dense.EinsumDense( + einsum_equation, + output_shape=_get_output_shape(output_rank - 1, output_shape), + bias_axes=bias_axes if self._use_bias else None, + name="attention_output", + **common_kwargs) + + def _build_attention(self, rank): + """Builds multi-head dot-product attention computations. + + This function builds attributes necessary for `_compute_attention` to + costomize attention computation to replace the default dot-product + attention. + + Args: + rank: the rank of query, key, value tensors. + """ + if self._attention_axes is None: + self._attention_axes = tuple(range(1, rank - 2)) + else: + self._attention_axes = tuple(self._attention_axes) + self._dot_product_equation, self._combine_equation, attn_scores_rank = ( + _build_attention_equation(rank, attn_axes=self._attention_axes)) + norm_axes = tuple( + range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) + self._masked_softmax = advanced_activations.Softmax(axis=norm_axes) + self._dropout_layer = core.Dropout(rate=self._dropout) + + def _compute_attention(self, query, key, value, attention_mask=None): + """Applies Dot-product attention with query, key, value tensors. + + This function defines the computation inside `call` with projected + multi-head Q, K, V inputs. Users can override this function for customized + attention implementation. + + Args: + query: Projected query `Tensor` of shape `[B, T, N, key_dim]`. + key: Projected key `Tensor` of shape `[B, T, N, key_dim]`. + value: Projected value `Tensor` of shape `[B, T, N, value_dim]`. + attention_mask: a boolean mask of shape `[B, T, S]`, that prevents + attention to certain positions. + + Returns: + attention_output: Multi-headed outputs of attention computation. + attention_scores: Multi-headed attention weights. + """ + # Note: Applying scalar multiply at the smaller end of einsum improves + # XLA performance, but may introduce slight numeric differences in + # the Transformer attention head. + query = math_ops.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) + + # Take the dot product between "query" and "key" to get the raw + # attention scores. + attention_scores = special_math_ops.einsum(self._dot_product_equation, key, + query) + + # Normalize the attention scores to probabilities. + # `attention_scores` = [B, N, T, S] + if attention_mask is not None: + # The expand dim happens starting from the `num_heads` dimension, + # (, num_heads, ) + mask_expansion_axes = [-len(self._attention_axes) * 2 - 1] + for _ in range(len(attention_scores.shape) - len(attention_mask.shape)): + attention_mask = array_ops.expand_dims( + attention_mask, axis=mask_expansion_axes) + attention_scores = self._masked_softmax(attention_scores, attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_scores_dropout = self._dropout_layer(attention_scores) + + # `context_layer` = [B, T, N, H] + attention_output = special_math_ops.einsum(self._combine_equation, + attention_scores_dropout, value) + return attention_output, attention_scores + + def call(self, query, value, key=None, attention_mask=None, + return_attention_scores=False): + if not self._built_from_signature: + self._build_from_signature(query=query, value=value, key=key) + if key is None: + key = value + + # N = `num_attention_heads` + # H = `size_per_head` + # `query` = [B, T, N ,H] + query = self._query_dense(query) + + # `key` = [B, S, N, H] + key = self._key_dense(key) + + # `value` = [B, S, N, H] + value = self._value_dense(value) + + attention_output, attention_scores = self._compute_attention( + query, key, value, attention_mask) + attention_output = self._output_dense(attention_output) + + if return_attention_scores: + return attention_output, attention_scores + return attention_output + diff --git a/tensorflow/python/keras/layers/multi_head_attention_test.py b/tensorflow/python/keras/layers/multi_head_attention_test.py new file mode 100644 index 00000000000..7702a2898c4 --- /dev/null +++ b/tensorflow/python/keras/layers/multi_head_attention_test.py @@ -0,0 +1,230 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the attention layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np + +from tensorflow.python import keras +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras.layers import multi_head_attention +from tensorflow.python.platform import test + + +# This decorator runs the test in V1, V2-Eager, and V2-Functional mode. It +# guarantees forward compatibility of this code for the V2 switchover. +@keras_parameterized.run_all_keras_modes +class MultiHeadAttentionTest(keras_parameterized.TestCase): + + @parameterized.named_parameters( + ("key_value_same_proj", None, None, [40, 80]), + ("key_value_different_proj", 32, 60, [40, 60]), + ) + def test_non_masked_attention(self, value_dim, output_shape, output_dims): + """Test that the attention layer can be created without a mask tensor.""" + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=12, + key_dim=64, + value_dim=value_dim, + output_shape=output_shape) + # Create a 3-dimensional input (the first dimension is implicit). + query = keras.Input(shape=(40, 80)) + value = keras.Input(shape=(20, 80)) + output = test_layer(query=query, value=value) + self.assertEqual(output.shape.as_list(), [None] + output_dims) + + def test_non_masked_self_attention(self): + """Test with one input (self-attenntion) and no mask tensor.""" + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=12, key_dim=64) + # Create a 3-dimensional input (the first dimension is implicit). + query = keras.Input(shape=(40, 80)) + output = test_layer(query, query) + self.assertEqual(output.shape.as_list(), [None, 40, 80]) + + def test_attention_scores(self): + """Test attention outputs with coefficients.""" + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=12, key_dim=64) + # Create a 3-dimensional input (the first dimension is implicit). + query = keras.Input(shape=(40, 80)) + output, coef = test_layer(query, query, return_attention_scores=True) + self.assertEqual(output.shape.as_list(), [None, 40, 80]) + self.assertEqual(coef.shape.as_list(), [None, 12, 40, 40]) + + def test_attention_scores_with_values(self): + """Test attention outputs with coefficients.""" + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=12, key_dim=64) + # Create a 3-dimensional input (the first dimension is implicit). + query = keras.Input(shape=(40, 80)) + value = keras.Input(shape=(60, 80)) + output, coef = test_layer(query, value, return_attention_scores=True) + self.assertEqual(output.shape.as_list(), [None, 40, 80]) + self.assertEqual(coef.shape.as_list(), [None, 12, 40, 60]) + + @parameterized.named_parameters(("with_bias", True), ("no_bias", False)) + def test_masked_attention(self, use_bias): + """Test with a mask tensor.""" + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=2, key_dim=2, use_bias=use_bias) + # Create a 3-dimensional input (the first dimension is implicit). + batch_size = 3 + query = keras.Input(shape=(4, 8)) + value = keras.Input(shape=(2, 8)) + mask_tensor = keras.Input(shape=(4, 2)) + output = test_layer(query=query, value=value, attention_mask=mask_tensor) + + # Create a model containing the test layer. + model = keras.Model([query, value, mask_tensor], output) + + # Generate data for the input (non-mask) tensors. + from_data = 10 * np.random.random_sample((batch_size, 4, 8)) + to_data = 10 * np.random.random_sample((batch_size, 2, 8)) + + # Invoke the data with a random set of mask data. This should mask at least + # one element. + mask_data = np.random.randint(2, size=(batch_size, 4, 2)) + masked_output_data = model.predict([from_data, to_data, mask_data]) + + # Invoke the same data, but with a null mask (where no elements are masked). + null_mask_data = np.ones((batch_size, 4, 2)) + unmasked_output_data = model.predict([from_data, to_data, null_mask_data]) + + # Because one data is masked and one is not, the outputs should not be the + # same. + self.assertNotAllClose(masked_output_data, unmasked_output_data) + + # Tests the layer with three inputs: Q, K, V. + key = keras.Input(shape=(2, 8)) + output = test_layer(query, value=value, key=key, attention_mask=mask_tensor) + model = keras.Model([query, value, key, mask_tensor], output) + + masked_output_data = model.predict([from_data, to_data, to_data, mask_data]) + unmasked_output_data = model.predict( + [from_data, to_data, to_data, null_mask_data]) + # Because one data is masked and one is not, the outputs should not be the + # same. + self.assertNotAllClose(masked_output_data, unmasked_output_data) + + if use_bias: + self.assertLen(test_layer._query_dense.trainable_variables, 2) + self.assertLen(test_layer._output_dense.trainable_variables, 2) + else: + self.assertLen(test_layer._query_dense.trainable_variables, 1) + self.assertLen(test_layer._output_dense.trainable_variables, 1) + + def test_initializer(self): + """Test with a specified initializer.""" + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=12, + key_dim=64, + kernel_initializer=keras.initializers.TruncatedNormal(stddev=0.02)) + # Create a 3-dimensional input (the first dimension is implicit). + query = keras.Input(shape=(40, 80)) + output = test_layer(query, query) + self.assertEqual(output.shape.as_list(), [None, 40, 80]) + + def test_masked_attention_with_scores(self): + """Test with a mask tensor.""" + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=2, key_dim=2) + # Create a 3-dimensional input (the first dimension is implicit). + batch_size = 3 + query = keras.Input(shape=(4, 8)) + value = keras.Input(shape=(2, 8)) + mask_tensor = keras.Input(shape=(4, 2)) + output = test_layer(query=query, value=value, attention_mask=mask_tensor) + + # Create a model containing the test layer. + model = keras.Model([query, value, mask_tensor], output) + + # Generate data for the input (non-mask) tensors. + from_data = 10 * np.random.random_sample((batch_size, 4, 8)) + to_data = 10 * np.random.random_sample((batch_size, 2, 8)) + + # Invoke the data with a random set of mask data. This should mask at least + # one element. + mask_data = np.random.randint(2, size=(batch_size, 4, 2)) + masked_output_data = model.predict([from_data, to_data, mask_data]) + + # Invoke the same data, but with a null mask (where no elements are masked). + null_mask_data = np.ones((batch_size, 4, 2)) + unmasked_output_data = model.predict([from_data, to_data, null_mask_data]) + + # Because one data is masked and one is not, the outputs should not be the + # same. + self.assertNotAllClose(masked_output_data, unmasked_output_data) + + # Create a model containing attention scores. + output, scores = test_layer( + query=query, value=value, attention_mask=mask_tensor, + return_attention_scores=True) + model = keras.Model([query, value, mask_tensor], [output, scores]) + masked_output_data_score, masked_score = model.predict( + [from_data, to_data, mask_data]) + unmasked_output_data_score, unmasked_score = model.predict( + [from_data, to_data, null_mask_data]) + self.assertNotAllClose(masked_output_data_score, unmasked_output_data_score) + self.assertAllClose(masked_output_data, masked_output_data_score) + self.assertAllClose(unmasked_output_data, unmasked_output_data_score) + self.assertNotAllClose(masked_score, unmasked_score) + + @parameterized.named_parameters( + ("4d_inputs_1freebatch_mask2", [3, 4], [3, 2], [4, 2], + (2,)), ("4d_inputs_1freebatch_mask3", [3, 4], [3, 2], [3, 4, 2], (2,)), + ("4d_inputs_1freebatch_mask4", [3, 4], [3, 2], [3, 2, 4, 2], + (2,)), ("4D_inputs_2D_attention", [3, 4], [3, 2], [3, 4, 3, 2], (1, 2)), + ("5D_inputs_2D_attention", [5, 3, 4], [5, 3, 2], [3, 4, 3, 2], (2, 3)), + ("5D_inputs_2D_attention_fullmask", [5, 3, 4], [5, 3, 2], [5, 3, 4, 3, 2], + (2, 3))) + def test_high_dim_attention(self, q_dims, v_dims, mask_dims, attention_axes): + """Test with a mask tensor.""" + test_layer = multi_head_attention.MultiHeadAttention( + num_heads=2, key_dim=2, attention_axes=attention_axes) + batch_size, hidden_size = 3, 8 + # Generate data for the input (non-mask) tensors. + query_shape = [batch_size] + q_dims + [hidden_size] + value_shape = [batch_size] + v_dims + [hidden_size] + mask_shape = [batch_size] + mask_dims + query = 10 * np.random.random_sample(query_shape) + value = 10 * np.random.random_sample(value_shape) + + # Invoke the data with a random set of mask data. This should mask at least + # one element. + mask_data = np.random.randint(2, size=mask_shape).astype("bool") + # Invoke the same data, but with a null mask (where no elements are masked). + null_mask_data = np.ones(mask_shape) + # Because one data is masked and one is not, the outputs should not be the + # same. + query_tensor = keras.Input(query_shape[1:], name="query") + value_tensor = keras.Input(value_shape[1:], name="value") + mask_tensor = keras.Input(mask_shape[1:], name="mask") + output = test_layer(query=query_tensor, value=value_tensor, + attention_mask=mask_tensor) + model = keras.Model([query_tensor, value_tensor, mask_tensor], output) + + self.assertNotAllClose( + model.predict([query, value, mask_data]), + model.predict([query, value, null_mask_data])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index d990f2075c8..d1fa4c19e92 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -37,6 +37,7 @@ from tensorflow.python.keras.layers import einsum_dense from tensorflow.python.keras.layers import embeddings from tensorflow.python.keras.layers import local from tensorflow.python.keras.layers import merge +from tensorflow.python.keras.layers import multi_head_attention from tensorflow.python.keras.layers import noise from tensorflow.python.keras.layers import normalization from tensorflow.python.keras.layers import normalization_v2 @@ -70,7 +71,8 @@ ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional, pooling, image_preprocessing, preprocessing_integer_lookup_v1, preprocessing_normalization_v1, preprocessing_string_lookup_v1, preprocessing_text_vectorization_v1, recurrent, wrappers, - hashing, category_crossing, category_encoding_v1, discretization) + hashing, category_crossing, category_encoding_v1, discretization, + multi_head_attention) ALL_V2_MODULES = (rnn_cell_wrapper_v2, normalization_v2, recurrent_v2, preprocessing_integer_lookup, preprocessing_normalization, preprocessing_string_lookup, preprocessing_text_vectorization, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt new file mode 100644 index 00000000000..070ee20ab30 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multi-head-attention.pbtxt @@ -0,0 +1,222 @@ +path: "tensorflow.keras.layers.MultiHeadAttention" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "supports_masking" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt index db272bdf782..97e4b91bfa3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt @@ -149,7 +149,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt index ea139297807..35714912b04 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt @@ -312,6 +312,10 @@ tf_module { name: "Minimum" mtype: "" } + member { + name: "MultiHeadAttention" + mtype: "" + } member { name: "Multiply" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt new file mode 100644 index 00000000000..070ee20ab30 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multi-head-attention.pbtxt @@ -0,0 +1,222 @@ +path: "tensorflow.keras.layers.MultiHeadAttention" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "supports_masking" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_heads\', \'key_dim\', \'value_dim\', \'dropout\', \'use_bias\', \'output_shape\', \'attention_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'0.0\', \'True\', \'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'query\', \'value\', \'key\', \'attention_mask\', \'return_attention_scores\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt index db272bdf782..97e4b91bfa3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt @@ -149,7 +149,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt index 3706919341d..078c7ec8a67 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.pbtxt @@ -304,6 +304,10 @@ tf_module { name: "Minimum" mtype: "" } + member { + name: "MultiHeadAttention" + mtype: "" + } member { name: "Multiply" mtype: ""