diff --git a/tensorflow/python/keras/layers/BUILD b/tensorflow/python/keras/layers/BUILD index d48cb71cb5c..fde014d5834 100644 --- a/tensorflow/python/keras/layers/BUILD +++ b/tensorflow/python/keras/layers/BUILD @@ -34,6 +34,7 @@ py_library( ":core", ":cudnn_recurrent", ":dense_attention", + ":einsum_dense", ":embeddings", ":kernelized", ":local", @@ -187,6 +188,22 @@ py_library( ], ) +py_library( + name = "einsum_dense", + srcs = ["einsum_dense.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:special_math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", + "//tensorflow/python/keras:activations", + "//tensorflow/python/keras:base_layer", + "//tensorflow/python/keras:constraints", + "//tensorflow/python/keras:initializers", + "//tensorflow/python/keras:regularizers", + ], +) + py_library( name = "embeddings", srcs = ["embeddings.py"], @@ -581,6 +598,18 @@ cuda_py_test( ], ) +tf_py_test( + name = "einsum_dense_test", + srcs = ["einsum_dense_test.py"], + python_version = "PY3", + deps = [ + ":einsum_dense", + "//tensorflow/python:client_testlib", + "//tensorflow/python/keras", + "@absl_py//absl/testing:parameterized", + ], +) + tf_py_test( name = "local_test", size = "medium", diff --git a/tensorflow/python/keras/layers/__init__.py b/tensorflow/python/keras/layers/__init__.py index c4388ec94fe..192c6a4afc8 100644 --- a/tensorflow/python/keras/layers/__init__.py +++ b/tensorflow/python/keras/layers/__init__.py @@ -119,6 +119,9 @@ from tensorflow.python.keras.layers.dense_attention import Attention # Embedding layers. from tensorflow.python.keras.layers.embeddings import Embedding +# Einsum-based dense layer/ +from tensorflow.python.keras.layers.einsum_dense import EinsumDense + # Locally-connected layers. from tensorflow.python.keras.layers.local import LocallyConnected1D from tensorflow.python.keras.layers.local import LocallyConnected2D diff --git a/tensorflow/python/keras/layers/einsum_dense.py b/tensorflow/python/keras/layers/einsum_dense.py new file mode 100644 index 00000000000..7b5bd085703 --- /dev/null +++ b/tensorflow/python/keras/layers/einsum_dense.py @@ -0,0 +1,337 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Keras-based einsum dense layer.""" +# pylint: disable=g-classes-have-attributes +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import activations +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.base_layer import Layer +from tensorflow.python.ops import special_math_ops +from tensorflow.python.util.tf_export import keras_export + + +@keras_export("keras.layers.experimental.EinsumDense") +class EinsumDense(Layer): + """A layer that uses tf.einsum as the backing computation. + + This layer can perform einsum calculations of arbitrary dimensionality. + + Arguments: + equation: An equation describing the einsum to perform. This equation must + be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or + `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis + expression sequence. + output_shape: The expected shape of the output tensor (excluding the batch + dimension and any dimensions represented by ellipses). You can specify + None for any dimension that is unknown or can be inferred from the input + shape. + activation: Activation function to use. If you don't specify anything, no + activation is applied (that is, a "linear" activation: `a(x) = x`). + bias_axes: A string containing the output dimension(s) to apply a bias to. + Each character in the `bias_axes` string should correspond to a character + in the output portion of the `equation` string. + kernel_initializer: Initializer for the `kernel` weights matrix. + bias_initializer: Initializer for the bias vector. + kernel_regularizer: Regularizer function applied to the `kernel` weights + matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to the output of the + layer (its "activation").. + kernel_constraint: Constraint function applied to the `kernel` weights + matrix. + bias_constraint: Constraint function applied to the bias vector. + + Examples: + + **Biased dense layer with einsums** + + This example shows how to instantiate a standard Keras dense layer using + einsum operations. This example is equivalent to + `tf.keras.layers.Dense(64, use_bias=True)`. + + >>> layer = EinsumDense("ab,bc->ac", output_shape=64, bias_axes="c") + >>> input_tensor = tf.keras.Input(shape=[32]) + >>> output_tensor = layer(input_tensor) + >>> output_tensor + + + **Applying a dense layer to a sequence** + + This example shows how to instantiate a layer that applies the same dense + operation to every element in a sequence. Here, the 'output_shape' has two + values (since there are two non-batch dimensions in the output); the first + dimension in the output_shape is `None`, because the sequence dimension `b` + has an unknown shape. + + >>> layer = EinsumDense("abc,cd->abd", + ... output_shape=(None, 64), + ... bias_axes="d") + >>> input_tensor = tf.keras.Input(shape=[32, 128]) + >>> output_tensor = layer(input_tensor) + >>> output_tensor + + + **Applying a dense layer to a sequence using ellipses** + + This example shows how to instantiate a layer that applies the same dense + operation to every element in a sequence, but uses the ellipsis notation + instead of specifying the batch and sequence dimensions. + + Because we are using ellipsis notation and have specified only one axis, the + output_shape arg is a single value. When instantiated in this way, the layer + can handle any number of sequence dimensions - including the case where no + sequence dimension exists. + + >>> layer = EinsumDense("...x,xy->...y", output_shape=64, bias_axes="y") + >>> input_tensor = tf.keras.Input(shape=[32, 128]) + >>> output_tensor = layer(input_tensor) + >>> output_tensor + + """ + + def __init__(self, + equation, + output_shape, + activation=None, + bias_axes=None, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs): + super(EinsumDense, self).__init__(**kwargs) + self.equation = equation + if isinstance(output_shape, int): + self.partial_output_shape = [output_shape] + else: + self.partial_output_shape = list(output_shape) + self.bias_axes = bias_axes + self.activation = activations.get(activation) + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.kernel_constraint = constraints.get(kernel_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + shape_data = _analyze_einsum_string(self.equation, + self.bias_axes, + input_shape, + self.partial_output_shape) + kernel_shape, bias_shape, self.full_output_shape = shape_data + self.kernel = self.add_weight( + "kernel", + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=True) + + if bias_shape is not None: + self.bias = self.add_weight( + "bias", + shape=bias_shape, + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint, + dtype=self.dtype, + trainable=True) + else: + self.bias = None + super(EinsumDense, self).build(input_shape) + + def compute_output_shape(self, _): + return tensor_shape.TensorShape(self.full_output_shape) + + def get_config(self): + config = { + "output_shape": + self.partial_output_shape, + "equation": + self.equation, + "activation": + activations.serialize(self.activation), + "bias_axes": + self.bias_axes, + "kernel_initializer": + initializers.serialize(self.kernel_initializer), + "bias_initializer": + initializers.serialize(self.bias_initializer), + "kernel_regularizer": + regularizers.serialize(self.kernel_regularizer), + "bias_regularizer": + regularizers.serialize(self.bias_regularizer), + "activity_regularizer": + regularizers.serialize(self.activity_regularizer), + "kernel_constraint": + constraints.serialize(self.kernel_constraint), + "bias_constraint": + constraints.serialize(self.bias_constraint), + } + base_config = super(EinsumDense, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def call(self, inputs): + ret = special_math_ops.einsum(self.equation, inputs, self.kernel) + if self.bias is not None: + ret += self.bias + if self.activation is not None: + ret = self.activation(ret) + return ret + + +def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape): + """Analyzes an einsum string to determine the required weight shape.""" + + dot_replaced_string = re.sub(r"\.\.\.", "0", equation) + + # This is the case where no ellipses are present in the string. + split_string = re.match("([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", + dot_replaced_string) + if split_string: + return _analyze_split_string(split_string, bias_axes, input_shape, + output_shape) + + # This is the case where ellipses are present on the left. + split_string = re.match("0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", + dot_replaced_string) + if split_string: + return _analyze_split_string( + split_string, bias_axes, input_shape, output_shape, left_elided=True) + + # This is the case where ellipses are present on the right. + split_string = re.match("([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", + dot_replaced_string) + if split_string: + return _analyze_split_string(split_string, bias_axes, input_shape, + output_shape) + + raise ValueError( + "Invalid einsum equation '%s'. Equations must be in the form " + "[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." % equation) + + +def _analyze_split_string(split_string, + bias_axes, + input_shape, + output_shape, + left_elided=False): + """Analyze an pre-split einsum string to find the weight shape.""" + input_spec = split_string.group(1) + weight_spec = split_string.group(2) + output_spec = split_string.group(3) + elided = len(input_shape) - len(input_spec) + + if isinstance(output_shape, int): + output_shape = [output_shape] + else: + output_shape = list(output_shape) + + output_shape.insert(0, input_shape[0]) + + if elided > 0 and left_elided: + for i in range(1, elided): + # We already inserted the 0th input dimension at dim 0, so we need to + # start at location 1 here. + output_shape.insert(1, input_shape[i]) + elif elided > 0 and not left_elided: + for i in range(len(input_shape) - elided, len(input_shape)): + output_shape.append(input_shape[i]) + + if left_elided: + # If we have beginning dimensions elided, we need to use negative indexing + # to determine where in the input dimension our values are. + input_dim_map = { + dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec) + } + # Because we've constructed the full output shape already, we don't need + # to do negative indexing. + output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)} + else: + input_dim_map = {dim: i for i, dim in enumerate(input_spec)} + output_dim_map = {dim: i for i, dim in enumerate(output_spec)} + + for i, dim in enumerate(input_spec): + input_shape_at_dim = input_shape[i] + if dim in output_dim_map: + output_shape_at_dim = output_shape[output_dim_map[dim]] + if (output_shape_at_dim is not None and + output_shape_at_dim != input_shape_at_dim): + raise ValueError( + "Input shape and output shape do not match at shared " + "dimension '%s'. Input shape is %s, and output shape " + "is %s." % + (dim, input_shape_at_dim, output_shape[output_dim_map[dim]])) + + for dim in output_spec: + if dim not in input_spec and dim not in weight_spec: + raise ValueError("Dimension '%s' was specified in the output '%s' but " + "has no corresponding dim in the input spec '%s' or " + "weight spec '%s.'" % (dim, output_spec, input_spec, + output_spec)) + + weight_shape = [] + for dim in weight_spec: + if dim in input_dim_map: + weight_shape.append(input_shape[input_dim_map[dim]]) + elif dim in output_dim_map: + weight_shape.append(output_shape[output_dim_map[dim]]) + else: + raise ValueError("Weight dimension '%s' did not have a match in either " + "the input spec '%s' or the output spec '%s'. For this " + "layer, the weight must be fully specified." % + (dim, input_spec, output_spec)) + + if bias_axes is not None: + num_left_elided = elided if left_elided else 0 + idx_map = { + char: output_shape[i + num_left_elided] + for i, char in enumerate(output_spec) + } + + for char in bias_axes: + if char not in output_spec: + raise ValueError("Bias dimension '%s' was requested, but is not a part " + "of the output specification '%s'" % + (char, output_spec)) + + first_bias_location = min([output_spec.find(char) for char in bias_axes]) + bias_output_spec = output_spec[first_bias_location:] + + bias_shape = [ + idx_map[char] if char in bias_axes else 1 for char in bias_output_spec + ] + + if not left_elided: + for _ in range(elided): + bias_shape.append(1) + else: + bias_shape = None + + return weight_shape, bias_shape, output_shape diff --git a/tensorflow/python/keras/layers/einsum_dense_test.py b/tensorflow/python/keras/layers/einsum_dense_test.py new file mode 100644 index 00000000000..e9ae7271130 --- /dev/null +++ b/tensorflow/python/keras/layers/einsum_dense_test.py @@ -0,0 +1,315 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Keras-based einsum dense layer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +import numpy as np + +from tensorflow.python import keras + +from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.layers import einsum_dense +from tensorflow.python.platform import test + + +@keras_parameterized.run_all_keras_modes +@parameterized.named_parameters( + { + "testcase_name": "_1d_end_weight", + "equation": "ab,b->a", + "bias_axes": None, + "input_shape": (None, 32), + "output_shape": [], + "expected_weight_shape": [32], + "expected_bias_shape": None, + "expected_output_shape": (None,) + }, { + "testcase_name": "_2d_middle_weight", + "equation": "ab,bc->ac", + "bias_axes": None, + "input_shape": (None, 32), + "output_shape": (64), + "expected_weight_shape": [32, 64], + "expected_bias_shape": None, + "expected_output_shape": (None, 64) + }, { + "testcase_name": "_3d_bert", + "equation": "abc,cde->abde", + "bias_axes": None, + "input_shape": (None, 1, 2), + "output_shape": (1, 3, 4), + "expected_weight_shape": [2, 3, 4], + "expected_bias_shape": None, + "expected_output_shape": (None, 1, 3, 4) + }, { + "testcase_name": "_3d_3_bias", + "equation": "abc,cde->abde", + "bias_axes": "e", + "input_shape": (None, 1, 2), + "output_shape": (1, 3, 4), + "expected_weight_shape": [2, 3, 4], + "expected_bias_shape": [4], + "expected_output_shape": (None, 1, 3, 4) + }, { + "testcase_name": "_3d_2_bias", + "equation": "abc,cde->abde", + "bias_axes": "d", + "input_shape": (None, 1, 2), + "output_shape": (1, 3, 4), + "expected_weight_shape": [2, 3, 4], + "expected_bias_shape": [3, 1], + "expected_output_shape": (None, 1, 3, 4) + }, { + "testcase_name": "_3d_1_3_bias", + "equation": "abc,cde->abde", + "bias_axes": "be", + "input_shape": (None, 7, 2), + "output_shape": (7, 3, 4), + "expected_weight_shape": [2, 3, 4], + "expected_bias_shape": [7, 1, 4], + "expected_output_shape": (None, 7, 3, 4) + }, { + "testcase_name": "_3d_bert_projection", + "equation": "BFNH,NHD->BFD", + "bias_axes": None, + "input_shape": (None, 1, 2, 3), + "output_shape": (1, 4), + "expected_weight_shape": [2, 3, 4], + "expected_bias_shape": None, + "expected_output_shape": (None, 1, 4) + }, { + "testcase_name": "_2d_bert", + "equation": "abc,cd->abd", + "bias_axes": None, + "input_shape": (None, 1, 2), + "output_shape": (1, 4), + "expected_weight_shape": [2, 4], + "expected_bias_shape": None, + "expected_output_shape": (None, 1, 4) + }, { + "testcase_name": "_embedding_1d", + "equation": "i,d->id", + "bias_axes": None, + "input_shape": (None,), + "output_shape": (2), + "expected_weight_shape": [2], + "expected_bias_shape": None, + "expected_output_shape": (None, 2) + }, { + "testcase_name": "_xlnet_lm", + "equation": "ibd,nd->ibn", + "bias_axes": None, + "input_shape": (None, None, 1), + "output_shape": (None, 2), + "expected_weight_shape": [2, 1], + "expected_bias_shape": None, + "expected_output_shape": (None, None, 2) + }, { + "testcase_name": "_2d_precast", + "equation": "...b,bc->...c", + "bias_axes": None, + "input_shape": (None, 32), + "output_shape": (64), + "expected_weight_shape": [32, 64], + "expected_bias_shape": None, + "expected_output_shape": (None, 64) + }, { + "testcase_name": "_2d_precast_multiple_elided_dims", + "equation": "...b,bc->...c", + "bias_axes": None, + "input_shape": (None, None, 32), + "output_shape": (64), + "expected_weight_shape": [32, 64], + "expected_bias_shape": None, + "expected_output_shape": (None, None, 64) + }, { + "testcase_name": "_3d_precast", + "equation": "...c,cde->...de", + "bias_axes": None, + "input_shape": (None, 1, 2), + "output_shape": (3, 4), + "expected_weight_shape": [2, 3, 4], + "expected_bias_shape": None, + "expected_output_shape": (None, 1, 3, 4) + }, { + "testcase_name": "_3d_precast_3_bias", + "equation": "...c,cde->...de", + "bias_axes": "e", + "input_shape": (None, 1, 2), + "output_shape": (3, 4), + "expected_weight_shape": [2, 3, 4], + "expected_bias_shape": [4], + "expected_output_shape": (None, 1, 3, 4) + }, { + "testcase_name": "_3d_precast_2_bias", + "equation": "...c,cde->...de", + "bias_axes": "d", + "input_shape": (None, 1, 2), + "output_shape": (3, 4), + "expected_weight_shape": [2, 3, 4], + "expected_bias_shape": [3, 1], + "expected_output_shape": (None, 1, 3, 4) + }, { + "testcase_name": "_3d_precast_2_3_bias", + "equation": "...c,cde->...de", + "bias_axes": "de", + "input_shape": (None, 1, 2), + "output_shape": (3, 4), + "expected_weight_shape": [2, 3, 4], + "expected_bias_shape": [3, 4], + "expected_output_shape": (None, 1, 3, 4) + }, { + "testcase_name": "_2d_postcast", + "equation": "bc...,cd->bd...", + "bias_axes": None, + "input_shape": (None, 1, 2, 3), + "output_shape": (4), + "expected_weight_shape": [1, 4], + "expected_bias_shape": None, + "expected_output_shape": (None, 4, 2, 3) + }, { + "testcase_name": "_3d_postcast", + "equation": "bc...,cde->bde...", + "bias_axes": None, + "input_shape": (None, 1, 2), + "output_shape": (3, 4), + "expected_weight_shape": [1, 3, 4], + "expected_bias_shape": None, + "expected_output_shape": (None, 3, 4, 2) + }, { + "testcase_name": "_3d_postcast_1_bias", + "equation": "bc...,cde->bde...", + "bias_axes": "d", + "input_shape": (None, 1, 2), + "output_shape": (3, 4), + "expected_weight_shape": [1, 3, 4], + "expected_bias_shape": [3, 1, 1], + "expected_output_shape": (None, 3, 4, 2) + }, { + "testcase_name": "_3d_postcast_2_bias", + "equation": "bc...,cde->bde...", + "bias_axes": "e", + "input_shape": (None, 1, 2), + "output_shape": (3, 4), + "expected_weight_shape": [1, 3, 4], + "expected_bias_shape": [4, 1], + "expected_output_shape": (None, 3, 4, 2) + }, { + "testcase_name": "_3d_postcast_1_2_bias", + "equation": "bc...,cde->bde...", + "bias_axes": "de", + "input_shape": (None, 1, 2), + "output_shape": (3, 4), + "expected_weight_shape": [1, 3, 4], + "expected_bias_shape": [3, 4, 1], + "expected_output_shape": (None, 3, 4, 2) + }) +class TestEinsumDenseLayer(keras_parameterized.TestCase): + + def test_weight_shapes(self, equation, bias_axes, input_shape, output_shape, + expected_weight_shape, expected_bias_shape, + expected_output_shape): + del expected_output_shape # Not used in this test. + + weight_shape, bias_shape, _ = einsum_dense._analyze_einsum_string( + equation, bias_axes, input_shape, output_shape) + + self.assertAllEqual(expected_weight_shape, weight_shape) + self.assertAllEqual(expected_bias_shape, bias_shape) + + def test_layer_creation(self, equation, bias_axes, input_shape, output_shape, + expected_weight_shape, expected_bias_shape, + expected_output_shape): + # Keras elides the 0-dimension of the input shape when constructing inputs. + non_batch_input_shape = list(input_shape)[1:] + + input_tensor = keras.Input(shape=non_batch_input_shape) + layer = einsum_dense.EinsumDense( + equation=equation, output_shape=output_shape, bias_axes=bias_axes) + output_tensor = layer(input_tensor) + + self.assertAllEqual(expected_weight_shape, layer.kernel.shape.as_list()) + if expected_bias_shape is None: + self.assertIsNone(layer.bias) + else: + self.assertAllEqual(expected_bias_shape, layer.bias.shape.as_list()) + self.assertAllEqual(expected_output_shape, output_tensor.shape.as_list()) + + +@keras_parameterized.run_all_keras_modes +class TestEinsumLayerAPI(keras_parameterized.TestCase): + + def test_layer_api(self): + input_data = np.array([[1.0, 2.0], [3.0, 4.0]]) + kwargs = { + "equation": "...b,bc->...c", + "bias_axes": "c", + "output_shape": 4, + "bias_initializer": keras.initializers.constant(0.03), + "kernel_initializer": keras.initializers.constant(0.5), + "dtype": input_data.dtype + } + expected_output = np.array([[1.53, 1.53, 1.53, 1.53], + [3.53, 3.53, 3.53, 3.53]]) + + output_data = testing_utils.layer_test( + einsum_dense.EinsumDense, + kwargs=kwargs, + input_shape=(None, 2), + input_data=input_data) + + self.assertAllClose(expected_output, output_data) + + def test_unspecified_bias_dim_fails(self): + input_tensor = keras.Input(shape=(32,)) + layer = einsum_dense.EinsumDense( + equation="ab,bc->ac", output_shape=64, bias_axes="y") + with self.assertRaisesRegexp( + ValueError, ".*is not a part of the output specification.*"): + _ = layer(input_tensor) + + def test_incompatible_input_output_shape_fails(self): + input_tensor = keras.Input(shape=(32, 64)) + layer = einsum_dense.EinsumDense( + equation="abc,cd->abd", output_shape=(10, 96)) + with self.assertRaisesRegexp( + ValueError, ".*Input shape and output shape do not match at shared " + "dimension 'b'.*"): + _ = layer(input_tensor) + + def test_unspecified_output_dim_fails(self): + input_tensor = keras.Input(shape=(32,)) + layer = einsum_dense.EinsumDense(equation="ab,bc->cd", output_shape=64) + with self.assertRaisesRegexp( + ValueError, ".*Dimension 'd' was specified in the output 'cd' but has " + "no corresponding dim.*"): + _ = layer(input_tensor) + + def test_unspecified_weight_dim_fails(self): + input_tensor = keras.Input(shape=(32,)) + layer = einsum_dense.EinsumDense(equation="ab,zd->ad", output_shape=64) + with self.assertRaisesRegexp( + ValueError, ".*Weight dimension 'z' did not have a match "): + _ = layer(input_tensor) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/keras/layers/serialization.py b/tensorflow/python/keras/layers/serialization.py index f74bf51aae0..a0056d82ab9 100644 --- a/tensorflow/python/keras/layers/serialization.py +++ b/tensorflow/python/keras/layers/serialization.py @@ -33,6 +33,7 @@ from tensorflow.python.keras.layers import convolutional_recurrent from tensorflow.python.keras.layers import core from tensorflow.python.keras.layers import cudnn_recurrent from tensorflow.python.keras.layers import dense_attention +from tensorflow.python.keras.layers import einsum_dense from tensorflow.python.keras.layers import embeddings from tensorflow.python.keras.layers import local from tensorflow.python.keras.layers import merge @@ -52,26 +53,11 @@ from tensorflow.python.util import tf_inspect as inspect from tensorflow.python.util.tf_export import keras_export -ALL_MODULES = ( - base_layer, - input_layer, - advanced_activations, - convolutional, - convolutional_recurrent, - core, - cudnn_recurrent, - dense_attention, - embeddings, - local, - merge, - noise, - normalization, - pooling, - image_preprocessing, - preprocessing_normalization_v1, - recurrent, - wrappers -) +ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional, + convolutional_recurrent, core, cudnn_recurrent, dense_attention, + embeddings, einsum_dense, local, merge, noise, normalization, + pooling, image_preprocessing, preprocessing_normalization_v1, + recurrent, wrappers) ALL_V2_MODULES = ( rnn_cell_wrapper_v2, normalization_v2, diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt new file mode 100644 index 00000000000..8a782f6666f --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt @@ -0,0 +1,218 @@ +path: "tensorflow.keras.layers.experimental.EinsumDense" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'equation\', \'output_shape\', \'activation\', \'bias_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.pbtxt index 67d9ef6bbcc..81d2acbd71f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.layers.experimental" tf_module { + member { + name: "EinsumDense" + mtype: "" + } member { name: "RandomFourierFeatures" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt new file mode 100644 index 00000000000..8a782f6666f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-einsum-dense.pbtxt @@ -0,0 +1,218 @@ +path: "tensorflow.keras.layers.experimental.EinsumDense" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "activity_regularizer" + mtype: "" + } + member { + name: "dtype" + mtype: "" + } + member { + name: "dynamic" + mtype: "" + } + member { + name: "inbound_nodes" + mtype: "" + } + member { + name: "input" + mtype: "" + } + member { + name: "input_mask" + mtype: "" + } + member { + name: "input_shape" + mtype: "" + } + member { + name: "input_spec" + mtype: "" + } + member { + name: "losses" + mtype: "" + } + member { + name: "metrics" + mtype: "" + } + member { + name: "name" + mtype: "" + } + member { + name: "name_scope" + mtype: "" + } + member { + name: "non_trainable_variables" + mtype: "" + } + member { + name: "non_trainable_weights" + mtype: "" + } + member { + name: "outbound_nodes" + mtype: "" + } + member { + name: "output" + mtype: "" + } + member { + name: "output_mask" + mtype: "" + } + member { + name: "output_shape" + mtype: "" + } + member { + name: "stateful" + mtype: "" + } + member { + name: "submodules" + mtype: "" + } + member { + name: "trainable" + mtype: "" + } + member { + name: "trainable_variables" + mtype: "" + } + member { + name: "trainable_weights" + mtype: "" + } + member { + name: "updates" + mtype: "" + } + member { + name: "variables" + mtype: "" + } + member { + name: "weights" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'equation\', \'output_shape\', \'activation\', \'bias_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "add_loss" + argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None" + } + member_method { + name: "add_metric" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " + } + member_method { + name: "add_update" + argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "add_variable" + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " + } + member_method { + name: "apply" + argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "build" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_mask" + argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "compute_output_shape" + argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "compute_output_signature" + argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "count_params" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_config" + argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_config" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_input_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_losses_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_mask_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_output_shape_at" + argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_updates_for" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_weights" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "set_weights" + argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "with_name_scope" + argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.pbtxt index 75c73ca2018..53d4adbed30 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.layers.experimental" tf_module { + member { + name: "EinsumDense" + mtype: "" + } member { name: "RandomFourierFeatures" mtype: ""