Add an einsum layer.

PiperOrigin-RevId: 309231226
Change-Id: Icbadc1f420e06ea2222d519e93ac8fe377d08527
This commit is contained in:
A. Unique TensorFlower 2020-04-30 08:35:16 -07:00 committed by TensorFlower Gardener
parent 63aa02c5f3
commit 16748305e8
9 changed files with 1134 additions and 20 deletions

View File

@ -34,6 +34,7 @@ py_library(
":core",
":cudnn_recurrent",
":dense_attention",
":einsum_dense",
":embeddings",
":kernelized",
":local",
@ -187,6 +188,22 @@ py_library(
],
)
py_library(
name = "einsum_dense",
srcs = ["einsum_dense.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:special_math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/keras:activations",
"//tensorflow/python/keras:base_layer",
"//tensorflow/python/keras:constraints",
"//tensorflow/python/keras:initializers",
"//tensorflow/python/keras:regularizers",
],
)
py_library(
name = "embeddings",
srcs = ["embeddings.py"],
@ -581,6 +598,18 @@ cuda_py_test(
],
)
tf_py_test(
name = "einsum_dense_test",
srcs = ["einsum_dense_test.py"],
python_version = "PY3",
deps = [
":einsum_dense",
"//tensorflow/python:client_testlib",
"//tensorflow/python/keras",
"@absl_py//absl/testing:parameterized",
],
)
tf_py_test(
name = "local_test",
size = "medium",

View File

@ -119,6 +119,9 @@ from tensorflow.python.keras.layers.dense_attention import Attention
# Embedding layers.
from tensorflow.python.keras.layers.embeddings import Embedding
# Einsum-based dense layer/
from tensorflow.python.keras.layers.einsum_dense import EinsumDense
# Locally-connected layers.
from tensorflow.python.keras.layers.local import LocallyConnected1D
from tensorflow.python.keras.layers.local import LocallyConnected2D

View File

@ -0,0 +1,337 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based einsum dense layer."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import activations
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.ops import special_math_ops
from tensorflow.python.util.tf_export import keras_export
@keras_export("keras.layers.experimental.EinsumDense")
class EinsumDense(Layer):
"""A layer that uses tf.einsum as the backing computation.
This layer can perform einsum calculations of arbitrary dimensionality.
Arguments:
equation: An equation describing the einsum to perform. This equation must
be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or
`ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis
expression sequence.
output_shape: The expected shape of the output tensor (excluding the batch
dimension and any dimensions represented by ellipses). You can specify
None for any dimension that is unknown or can be inferred from the input
shape.
activation: Activation function to use. If you don't specify anything, no
activation is applied (that is, a "linear" activation: `a(x) = x`).
bias_axes: A string containing the output dimension(s) to apply a bias to.
Each character in the `bias_axes` string should correspond to a character
in the output portion of the `equation` string.
kernel_initializer: Initializer for the `kernel` weights matrix.
bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to the `kernel` weights
matrix.
bias_regularizer: Regularizer function applied to the bias vector.
activity_regularizer: Regularizer function applied to the output of the
layer (its "activation")..
kernel_constraint: Constraint function applied to the `kernel` weights
matrix.
bias_constraint: Constraint function applied to the bias vector.
Examples:
**Biased dense layer with einsums**
This example shows how to instantiate a standard Keras dense layer using
einsum operations. This example is equivalent to
`tf.keras.layers.Dense(64, use_bias=True)`.
>>> layer = EinsumDense("ab,bc->ac", output_shape=64, bias_axes="c")
>>> input_tensor = tf.keras.Input(shape=[32])
>>> output_tensor = layer(input_tensor)
>>> output_tensor
<tf.Tensor '...' shape=(None, 64) dtype=...>
**Applying a dense layer to a sequence**
This example shows how to instantiate a layer that applies the same dense
operation to every element in a sequence. Here, the 'output_shape' has two
values (since there are two non-batch dimensions in the output); the first
dimension in the output_shape is `None`, because the sequence dimension `b`
has an unknown shape.
>>> layer = EinsumDense("abc,cd->abd",
... output_shape=(None, 64),
... bias_axes="d")
>>> input_tensor = tf.keras.Input(shape=[32, 128])
>>> output_tensor = layer(input_tensor)
>>> output_tensor
<tf.Tensor '...' shape=(None, 32, 64) dtype=...>
**Applying a dense layer to a sequence using ellipses**
This example shows how to instantiate a layer that applies the same dense
operation to every element in a sequence, but uses the ellipsis notation
instead of specifying the batch and sequence dimensions.
Because we are using ellipsis notation and have specified only one axis, the
output_shape arg is a single value. When instantiated in this way, the layer
can handle any number of sequence dimensions - including the case where no
sequence dimension exists.
>>> layer = EinsumDense("...x,xy->...y", output_shape=64, bias_axes="y")
>>> input_tensor = tf.keras.Input(shape=[32, 128])
>>> output_tensor = layer(input_tensor)
>>> output_tensor
<tf.Tensor '...' shape=(None, 32, 64) dtype=...>
"""
def __init__(self,
equation,
output_shape,
activation=None,
bias_axes=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(EinsumDense, self).__init__(**kwargs)
self.equation = equation
if isinstance(output_shape, int):
self.partial_output_shape = [output_shape]
else:
self.partial_output_shape = list(output_shape)
self.bias_axes = bias_axes
self.activation = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
shape_data = _analyze_einsum_string(self.equation,
self.bias_axes,
input_shape,
self.partial_output_shape)
kernel_shape, bias_shape, self.full_output_shape = shape_data
self.kernel = self.add_weight(
"kernel",
shape=kernel_shape,
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
dtype=self.dtype,
trainable=True)
if bias_shape is not None:
self.bias = self.add_weight(
"bias",
shape=bias_shape,
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint,
dtype=self.dtype,
trainable=True)
else:
self.bias = None
super(EinsumDense, self).build(input_shape)
def compute_output_shape(self, _):
return tensor_shape.TensorShape(self.full_output_shape)
def get_config(self):
config = {
"output_shape":
self.partial_output_shape,
"equation":
self.equation,
"activation":
activations.serialize(self.activation),
"bias_axes":
self.bias_axes,
"kernel_initializer":
initializers.serialize(self.kernel_initializer),
"bias_initializer":
initializers.serialize(self.bias_initializer),
"kernel_regularizer":
regularizers.serialize(self.kernel_regularizer),
"bias_regularizer":
regularizers.serialize(self.bias_regularizer),
"activity_regularizer":
regularizers.serialize(self.activity_regularizer),
"kernel_constraint":
constraints.serialize(self.kernel_constraint),
"bias_constraint":
constraints.serialize(self.bias_constraint),
}
base_config = super(EinsumDense, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
ret = special_math_ops.einsum(self.equation, inputs, self.kernel)
if self.bias is not None:
ret += self.bias
if self.activation is not None:
ret = self.activation(ret)
return ret
def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
"""Analyzes an einsum string to determine the required weight shape."""
dot_replaced_string = re.sub(r"\.\.\.", "0", equation)
# This is the case where no ellipses are present in the string.
split_string = re.match("([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)",
dot_replaced_string)
if split_string:
return _analyze_split_string(split_string, bias_axes, input_shape,
output_shape)
# This is the case where ellipses are present on the left.
split_string = re.match("0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)",
dot_replaced_string)
if split_string:
return _analyze_split_string(
split_string, bias_axes, input_shape, output_shape, left_elided=True)
# This is the case where ellipses are present on the right.
split_string = re.match("([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0",
dot_replaced_string)
if split_string:
return _analyze_split_string(split_string, bias_axes, input_shape,
output_shape)
raise ValueError(
"Invalid einsum equation '%s'. Equations must be in the form "
"[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." % equation)
def _analyze_split_string(split_string,
bias_axes,
input_shape,
output_shape,
left_elided=False):
"""Analyze an pre-split einsum string to find the weight shape."""
input_spec = split_string.group(1)
weight_spec = split_string.group(2)
output_spec = split_string.group(3)
elided = len(input_shape) - len(input_spec)
if isinstance(output_shape, int):
output_shape = [output_shape]
else:
output_shape = list(output_shape)
output_shape.insert(0, input_shape[0])
if elided > 0 and left_elided:
for i in range(1, elided):
# We already inserted the 0th input dimension at dim 0, so we need to
# start at location 1 here.
output_shape.insert(1, input_shape[i])
elif elided > 0 and not left_elided:
for i in range(len(input_shape) - elided, len(input_shape)):
output_shape.append(input_shape[i])
if left_elided:
# If we have beginning dimensions elided, we need to use negative indexing
# to determine where in the input dimension our values are.
input_dim_map = {
dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec)
}
# Because we've constructed the full output shape already, we don't need
# to do negative indexing.
output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)}
else:
input_dim_map = {dim: i for i, dim in enumerate(input_spec)}
output_dim_map = {dim: i for i, dim in enumerate(output_spec)}
for i, dim in enumerate(input_spec):
input_shape_at_dim = input_shape[i]
if dim in output_dim_map:
output_shape_at_dim = output_shape[output_dim_map[dim]]
if (output_shape_at_dim is not None and
output_shape_at_dim != input_shape_at_dim):
raise ValueError(
"Input shape and output shape do not match at shared "
"dimension '%s'. Input shape is %s, and output shape "
"is %s." %
(dim, input_shape_at_dim, output_shape[output_dim_map[dim]]))
for dim in output_spec:
if dim not in input_spec and dim not in weight_spec:
raise ValueError("Dimension '%s' was specified in the output '%s' but "
"has no corresponding dim in the input spec '%s' or "
"weight spec '%s.'" % (dim, output_spec, input_spec,
output_spec))
weight_shape = []
for dim in weight_spec:
if dim in input_dim_map:
weight_shape.append(input_shape[input_dim_map[dim]])
elif dim in output_dim_map:
weight_shape.append(output_shape[output_dim_map[dim]])
else:
raise ValueError("Weight dimension '%s' did not have a match in either "
"the input spec '%s' or the output spec '%s'. For this "
"layer, the weight must be fully specified." %
(dim, input_spec, output_spec))
if bias_axes is not None:
num_left_elided = elided if left_elided else 0
idx_map = {
char: output_shape[i + num_left_elided]
for i, char in enumerate(output_spec)
}
for char in bias_axes:
if char not in output_spec:
raise ValueError("Bias dimension '%s' was requested, but is not a part "
"of the output specification '%s'" %
(char, output_spec))
first_bias_location = min([output_spec.find(char) for char in bias_axes])
bias_output_spec = output_spec[first_bias_location:]
bias_shape = [
idx_map[char] if char in bias_axes else 1 for char in bias_output_spec
]
if not left_elided:
for _ in range(elided):
bias_shape.append(1)
else:
bias_shape = None
return weight_shape, bias_shape, output_shape

View File

@ -0,0 +1,315 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras-based einsum dense layer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python import keras
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.layers import einsum_dense
from tensorflow.python.platform import test
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters(
{
"testcase_name": "_1d_end_weight",
"equation": "ab,b->a",
"bias_axes": None,
"input_shape": (None, 32),
"output_shape": [],
"expected_weight_shape": [32],
"expected_bias_shape": None,
"expected_output_shape": (None,)
}, {
"testcase_name": "_2d_middle_weight",
"equation": "ab,bc->ac",
"bias_axes": None,
"input_shape": (None, 32),
"output_shape": (64),
"expected_weight_shape": [32, 64],
"expected_bias_shape": None,
"expected_output_shape": (None, 64)
}, {
"testcase_name": "_3d_bert",
"equation": "abc,cde->abde",
"bias_axes": None,
"input_shape": (None, 1, 2),
"output_shape": (1, 3, 4),
"expected_weight_shape": [2, 3, 4],
"expected_bias_shape": None,
"expected_output_shape": (None, 1, 3, 4)
}, {
"testcase_name": "_3d_3_bias",
"equation": "abc,cde->abde",
"bias_axes": "e",
"input_shape": (None, 1, 2),
"output_shape": (1, 3, 4),
"expected_weight_shape": [2, 3, 4],
"expected_bias_shape": [4],
"expected_output_shape": (None, 1, 3, 4)
}, {
"testcase_name": "_3d_2_bias",
"equation": "abc,cde->abde",
"bias_axes": "d",
"input_shape": (None, 1, 2),
"output_shape": (1, 3, 4),
"expected_weight_shape": [2, 3, 4],
"expected_bias_shape": [3, 1],
"expected_output_shape": (None, 1, 3, 4)
}, {
"testcase_name": "_3d_1_3_bias",
"equation": "abc,cde->abde",
"bias_axes": "be",
"input_shape": (None, 7, 2),
"output_shape": (7, 3, 4),
"expected_weight_shape": [2, 3, 4],
"expected_bias_shape": [7, 1, 4],
"expected_output_shape": (None, 7, 3, 4)
}, {
"testcase_name": "_3d_bert_projection",
"equation": "BFNH,NHD->BFD",
"bias_axes": None,
"input_shape": (None, 1, 2, 3),
"output_shape": (1, 4),
"expected_weight_shape": [2, 3, 4],
"expected_bias_shape": None,
"expected_output_shape": (None, 1, 4)
}, {
"testcase_name": "_2d_bert",
"equation": "abc,cd->abd",
"bias_axes": None,
"input_shape": (None, 1, 2),
"output_shape": (1, 4),
"expected_weight_shape": [2, 4],
"expected_bias_shape": None,
"expected_output_shape": (None, 1, 4)
}, {
"testcase_name": "_embedding_1d",
"equation": "i,d->id",
"bias_axes": None,
"input_shape": (None,),
"output_shape": (2),
"expected_weight_shape": [2],
"expected_bias_shape": None,
"expected_output_shape": (None, 2)
}, {
"testcase_name": "_xlnet_lm",
"equation": "ibd,nd->ibn",
"bias_axes": None,
"input_shape": (None, None, 1),
"output_shape": (None, 2),
"expected_weight_shape": [2, 1],
"expected_bias_shape": None,
"expected_output_shape": (None, None, 2)
}, {
"testcase_name": "_2d_precast",
"equation": "...b,bc->...c",
"bias_axes": None,
"input_shape": (None, 32),
"output_shape": (64),
"expected_weight_shape": [32, 64],
"expected_bias_shape": None,
"expected_output_shape": (None, 64)
}, {
"testcase_name": "_2d_precast_multiple_elided_dims",
"equation": "...b,bc->...c",
"bias_axes": None,
"input_shape": (None, None, 32),
"output_shape": (64),
"expected_weight_shape": [32, 64],
"expected_bias_shape": None,
"expected_output_shape": (None, None, 64)
}, {
"testcase_name": "_3d_precast",
"equation": "...c,cde->...de",
"bias_axes": None,
"input_shape": (None, 1, 2),
"output_shape": (3, 4),
"expected_weight_shape": [2, 3, 4],
"expected_bias_shape": None,
"expected_output_shape": (None, 1, 3, 4)
}, {
"testcase_name": "_3d_precast_3_bias",
"equation": "...c,cde->...de",
"bias_axes": "e",
"input_shape": (None, 1, 2),
"output_shape": (3, 4),
"expected_weight_shape": [2, 3, 4],
"expected_bias_shape": [4],
"expected_output_shape": (None, 1, 3, 4)
}, {
"testcase_name": "_3d_precast_2_bias",
"equation": "...c,cde->...de",
"bias_axes": "d",
"input_shape": (None, 1, 2),
"output_shape": (3, 4),
"expected_weight_shape": [2, 3, 4],
"expected_bias_shape": [3, 1],
"expected_output_shape": (None, 1, 3, 4)
}, {
"testcase_name": "_3d_precast_2_3_bias",
"equation": "...c,cde->...de",
"bias_axes": "de",
"input_shape": (None, 1, 2),
"output_shape": (3, 4),
"expected_weight_shape": [2, 3, 4],
"expected_bias_shape": [3, 4],
"expected_output_shape": (None, 1, 3, 4)
}, {
"testcase_name": "_2d_postcast",
"equation": "bc...,cd->bd...",
"bias_axes": None,
"input_shape": (None, 1, 2, 3),
"output_shape": (4),
"expected_weight_shape": [1, 4],
"expected_bias_shape": None,
"expected_output_shape": (None, 4, 2, 3)
}, {
"testcase_name": "_3d_postcast",
"equation": "bc...,cde->bde...",
"bias_axes": None,
"input_shape": (None, 1, 2),
"output_shape": (3, 4),
"expected_weight_shape": [1, 3, 4],
"expected_bias_shape": None,
"expected_output_shape": (None, 3, 4, 2)
}, {
"testcase_name": "_3d_postcast_1_bias",
"equation": "bc...,cde->bde...",
"bias_axes": "d",
"input_shape": (None, 1, 2),
"output_shape": (3, 4),
"expected_weight_shape": [1, 3, 4],
"expected_bias_shape": [3, 1, 1],
"expected_output_shape": (None, 3, 4, 2)
}, {
"testcase_name": "_3d_postcast_2_bias",
"equation": "bc...,cde->bde...",
"bias_axes": "e",
"input_shape": (None, 1, 2),
"output_shape": (3, 4),
"expected_weight_shape": [1, 3, 4],
"expected_bias_shape": [4, 1],
"expected_output_shape": (None, 3, 4, 2)
}, {
"testcase_name": "_3d_postcast_1_2_bias",
"equation": "bc...,cde->bde...",
"bias_axes": "de",
"input_shape": (None, 1, 2),
"output_shape": (3, 4),
"expected_weight_shape": [1, 3, 4],
"expected_bias_shape": [3, 4, 1],
"expected_output_shape": (None, 3, 4, 2)
})
class TestEinsumDenseLayer(keras_parameterized.TestCase):
def test_weight_shapes(self, equation, bias_axes, input_shape, output_shape,
expected_weight_shape, expected_bias_shape,
expected_output_shape):
del expected_output_shape # Not used in this test.
weight_shape, bias_shape, _ = einsum_dense._analyze_einsum_string(
equation, bias_axes, input_shape, output_shape)
self.assertAllEqual(expected_weight_shape, weight_shape)
self.assertAllEqual(expected_bias_shape, bias_shape)
def test_layer_creation(self, equation, bias_axes, input_shape, output_shape,
expected_weight_shape, expected_bias_shape,
expected_output_shape):
# Keras elides the 0-dimension of the input shape when constructing inputs.
non_batch_input_shape = list(input_shape)[1:]
input_tensor = keras.Input(shape=non_batch_input_shape)
layer = einsum_dense.EinsumDense(
equation=equation, output_shape=output_shape, bias_axes=bias_axes)
output_tensor = layer(input_tensor)
self.assertAllEqual(expected_weight_shape, layer.kernel.shape.as_list())
if expected_bias_shape is None:
self.assertIsNone(layer.bias)
else:
self.assertAllEqual(expected_bias_shape, layer.bias.shape.as_list())
self.assertAllEqual(expected_output_shape, output_tensor.shape.as_list())
@keras_parameterized.run_all_keras_modes
class TestEinsumLayerAPI(keras_parameterized.TestCase):
def test_layer_api(self):
input_data = np.array([[1.0, 2.0], [3.0, 4.0]])
kwargs = {
"equation": "...b,bc->...c",
"bias_axes": "c",
"output_shape": 4,
"bias_initializer": keras.initializers.constant(0.03),
"kernel_initializer": keras.initializers.constant(0.5),
"dtype": input_data.dtype
}
expected_output = np.array([[1.53, 1.53, 1.53, 1.53],
[3.53, 3.53, 3.53, 3.53]])
output_data = testing_utils.layer_test(
einsum_dense.EinsumDense,
kwargs=kwargs,
input_shape=(None, 2),
input_data=input_data)
self.assertAllClose(expected_output, output_data)
def test_unspecified_bias_dim_fails(self):
input_tensor = keras.Input(shape=(32,))
layer = einsum_dense.EinsumDense(
equation="ab,bc->ac", output_shape=64, bias_axes="y")
with self.assertRaisesRegexp(
ValueError, ".*is not a part of the output specification.*"):
_ = layer(input_tensor)
def test_incompatible_input_output_shape_fails(self):
input_tensor = keras.Input(shape=(32, 64))
layer = einsum_dense.EinsumDense(
equation="abc,cd->abd", output_shape=(10, 96))
with self.assertRaisesRegexp(
ValueError, ".*Input shape and output shape do not match at shared "
"dimension 'b'.*"):
_ = layer(input_tensor)
def test_unspecified_output_dim_fails(self):
input_tensor = keras.Input(shape=(32,))
layer = einsum_dense.EinsumDense(equation="ab,bc->cd", output_shape=64)
with self.assertRaisesRegexp(
ValueError, ".*Dimension 'd' was specified in the output 'cd' but has "
"no corresponding dim.*"):
_ = layer(input_tensor)
def test_unspecified_weight_dim_fails(self):
input_tensor = keras.Input(shape=(32,))
layer = einsum_dense.EinsumDense(equation="ab,zd->ad", output_shape=64)
with self.assertRaisesRegexp(
ValueError, ".*Weight dimension 'z' did not have a match "):
_ = layer(input_tensor)
if __name__ == "__main__":
test.main()

View File

@ -33,6 +33,7 @@ from tensorflow.python.keras.layers import convolutional_recurrent
from tensorflow.python.keras.layers import core
from tensorflow.python.keras.layers import cudnn_recurrent
from tensorflow.python.keras.layers import dense_attention
from tensorflow.python.keras.layers import einsum_dense
from tensorflow.python.keras.layers import embeddings
from tensorflow.python.keras.layers import local
from tensorflow.python.keras.layers import merge
@ -52,26 +53,11 @@ from tensorflow.python.util import tf_inspect as inspect
from tensorflow.python.util.tf_export import keras_export
ALL_MODULES = (
base_layer,
input_layer,
advanced_activations,
convolutional,
convolutional_recurrent,
core,
cudnn_recurrent,
dense_attention,
embeddings,
local,
merge,
noise,
normalization,
pooling,
image_preprocessing,
preprocessing_normalization_v1,
recurrent,
wrappers
)
ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional,
convolutional_recurrent, core, cudnn_recurrent, dense_attention,
embeddings, einsum_dense, local, merge, noise, normalization,
pooling, image_preprocessing, preprocessing_normalization_v1,
recurrent, wrappers)
ALL_V2_MODULES = (
rnn_cell_wrapper_v2,
normalization_v2,

View File

@ -0,0 +1,218 @@
path: "tensorflow.keras.layers.experimental.EinsumDense"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.einsum_dense.EinsumDense\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_mask"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "metrics"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "outbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_mask"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'equation\', \'output_shape\', \'activation\', \'bias_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_metric"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_output_signature"
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.keras.layers.experimental"
tf_module {
member {
name: "EinsumDense"
mtype: "<type \'type\'>"
}
member {
name: "RandomFourierFeatures"
mtype: "<type \'type\'>"

View File

@ -0,0 +1,218 @@
path: "tensorflow.keras.layers.experimental.EinsumDense"
tf_class {
is_instance: "<class \'tensorflow.python.keras.layers.einsum_dense.EinsumDense\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
is_instance: "<type \'object\'>"
member {
name: "activity_regularizer"
mtype: "<type \'property\'>"
}
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic"
mtype: "<type \'property\'>"
}
member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "input"
mtype: "<type \'property\'>"
}
member {
name: "input_mask"
mtype: "<type \'property\'>"
}
member {
name: "input_shape"
mtype: "<type \'property\'>"
}
member {
name: "input_spec"
mtype: "<type \'property\'>"
}
member {
name: "losses"
mtype: "<type \'property\'>"
}
member {
name: "metrics"
mtype: "<type \'property\'>"
}
member {
name: "name"
mtype: "<type \'property\'>"
}
member {
name: "name_scope"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "non_trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "outbound_nodes"
mtype: "<type \'property\'>"
}
member {
name: "output"
mtype: "<type \'property\'>"
}
member {
name: "output_mask"
mtype: "<type \'property\'>"
}
member {
name: "output_shape"
mtype: "<type \'property\'>"
}
member {
name: "stateful"
mtype: "<type \'property\'>"
}
member {
name: "submodules"
mtype: "<type \'property\'>"
}
member {
name: "trainable"
mtype: "<type \'property\'>"
}
member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
member {
name: "trainable_weights"
mtype: "<type \'property\'>"
}
member {
name: "updates"
mtype: "<type \'property\'>"
}
member {
name: "variables"
mtype: "<type \'property\'>"
}
member {
name: "weights"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'equation\', \'output_shape\', \'activation\', \'bias_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "add_metric"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
}
member_method {
name: "add_update"
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "add_variable"
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "add_weight"
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "build"
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_mask"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_output_shape"
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "compute_output_signature"
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "count_params"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_input_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_losses_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_mask_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_output_shape_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates_for"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_weights"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "with_name_scope"
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.keras.layers.experimental"
tf_module {
member {
name: "EinsumDense"
mtype: "<type \'type\'>"
}
member {
name: "RandomFourierFeatures"
mtype: "<type \'type\'>"