Add an einsum layer.
PiperOrigin-RevId: 309231226 Change-Id: Icbadc1f420e06ea2222d519e93ac8fe377d08527
This commit is contained in:
parent
63aa02c5f3
commit
16748305e8
@ -34,6 +34,7 @@ py_library(
|
||||
":core",
|
||||
":cudnn_recurrent",
|
||||
":dense_attention",
|
||||
":einsum_dense",
|
||||
":embeddings",
|
||||
":kernelized",
|
||||
":local",
|
||||
@ -187,6 +188,22 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "einsum_dense",
|
||||
srcs = ["einsum_dense.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:special_math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/keras:activations",
|
||||
"//tensorflow/python/keras:base_layer",
|
||||
"//tensorflow/python/keras:constraints",
|
||||
"//tensorflow/python/keras:initializers",
|
||||
"//tensorflow/python/keras:regularizers",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "embeddings",
|
||||
srcs = ["embeddings.py"],
|
||||
@ -581,6 +598,18 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "einsum_dense_test",
|
||||
srcs = ["einsum_dense_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":einsum_dense",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "local_test",
|
||||
size = "medium",
|
||||
|
@ -119,6 +119,9 @@ from tensorflow.python.keras.layers.dense_attention import Attention
|
||||
# Embedding layers.
|
||||
from tensorflow.python.keras.layers.embeddings import Embedding
|
||||
|
||||
# Einsum-based dense layer/
|
||||
from tensorflow.python.keras.layers.einsum_dense import EinsumDense
|
||||
|
||||
# Locally-connected layers.
|
||||
from tensorflow.python.keras.layers.local import LocallyConnected1D
|
||||
from tensorflow.python.keras.layers.local import LocallyConnected2D
|
||||
|
337
tensorflow/python/keras/layers/einsum_dense.py
Normal file
337
tensorflow/python/keras/layers/einsum_dense.py
Normal file
@ -0,0 +1,337 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras-based einsum dense layer."""
|
||||
# pylint: disable=g-classes-have-attributes
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import activations
|
||||
from tensorflow.python.keras import constraints
|
||||
from tensorflow.python.keras import initializers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine.base_layer import Layer
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@keras_export("keras.layers.experimental.EinsumDense")
|
||||
class EinsumDense(Layer):
|
||||
"""A layer that uses tf.einsum as the backing computation.
|
||||
|
||||
This layer can perform einsum calculations of arbitrary dimensionality.
|
||||
|
||||
Arguments:
|
||||
equation: An equation describing the einsum to perform. This equation must
|
||||
be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or
|
||||
`ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis
|
||||
expression sequence.
|
||||
output_shape: The expected shape of the output tensor (excluding the batch
|
||||
dimension and any dimensions represented by ellipses). You can specify
|
||||
None for any dimension that is unknown or can be inferred from the input
|
||||
shape.
|
||||
activation: Activation function to use. If you don't specify anything, no
|
||||
activation is applied (that is, a "linear" activation: `a(x) = x`).
|
||||
bias_axes: A string containing the output dimension(s) to apply a bias to.
|
||||
Each character in the `bias_axes` string should correspond to a character
|
||||
in the output portion of the `equation` string.
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix.
|
||||
bias_initializer: Initializer for the bias vector.
|
||||
kernel_regularizer: Regularizer function applied to the `kernel` weights
|
||||
matrix.
|
||||
bias_regularizer: Regularizer function applied to the bias vector.
|
||||
activity_regularizer: Regularizer function applied to the output of the
|
||||
layer (its "activation")..
|
||||
kernel_constraint: Constraint function applied to the `kernel` weights
|
||||
matrix.
|
||||
bias_constraint: Constraint function applied to the bias vector.
|
||||
|
||||
Examples:
|
||||
|
||||
**Biased dense layer with einsums**
|
||||
|
||||
This example shows how to instantiate a standard Keras dense layer using
|
||||
einsum operations. This example is equivalent to
|
||||
`tf.keras.layers.Dense(64, use_bias=True)`.
|
||||
|
||||
>>> layer = EinsumDense("ab,bc->ac", output_shape=64, bias_axes="c")
|
||||
>>> input_tensor = tf.keras.Input(shape=[32])
|
||||
>>> output_tensor = layer(input_tensor)
|
||||
>>> output_tensor
|
||||
<tf.Tensor '...' shape=(None, 64) dtype=...>
|
||||
|
||||
**Applying a dense layer to a sequence**
|
||||
|
||||
This example shows how to instantiate a layer that applies the same dense
|
||||
operation to every element in a sequence. Here, the 'output_shape' has two
|
||||
values (since there are two non-batch dimensions in the output); the first
|
||||
dimension in the output_shape is `None`, because the sequence dimension `b`
|
||||
has an unknown shape.
|
||||
|
||||
>>> layer = EinsumDense("abc,cd->abd",
|
||||
... output_shape=(None, 64),
|
||||
... bias_axes="d")
|
||||
>>> input_tensor = tf.keras.Input(shape=[32, 128])
|
||||
>>> output_tensor = layer(input_tensor)
|
||||
>>> output_tensor
|
||||
<tf.Tensor '...' shape=(None, 32, 64) dtype=...>
|
||||
|
||||
**Applying a dense layer to a sequence using ellipses**
|
||||
|
||||
This example shows how to instantiate a layer that applies the same dense
|
||||
operation to every element in a sequence, but uses the ellipsis notation
|
||||
instead of specifying the batch and sequence dimensions.
|
||||
|
||||
Because we are using ellipsis notation and have specified only one axis, the
|
||||
output_shape arg is a single value. When instantiated in this way, the layer
|
||||
can handle any number of sequence dimensions - including the case where no
|
||||
sequence dimension exists.
|
||||
|
||||
>>> layer = EinsumDense("...x,xy->...y", output_shape=64, bias_axes="y")
|
||||
>>> input_tensor = tf.keras.Input(shape=[32, 128])
|
||||
>>> output_tensor = layer(input_tensor)
|
||||
>>> output_tensor
|
||||
<tf.Tensor '...' shape=(None, 32, 64) dtype=...>
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
equation,
|
||||
output_shape,
|
||||
activation=None,
|
||||
bias_axes=None,
|
||||
kernel_initializer="glorot_uniform",
|
||||
bias_initializer="zeros",
|
||||
kernel_regularizer=None,
|
||||
bias_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
kernel_constraint=None,
|
||||
bias_constraint=None,
|
||||
**kwargs):
|
||||
super(EinsumDense, self).__init__(**kwargs)
|
||||
self.equation = equation
|
||||
if isinstance(output_shape, int):
|
||||
self.partial_output_shape = [output_shape]
|
||||
else:
|
||||
self.partial_output_shape = list(output_shape)
|
||||
self.bias_axes = bias_axes
|
||||
self.activation = activations.get(activation)
|
||||
self.kernel_initializer = initializers.get(kernel_initializer)
|
||||
self.bias_initializer = initializers.get(bias_initializer)
|
||||
self.kernel_regularizer = regularizers.get(kernel_regularizer)
|
||||
self.bias_regularizer = regularizers.get(bias_regularizer)
|
||||
self.kernel_constraint = constraints.get(kernel_constraint)
|
||||
self.bias_constraint = constraints.get(bias_constraint)
|
||||
|
||||
def build(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
shape_data = _analyze_einsum_string(self.equation,
|
||||
self.bias_axes,
|
||||
input_shape,
|
||||
self.partial_output_shape)
|
||||
kernel_shape, bias_shape, self.full_output_shape = shape_data
|
||||
self.kernel = self.add_weight(
|
||||
"kernel",
|
||||
shape=kernel_shape,
|
||||
initializer=self.kernel_initializer,
|
||||
regularizer=self.kernel_regularizer,
|
||||
constraint=self.kernel_constraint,
|
||||
dtype=self.dtype,
|
||||
trainable=True)
|
||||
|
||||
if bias_shape is not None:
|
||||
self.bias = self.add_weight(
|
||||
"bias",
|
||||
shape=bias_shape,
|
||||
initializer=self.bias_initializer,
|
||||
regularizer=self.bias_regularizer,
|
||||
constraint=self.bias_constraint,
|
||||
dtype=self.dtype,
|
||||
trainable=True)
|
||||
else:
|
||||
self.bias = None
|
||||
super(EinsumDense, self).build(input_shape)
|
||||
|
||||
def compute_output_shape(self, _):
|
||||
return tensor_shape.TensorShape(self.full_output_shape)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"output_shape":
|
||||
self.partial_output_shape,
|
||||
"equation":
|
||||
self.equation,
|
||||
"activation":
|
||||
activations.serialize(self.activation),
|
||||
"bias_axes":
|
||||
self.bias_axes,
|
||||
"kernel_initializer":
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
"bias_initializer":
|
||||
initializers.serialize(self.bias_initializer),
|
||||
"kernel_regularizer":
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
"bias_regularizer":
|
||||
regularizers.serialize(self.bias_regularizer),
|
||||
"activity_regularizer":
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
"kernel_constraint":
|
||||
constraints.serialize(self.kernel_constraint),
|
||||
"bias_constraint":
|
||||
constraints.serialize(self.bias_constraint),
|
||||
}
|
||||
base_config = super(EinsumDense, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, inputs):
|
||||
ret = special_math_ops.einsum(self.equation, inputs, self.kernel)
|
||||
if self.bias is not None:
|
||||
ret += self.bias
|
||||
if self.activation is not None:
|
||||
ret = self.activation(ret)
|
||||
return ret
|
||||
|
||||
|
||||
def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
|
||||
"""Analyzes an einsum string to determine the required weight shape."""
|
||||
|
||||
dot_replaced_string = re.sub(r"\.\.\.", "0", equation)
|
||||
|
||||
# This is the case where no ellipses are present in the string.
|
||||
split_string = re.match("([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)",
|
||||
dot_replaced_string)
|
||||
if split_string:
|
||||
return _analyze_split_string(split_string, bias_axes, input_shape,
|
||||
output_shape)
|
||||
|
||||
# This is the case where ellipses are present on the left.
|
||||
split_string = re.match("0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)",
|
||||
dot_replaced_string)
|
||||
if split_string:
|
||||
return _analyze_split_string(
|
||||
split_string, bias_axes, input_shape, output_shape, left_elided=True)
|
||||
|
||||
# This is the case where ellipses are present on the right.
|
||||
split_string = re.match("([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0",
|
||||
dot_replaced_string)
|
||||
if split_string:
|
||||
return _analyze_split_string(split_string, bias_axes, input_shape,
|
||||
output_shape)
|
||||
|
||||
raise ValueError(
|
||||
"Invalid einsum equation '%s'. Equations must be in the form "
|
||||
"[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." % equation)
|
||||
|
||||
|
||||
def _analyze_split_string(split_string,
|
||||
bias_axes,
|
||||
input_shape,
|
||||
output_shape,
|
||||
left_elided=False):
|
||||
"""Analyze an pre-split einsum string to find the weight shape."""
|
||||
input_spec = split_string.group(1)
|
||||
weight_spec = split_string.group(2)
|
||||
output_spec = split_string.group(3)
|
||||
elided = len(input_shape) - len(input_spec)
|
||||
|
||||
if isinstance(output_shape, int):
|
||||
output_shape = [output_shape]
|
||||
else:
|
||||
output_shape = list(output_shape)
|
||||
|
||||
output_shape.insert(0, input_shape[0])
|
||||
|
||||
if elided > 0 and left_elided:
|
||||
for i in range(1, elided):
|
||||
# We already inserted the 0th input dimension at dim 0, so we need to
|
||||
# start at location 1 here.
|
||||
output_shape.insert(1, input_shape[i])
|
||||
elif elided > 0 and not left_elided:
|
||||
for i in range(len(input_shape) - elided, len(input_shape)):
|
||||
output_shape.append(input_shape[i])
|
||||
|
||||
if left_elided:
|
||||
# If we have beginning dimensions elided, we need to use negative indexing
|
||||
# to determine where in the input dimension our values are.
|
||||
input_dim_map = {
|
||||
dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec)
|
||||
}
|
||||
# Because we've constructed the full output shape already, we don't need
|
||||
# to do negative indexing.
|
||||
output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)}
|
||||
else:
|
||||
input_dim_map = {dim: i for i, dim in enumerate(input_spec)}
|
||||
output_dim_map = {dim: i for i, dim in enumerate(output_spec)}
|
||||
|
||||
for i, dim in enumerate(input_spec):
|
||||
input_shape_at_dim = input_shape[i]
|
||||
if dim in output_dim_map:
|
||||
output_shape_at_dim = output_shape[output_dim_map[dim]]
|
||||
if (output_shape_at_dim is not None and
|
||||
output_shape_at_dim != input_shape_at_dim):
|
||||
raise ValueError(
|
||||
"Input shape and output shape do not match at shared "
|
||||
"dimension '%s'. Input shape is %s, and output shape "
|
||||
"is %s." %
|
||||
(dim, input_shape_at_dim, output_shape[output_dim_map[dim]]))
|
||||
|
||||
for dim in output_spec:
|
||||
if dim not in input_spec and dim not in weight_spec:
|
||||
raise ValueError("Dimension '%s' was specified in the output '%s' but "
|
||||
"has no corresponding dim in the input spec '%s' or "
|
||||
"weight spec '%s.'" % (dim, output_spec, input_spec,
|
||||
output_spec))
|
||||
|
||||
weight_shape = []
|
||||
for dim in weight_spec:
|
||||
if dim in input_dim_map:
|
||||
weight_shape.append(input_shape[input_dim_map[dim]])
|
||||
elif dim in output_dim_map:
|
||||
weight_shape.append(output_shape[output_dim_map[dim]])
|
||||
else:
|
||||
raise ValueError("Weight dimension '%s' did not have a match in either "
|
||||
"the input spec '%s' or the output spec '%s'. For this "
|
||||
"layer, the weight must be fully specified." %
|
||||
(dim, input_spec, output_spec))
|
||||
|
||||
if bias_axes is not None:
|
||||
num_left_elided = elided if left_elided else 0
|
||||
idx_map = {
|
||||
char: output_shape[i + num_left_elided]
|
||||
for i, char in enumerate(output_spec)
|
||||
}
|
||||
|
||||
for char in bias_axes:
|
||||
if char not in output_spec:
|
||||
raise ValueError("Bias dimension '%s' was requested, but is not a part "
|
||||
"of the output specification '%s'" %
|
||||
(char, output_spec))
|
||||
|
||||
first_bias_location = min([output_spec.find(char) for char in bias_axes])
|
||||
bias_output_spec = output_spec[first_bias_location:]
|
||||
|
||||
bias_shape = [
|
||||
idx_map[char] if char in bias_axes else 1 for char in bias_output_spec
|
||||
]
|
||||
|
||||
if not left_elided:
|
||||
for _ in range(elided):
|
||||
bias_shape.append(1)
|
||||
else:
|
||||
bias_shape = None
|
||||
|
||||
return weight_shape, bias_shape, output_shape
|
315
tensorflow/python/keras/layers/einsum_dense_test.py
Normal file
315
tensorflow/python/keras/layers/einsum_dense_test.py
Normal file
@ -0,0 +1,315 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Keras-based einsum dense layer."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
|
||||
from tensorflow.python.keras import keras_parameterized # pylint: disable=g-direct-tensorflow-import
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.layers import einsum_dense
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"testcase_name": "_1d_end_weight",
|
||||
"equation": "ab,b->a",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, 32),
|
||||
"output_shape": [],
|
||||
"expected_weight_shape": [32],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None,)
|
||||
}, {
|
||||
"testcase_name": "_2d_middle_weight",
|
||||
"equation": "ab,bc->ac",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, 32),
|
||||
"output_shape": (64),
|
||||
"expected_weight_shape": [32, 64],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, 64)
|
||||
}, {
|
||||
"testcase_name": "_3d_bert",
|
||||
"equation": "abc,cde->abde",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (1, 3, 4),
|
||||
"expected_weight_shape": [2, 3, 4],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, 1, 3, 4)
|
||||
}, {
|
||||
"testcase_name": "_3d_3_bias",
|
||||
"equation": "abc,cde->abde",
|
||||
"bias_axes": "e",
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (1, 3, 4),
|
||||
"expected_weight_shape": [2, 3, 4],
|
||||
"expected_bias_shape": [4],
|
||||
"expected_output_shape": (None, 1, 3, 4)
|
||||
}, {
|
||||
"testcase_name": "_3d_2_bias",
|
||||
"equation": "abc,cde->abde",
|
||||
"bias_axes": "d",
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (1, 3, 4),
|
||||
"expected_weight_shape": [2, 3, 4],
|
||||
"expected_bias_shape": [3, 1],
|
||||
"expected_output_shape": (None, 1, 3, 4)
|
||||
}, {
|
||||
"testcase_name": "_3d_1_3_bias",
|
||||
"equation": "abc,cde->abde",
|
||||
"bias_axes": "be",
|
||||
"input_shape": (None, 7, 2),
|
||||
"output_shape": (7, 3, 4),
|
||||
"expected_weight_shape": [2, 3, 4],
|
||||
"expected_bias_shape": [7, 1, 4],
|
||||
"expected_output_shape": (None, 7, 3, 4)
|
||||
}, {
|
||||
"testcase_name": "_3d_bert_projection",
|
||||
"equation": "BFNH,NHD->BFD",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, 1, 2, 3),
|
||||
"output_shape": (1, 4),
|
||||
"expected_weight_shape": [2, 3, 4],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, 1, 4)
|
||||
}, {
|
||||
"testcase_name": "_2d_bert",
|
||||
"equation": "abc,cd->abd",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (1, 4),
|
||||
"expected_weight_shape": [2, 4],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, 1, 4)
|
||||
}, {
|
||||
"testcase_name": "_embedding_1d",
|
||||
"equation": "i,d->id",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None,),
|
||||
"output_shape": (2),
|
||||
"expected_weight_shape": [2],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, 2)
|
||||
}, {
|
||||
"testcase_name": "_xlnet_lm",
|
||||
"equation": "ibd,nd->ibn",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, None, 1),
|
||||
"output_shape": (None, 2),
|
||||
"expected_weight_shape": [2, 1],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, None, 2)
|
||||
}, {
|
||||
"testcase_name": "_2d_precast",
|
||||
"equation": "...b,bc->...c",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, 32),
|
||||
"output_shape": (64),
|
||||
"expected_weight_shape": [32, 64],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, 64)
|
||||
}, {
|
||||
"testcase_name": "_2d_precast_multiple_elided_dims",
|
||||
"equation": "...b,bc->...c",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, None, 32),
|
||||
"output_shape": (64),
|
||||
"expected_weight_shape": [32, 64],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, None, 64)
|
||||
}, {
|
||||
"testcase_name": "_3d_precast",
|
||||
"equation": "...c,cde->...de",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (3, 4),
|
||||
"expected_weight_shape": [2, 3, 4],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, 1, 3, 4)
|
||||
}, {
|
||||
"testcase_name": "_3d_precast_3_bias",
|
||||
"equation": "...c,cde->...de",
|
||||
"bias_axes": "e",
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (3, 4),
|
||||
"expected_weight_shape": [2, 3, 4],
|
||||
"expected_bias_shape": [4],
|
||||
"expected_output_shape": (None, 1, 3, 4)
|
||||
}, {
|
||||
"testcase_name": "_3d_precast_2_bias",
|
||||
"equation": "...c,cde->...de",
|
||||
"bias_axes": "d",
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (3, 4),
|
||||
"expected_weight_shape": [2, 3, 4],
|
||||
"expected_bias_shape": [3, 1],
|
||||
"expected_output_shape": (None, 1, 3, 4)
|
||||
}, {
|
||||
"testcase_name": "_3d_precast_2_3_bias",
|
||||
"equation": "...c,cde->...de",
|
||||
"bias_axes": "de",
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (3, 4),
|
||||
"expected_weight_shape": [2, 3, 4],
|
||||
"expected_bias_shape": [3, 4],
|
||||
"expected_output_shape": (None, 1, 3, 4)
|
||||
}, {
|
||||
"testcase_name": "_2d_postcast",
|
||||
"equation": "bc...,cd->bd...",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, 1, 2, 3),
|
||||
"output_shape": (4),
|
||||
"expected_weight_shape": [1, 4],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, 4, 2, 3)
|
||||
}, {
|
||||
"testcase_name": "_3d_postcast",
|
||||
"equation": "bc...,cde->bde...",
|
||||
"bias_axes": None,
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (3, 4),
|
||||
"expected_weight_shape": [1, 3, 4],
|
||||
"expected_bias_shape": None,
|
||||
"expected_output_shape": (None, 3, 4, 2)
|
||||
}, {
|
||||
"testcase_name": "_3d_postcast_1_bias",
|
||||
"equation": "bc...,cde->bde...",
|
||||
"bias_axes": "d",
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (3, 4),
|
||||
"expected_weight_shape": [1, 3, 4],
|
||||
"expected_bias_shape": [3, 1, 1],
|
||||
"expected_output_shape": (None, 3, 4, 2)
|
||||
}, {
|
||||
"testcase_name": "_3d_postcast_2_bias",
|
||||
"equation": "bc...,cde->bde...",
|
||||
"bias_axes": "e",
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (3, 4),
|
||||
"expected_weight_shape": [1, 3, 4],
|
||||
"expected_bias_shape": [4, 1],
|
||||
"expected_output_shape": (None, 3, 4, 2)
|
||||
}, {
|
||||
"testcase_name": "_3d_postcast_1_2_bias",
|
||||
"equation": "bc...,cde->bde...",
|
||||
"bias_axes": "de",
|
||||
"input_shape": (None, 1, 2),
|
||||
"output_shape": (3, 4),
|
||||
"expected_weight_shape": [1, 3, 4],
|
||||
"expected_bias_shape": [3, 4, 1],
|
||||
"expected_output_shape": (None, 3, 4, 2)
|
||||
})
|
||||
class TestEinsumDenseLayer(keras_parameterized.TestCase):
|
||||
|
||||
def test_weight_shapes(self, equation, bias_axes, input_shape, output_shape,
|
||||
expected_weight_shape, expected_bias_shape,
|
||||
expected_output_shape):
|
||||
del expected_output_shape # Not used in this test.
|
||||
|
||||
weight_shape, bias_shape, _ = einsum_dense._analyze_einsum_string(
|
||||
equation, bias_axes, input_shape, output_shape)
|
||||
|
||||
self.assertAllEqual(expected_weight_shape, weight_shape)
|
||||
self.assertAllEqual(expected_bias_shape, bias_shape)
|
||||
|
||||
def test_layer_creation(self, equation, bias_axes, input_shape, output_shape,
|
||||
expected_weight_shape, expected_bias_shape,
|
||||
expected_output_shape):
|
||||
# Keras elides the 0-dimension of the input shape when constructing inputs.
|
||||
non_batch_input_shape = list(input_shape)[1:]
|
||||
|
||||
input_tensor = keras.Input(shape=non_batch_input_shape)
|
||||
layer = einsum_dense.EinsumDense(
|
||||
equation=equation, output_shape=output_shape, bias_axes=bias_axes)
|
||||
output_tensor = layer(input_tensor)
|
||||
|
||||
self.assertAllEqual(expected_weight_shape, layer.kernel.shape.as_list())
|
||||
if expected_bias_shape is None:
|
||||
self.assertIsNone(layer.bias)
|
||||
else:
|
||||
self.assertAllEqual(expected_bias_shape, layer.bias.shape.as_list())
|
||||
self.assertAllEqual(expected_output_shape, output_tensor.shape.as_list())
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class TestEinsumLayerAPI(keras_parameterized.TestCase):
|
||||
|
||||
def test_layer_api(self):
|
||||
input_data = np.array([[1.0, 2.0], [3.0, 4.0]])
|
||||
kwargs = {
|
||||
"equation": "...b,bc->...c",
|
||||
"bias_axes": "c",
|
||||
"output_shape": 4,
|
||||
"bias_initializer": keras.initializers.constant(0.03),
|
||||
"kernel_initializer": keras.initializers.constant(0.5),
|
||||
"dtype": input_data.dtype
|
||||
}
|
||||
expected_output = np.array([[1.53, 1.53, 1.53, 1.53],
|
||||
[3.53, 3.53, 3.53, 3.53]])
|
||||
|
||||
output_data = testing_utils.layer_test(
|
||||
einsum_dense.EinsumDense,
|
||||
kwargs=kwargs,
|
||||
input_shape=(None, 2),
|
||||
input_data=input_data)
|
||||
|
||||
self.assertAllClose(expected_output, output_data)
|
||||
|
||||
def test_unspecified_bias_dim_fails(self):
|
||||
input_tensor = keras.Input(shape=(32,))
|
||||
layer = einsum_dense.EinsumDense(
|
||||
equation="ab,bc->ac", output_shape=64, bias_axes="y")
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, ".*is not a part of the output specification.*"):
|
||||
_ = layer(input_tensor)
|
||||
|
||||
def test_incompatible_input_output_shape_fails(self):
|
||||
input_tensor = keras.Input(shape=(32, 64))
|
||||
layer = einsum_dense.EinsumDense(
|
||||
equation="abc,cd->abd", output_shape=(10, 96))
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, ".*Input shape and output shape do not match at shared "
|
||||
"dimension 'b'.*"):
|
||||
_ = layer(input_tensor)
|
||||
|
||||
def test_unspecified_output_dim_fails(self):
|
||||
input_tensor = keras.Input(shape=(32,))
|
||||
layer = einsum_dense.EinsumDense(equation="ab,bc->cd", output_shape=64)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, ".*Dimension 'd' was specified in the output 'cd' but has "
|
||||
"no corresponding dim.*"):
|
||||
_ = layer(input_tensor)
|
||||
|
||||
def test_unspecified_weight_dim_fails(self):
|
||||
input_tensor = keras.Input(shape=(32,))
|
||||
layer = einsum_dense.EinsumDense(equation="ab,zd->ad", output_shape=64)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, ".*Weight dimension 'z' did not have a match "):
|
||||
_ = layer(input_tensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -33,6 +33,7 @@ from tensorflow.python.keras.layers import convolutional_recurrent
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.keras.layers import cudnn_recurrent
|
||||
from tensorflow.python.keras.layers import dense_attention
|
||||
from tensorflow.python.keras.layers import einsum_dense
|
||||
from tensorflow.python.keras.layers import embeddings
|
||||
from tensorflow.python.keras.layers import local
|
||||
from tensorflow.python.keras.layers import merge
|
||||
@ -52,26 +53,11 @@ from tensorflow.python.util import tf_inspect as inspect
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
ALL_MODULES = (
|
||||
base_layer,
|
||||
input_layer,
|
||||
advanced_activations,
|
||||
convolutional,
|
||||
convolutional_recurrent,
|
||||
core,
|
||||
cudnn_recurrent,
|
||||
dense_attention,
|
||||
embeddings,
|
||||
local,
|
||||
merge,
|
||||
noise,
|
||||
normalization,
|
||||
pooling,
|
||||
image_preprocessing,
|
||||
preprocessing_normalization_v1,
|
||||
recurrent,
|
||||
wrappers
|
||||
)
|
||||
ALL_MODULES = (base_layer, input_layer, advanced_activations, convolutional,
|
||||
convolutional_recurrent, core, cudnn_recurrent, dense_attention,
|
||||
embeddings, einsum_dense, local, merge, noise, normalization,
|
||||
pooling, image_preprocessing, preprocessing_normalization_v1,
|
||||
recurrent, wrappers)
|
||||
ALL_V2_MODULES = (
|
||||
rnn_cell_wrapper_v2,
|
||||
normalization_v2,
|
||||
|
@ -0,0 +1,218 @@
|
||||
path: "tensorflow.keras.layers.experimental.EinsumDense"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.einsum_dense.EinsumDense\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "activity_regularizer"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dynamic"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "inbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "metrics"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "outbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "stateful"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "submodules"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'equation\', \'output_shape\', \'activation\', \'bias_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_metric"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_signature"
|
||||
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "count_params"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_name_scope"
|
||||
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.keras.layers.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "EinsumDense"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RandomFourierFeatures"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -0,0 +1,218 @@
|
||||
path: "tensorflow.keras.layers.experimental.EinsumDense"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.keras.layers.einsum_dense.EinsumDense\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
|
||||
is_instance: "<class \'tensorflow.python.module.module.Module\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.tracking.AutoTrackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
|
||||
is_instance: "<class \'tensorflow.python.keras.utils.version_utils.LayerVersionSelector\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "activity_regularizer"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dynamic"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "inbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "input_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "losses"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "metrics"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "non_trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "outbound_nodes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_mask"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "output_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "stateful"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "submodules"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "trainable_weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "updates"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "variables"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "weights"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'equation\', \'output_shape\', \'activation\', \'bias_axes\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
argspec: "args=[\'self\', \'losses\'], varargs=None, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_metric"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_update"
|
||||
argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_variable"
|
||||
argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "add_weight"
|
||||
argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "apply"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "build"
|
||||
argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_shape"
|
||||
argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "compute_output_signature"
|
||||
argspec: "args=[\'self\', \'input_signature\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "count_params"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_input_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_losses_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_mask_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_output_shape_at"
|
||||
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates_for"
|
||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_weights"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "set_weights"
|
||||
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "with_name_scope"
|
||||
argspec: "args=[\'cls\', \'method\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
path: "tensorflow.keras.layers.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "EinsumDense"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RandomFourierFeatures"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
x
Reference in New Issue
Block a user