Merge pull request #39516 from lgeiger:resubmit-keras-grouped-conv
PiperOrigin-RevId: 313770233 Change-Id: I8842ffe4c2e8a2293e6b708c915b7fbb98a923de
This commit is contained in:
commit
e0dc15cec6
|
@ -78,6 +78,11 @@ class Conv(Layer):
|
|||
the dilation rate to use for dilated convolution.
|
||||
Currently, specifying any `dilation_rate` value != 1 is
|
||||
incompatible with specifying any `strides` value != 1.
|
||||
groups: A positive integer specifying the number of groups in which the
|
||||
input is split along the channel axis. Each group is convolved
|
||||
separately with `filters / groups` filters. The output is the
|
||||
concatenation of all the `groups` results along the channel axis.
|
||||
Input channels and `filters` must both be divisible by `groups`.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied.
|
||||
use_bias: Boolean, whether the layer uses a bias.
|
||||
|
@ -100,13 +105,15 @@ class Conv(Layer):
|
|||
name: A string, the name of the layer.
|
||||
"""
|
||||
|
||||
def __init__(self, rank,
|
||||
def __init__(self,
|
||||
rank,
|
||||
filters,
|
||||
kernel_size,
|
||||
strides=1,
|
||||
padding='valid',
|
||||
data_format=None,
|
||||
dilation_rate=1,
|
||||
groups=1,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
kernel_initializer='glorot_uniform',
|
||||
|
@ -128,6 +135,11 @@ class Conv(Layer):
|
|||
if filters is not None and not isinstance(filters, int):
|
||||
filters = int(filters)
|
||||
self.filters = filters
|
||||
self.groups = groups or 1
|
||||
if filters is not None and filters % self.groups != 0:
|
||||
raise ValueError(
|
||||
'The number of filters must be evenly divisible by the number of '
|
||||
'groups. Received: groups={}, filters={}'.format(groups, filters))
|
||||
self.kernel_size = conv_utils.normalize_tuple(
|
||||
kernel_size, rank, 'kernel_size')
|
||||
if not all(self.kernel_size):
|
||||
|
@ -155,7 +167,14 @@ class Conv(Layer):
|
|||
def build(self, input_shape):
|
||||
input_shape = tensor_shape.TensorShape(input_shape)
|
||||
input_channel = self._get_input_channel(input_shape)
|
||||
kernel_shape = self.kernel_size + (input_channel, self.filters)
|
||||
if input_channel % self.groups != 0:
|
||||
raise ValueError(
|
||||
'The number of input channels must be evenly divisible by the number '
|
||||
'of groups. Received groups={}, but the input has {} channels '
|
||||
'(full input shape is {}).'.format(self.groups, input_channel,
|
||||
input_shape))
|
||||
kernel_shape = self.kernel_size + (input_channel // self.groups,
|
||||
self.filters)
|
||||
|
||||
self.kernel = self.add_weight(
|
||||
name='kernel',
|
||||
|
@ -250,22 +269,38 @@ class Conv(Layer):
|
|||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'filters': self.filters,
|
||||
'kernel_size': self.kernel_size,
|
||||
'strides': self.strides,
|
||||
'padding': self.padding,
|
||||
'data_format': self.data_format,
|
||||
'dilation_rate': self.dilation_rate,
|
||||
'activation': activations.serialize(self.activation),
|
||||
'use_bias': self.use_bias,
|
||||
'kernel_initializer': initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer': initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
|
||||
'filters':
|
||||
self.filters,
|
||||
'kernel_size':
|
||||
self.kernel_size,
|
||||
'strides':
|
||||
self.strides,
|
||||
'padding':
|
||||
self.padding,
|
||||
'data_format':
|
||||
self.data_format,
|
||||
'dilation_rate':
|
||||
self.dilation_rate,
|
||||
'groups':
|
||||
self.groups,
|
||||
'activation':
|
||||
activations.serialize(self.activation),
|
||||
'use_bias':
|
||||
self.use_bias,
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer':
|
||||
initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer':
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer':
|
||||
regularizers.serialize(self.bias_regularizer),
|
||||
'activity_regularizer':
|
||||
regularizers.serialize(self.activity_regularizer),
|
||||
'kernel_constraint': constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint': constraints.serialize(self.bias_constraint)
|
||||
'kernel_constraint':
|
||||
constraints.serialize(self.kernel_constraint),
|
||||
'bias_constraint':
|
||||
constraints.serialize(self.bias_constraint)
|
||||
}
|
||||
base_config = super(Conv, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
@ -371,6 +406,11 @@ class Conv1D(Conv):
|
|||
the dilation rate to use for dilated convolution.
|
||||
Currently, specifying any `dilation_rate` value != 1 is
|
||||
incompatible with specifying any `strides` value != 1.
|
||||
groups: A positive integer specifying the number of groups in which the
|
||||
input is split along the channel axis. Each group is convolved
|
||||
separately with `filters / groups` filters. The output is the
|
||||
concatenation of all the `groups` results along the channel axis.
|
||||
Input channels and `filters` must both be divisible by `groups`.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied (
|
||||
see `keras.activations`).
|
||||
|
@ -413,6 +453,7 @@ class Conv1D(Conv):
|
|||
padding='valid',
|
||||
data_format='channels_last',
|
||||
dilation_rate=1,
|
||||
groups=1,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
kernel_initializer='glorot_uniform',
|
||||
|
@ -431,6 +472,7 @@ class Conv1D(Conv):
|
|||
padding=padding,
|
||||
data_format=data_format,
|
||||
dilation_rate=dilation_rate,
|
||||
groups=groups,
|
||||
activation=activations.get(activation),
|
||||
use_bias=use_bias,
|
||||
kernel_initializer=initializers.get(kernel_initializer),
|
||||
|
@ -517,6 +559,11 @@ class Conv2D(Conv):
|
|||
all spatial dimensions.
|
||||
Currently, specifying any `dilation_rate` value != 1 is
|
||||
incompatible with specifying any stride value != 1.
|
||||
groups: A positive integer specifying the number of groups in which the
|
||||
input is split along the channel axis. Each group is convolved
|
||||
separately with `filters / groups` filters. The output is the
|
||||
concatenation of all the `groups` results along the channel axis.
|
||||
Input channels and `filters` must both be divisible by `groups`.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied (
|
||||
see `keras.activations`).
|
||||
|
@ -566,6 +613,7 @@ class Conv2D(Conv):
|
|||
padding='valid',
|
||||
data_format=None,
|
||||
dilation_rate=(1, 1),
|
||||
groups=1,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
kernel_initializer='glorot_uniform',
|
||||
|
@ -584,6 +632,7 @@ class Conv2D(Conv):
|
|||
padding=padding,
|
||||
data_format=data_format,
|
||||
dilation_rate=dilation_rate,
|
||||
groups=groups,
|
||||
activation=activations.get(activation),
|
||||
use_bias=use_bias,
|
||||
kernel_initializer=initializers.get(kernel_initializer),
|
||||
|
@ -655,6 +704,11 @@ class Conv3D(Conv):
|
|||
all spatial dimensions.
|
||||
Currently, specifying any `dilation_rate` value != 1 is
|
||||
incompatible with specifying any stride value != 1.
|
||||
groups: A positive integer specifying the number of groups in which the
|
||||
input is split along the channel axis. Each group is convolved
|
||||
separately with `filters / groups` filters. The output is the
|
||||
concatenation of all the `groups` results along the channel axis.
|
||||
Input channels and `filters` must both be divisible by `groups`.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied (
|
||||
see `keras.activations`).
|
||||
|
@ -710,6 +764,7 @@ class Conv3D(Conv):
|
|||
padding='valid',
|
||||
data_format=None,
|
||||
dilation_rate=(1, 1, 1),
|
||||
groups=1,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
kernel_initializer='glorot_uniform',
|
||||
|
@ -728,6 +783,7 @@ class Conv3D(Conv):
|
|||
padding=padding,
|
||||
data_format=data_format,
|
||||
dilation_rate=dilation_rate,
|
||||
groups=groups,
|
||||
activation=activations.get(activation),
|
||||
use_bias=use_bias,
|
||||
kernel_initializer=initializers.get(kernel_initializer),
|
||||
|
|
|
@ -26,6 +26,9 @@ from tensorflow.python.eager import context
|
|||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
@ -47,20 +50,41 @@ class Conv1DTest(keras_parameterized.TestCase):
|
|||
expected_output_shape=expected_output_shape)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('padding_valid', {'padding': 'valid'}, (None, 5, 2)),
|
||||
('padding_same', {'padding': 'same'}, (None, 7, 2)),
|
||||
('padding_same_dilation_2', {'padding': 'same', 'dilation_rate': 2},
|
||||
(None, 7, 2)),
|
||||
('padding_same_dilation_3', {'padding': 'same', 'dilation_rate': 3},
|
||||
(None, 7, 2)),
|
||||
('padding_causal', {'padding': 'causal'}, (None, 7, 2)),
|
||||
('strides', {'strides': 2}, (None, 3, 2)),
|
||||
('dilation_rate', {'dilation_rate': 2}, (None, 3, 2)),
|
||||
('padding_valid', {
|
||||
'padding': 'valid'
|
||||
}, (None, 5, 2)),
|
||||
('padding_same', {
|
||||
'padding': 'same'
|
||||
}, (None, 7, 2)),
|
||||
('padding_same_dilation_2', {
|
||||
'padding': 'same',
|
||||
'dilation_rate': 2
|
||||
}, (None, 7, 2)),
|
||||
('padding_same_dilation_3', {
|
||||
'padding': 'same',
|
||||
'dilation_rate': 3
|
||||
}, (None, 7, 2)),
|
||||
('padding_causal', {
|
||||
'padding': 'causal'
|
||||
}, (None, 7, 2)),
|
||||
('strides', {
|
||||
'strides': 2
|
||||
}, (None, 3, 2)),
|
||||
('dilation_rate', {
|
||||
'dilation_rate': 2
|
||||
}, (None, 3, 2)),
|
||||
# Only runs on GPU with CUDA, groups are not supported on CPU.
|
||||
# https://github.com/tensorflow/tensorflow/issues/29005
|
||||
('group', {
|
||||
'groups': 3,
|
||||
'filters': 6
|
||||
}, (None, 5, 6), True),
|
||||
)
|
||||
def test_conv1d(self, kwargs, expected_output_shape):
|
||||
kwargs['filters'] = 2
|
||||
def test_conv1d(self, kwargs, expected_output_shape, requires_gpu=False):
|
||||
kwargs['filters'] = kwargs.get('filters', 2)
|
||||
kwargs['kernel_size'] = 3
|
||||
self._run_test(kwargs, expected_output_shape)
|
||||
if not requires_gpu or test.is_gpu_available(cuda_only=True):
|
||||
self._run_test(kwargs, expected_output_shape)
|
||||
|
||||
def test_conv1d_regularizers(self):
|
||||
kwargs = {
|
||||
|
@ -148,20 +172,38 @@ class Conv2DTest(keras_parameterized.TestCase):
|
|||
expected_output_shape=expected_output_shape)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('padding_valid', {'padding': 'valid'}, (None, 5, 4, 2)),
|
||||
('padding_same', {'padding': 'same'}, (None, 7, 6, 2)),
|
||||
('padding_same_dilation_2', {'padding': 'same', 'dilation_rate': 2},
|
||||
(None, 7, 6, 2)),
|
||||
('strides', {'strides': (2, 2)}, (None, 3, 2, 2)),
|
||||
('dilation_rate', {'dilation_rate': (2, 2)}, (None, 3, 2, 2)),
|
||||
('padding_valid', {
|
||||
'padding': 'valid'
|
||||
}, (None, 5, 4, 2)),
|
||||
('padding_same', {
|
||||
'padding': 'same'
|
||||
}, (None, 7, 6, 2)),
|
||||
('padding_same_dilation_2', {
|
||||
'padding': 'same',
|
||||
'dilation_rate': 2
|
||||
}, (None, 7, 6, 2)),
|
||||
('strides', {
|
||||
'strides': (2, 2)
|
||||
}, (None, 3, 2, 2)),
|
||||
('dilation_rate', {
|
||||
'dilation_rate': (2, 2)
|
||||
}, (None, 3, 2, 2)),
|
||||
# Only runs on GPU with CUDA, channels_first is not supported on CPU.
|
||||
# TODO(b/62340061): Support channels_first on CPU.
|
||||
('data_format', {'data_format': 'channels_first'}),
|
||||
('data_format', {
|
||||
'data_format': 'channels_first'
|
||||
}, None, True),
|
||||
# Only runs on GPU with CUDA, groups are not supported on CPU.
|
||||
# https://github.com/tensorflow/tensorflow/issues/29005
|
||||
('group', {
|
||||
'groups': 3,
|
||||
'filters': 6
|
||||
}, (None, 5, 4, 6), True),
|
||||
)
|
||||
def test_conv2d(self, kwargs, expected_output_shape=None):
|
||||
kwargs['filters'] = 2
|
||||
def test_conv2d(self, kwargs, expected_output_shape=None, requires_gpu=False):
|
||||
kwargs['filters'] = kwargs.get('filters', 2)
|
||||
kwargs['kernel_size'] = (3, 3)
|
||||
if 'data_format' not in kwargs or test.is_gpu_available(cuda_only=True):
|
||||
if not requires_gpu or test.is_gpu_available(cuda_only=True):
|
||||
self._run_test(kwargs, expected_output_shape)
|
||||
|
||||
def test_conv2d_regularizers(self):
|
||||
|
@ -208,7 +250,7 @@ class Conv2DTest(keras_parameterized.TestCase):
|
|||
@keras_parameterized.run_all_keras_modes
|
||||
class Conv3DTest(keras_parameterized.TestCase):
|
||||
|
||||
def _run_test(self, kwargs, expected_output_shape):
|
||||
def _run_test(self, kwargs, expected_output_shape, validate_training=True):
|
||||
num_samples = 2
|
||||
stack_size = 3
|
||||
num_row = 7
|
||||
|
@ -220,22 +262,41 @@ class Conv3DTest(keras_parameterized.TestCase):
|
|||
keras.layers.Conv3D,
|
||||
kwargs=kwargs,
|
||||
input_shape=(num_samples, depth, num_row, num_col, stack_size),
|
||||
expected_output_shape=expected_output_shape)
|
||||
expected_output_shape=expected_output_shape,
|
||||
validate_training=validate_training)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('padding_valid', {'padding': 'valid'}, (None, 3, 5, 4, 2)),
|
||||
('padding_same', {'padding': 'same'}, (None, 5, 7, 6, 2)),
|
||||
('strides', {'strides': (2, 2, 2)}, (None, 2, 3, 2, 2)),
|
||||
('dilation_rate', {'dilation_rate': (2, 2, 2)}, (None, 1, 3, 2, 2)),
|
||||
('padding_valid', {
|
||||
'padding': 'valid'
|
||||
}, (None, 3, 5, 4, 2)),
|
||||
('padding_same', {
|
||||
'padding': 'same'
|
||||
}, (None, 5, 7, 6, 2)),
|
||||
('strides', {
|
||||
'strides': (2, 2, 2)
|
||||
}, (None, 2, 3, 2, 2)),
|
||||
('dilation_rate', {
|
||||
'dilation_rate': (2, 2, 2)
|
||||
}, (None, 1, 3, 2, 2)),
|
||||
# Only runs on GPU with CUDA, channels_first is not supported on CPU.
|
||||
# TODO(b/62340061): Support channels_first on CPU.
|
||||
('data_format', {'data_format': 'channels_first'}),
|
||||
('data_format', {
|
||||
'data_format': 'channels_first'
|
||||
}, None, True),
|
||||
# Only runs on GPU with CUDA, groups are not supported on CPU.
|
||||
# https://github.com/tensorflow/tensorflow/issues/29005
|
||||
('group', {
|
||||
'groups': 3,
|
||||
'filters': 6
|
||||
}, (None, 3, 5, 4, 6), True),
|
||||
)
|
||||
def test_conv3d(self, kwargs, expected_output_shape=None):
|
||||
kwargs['filters'] = 2
|
||||
def test_conv3d(self, kwargs, expected_output_shape=None, requires_gpu=False):
|
||||
kwargs['filters'] = kwargs.get('filters', 2)
|
||||
kwargs['kernel_size'] = (3, 3, 3)
|
||||
if 'data_format' not in kwargs or test.is_gpu_available(cuda_only=True):
|
||||
self._run_test(kwargs, expected_output_shape)
|
||||
# train_on_batch currently fails with XLA enabled on GPUs
|
||||
test_training = 'groups' not in kwargs or not test_util.is_xla_enabled()
|
||||
if not requires_gpu or test.is_gpu_available(cuda_only=True):
|
||||
self._run_test(kwargs, expected_output_shape, test_training)
|
||||
|
||||
def test_conv3d_regularizers(self):
|
||||
kwargs = {
|
||||
|
@ -298,6 +359,57 @@ class Conv3DTest(keras_parameterized.TestCase):
|
|||
input_data=input_data)
|
||||
|
||||
|
||||
class GroupedConvTest(keras_parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('Conv1D', keras.layers.Conv1D),
|
||||
('Conv2D', keras.layers.Conv2D),
|
||||
('Conv3D', keras.layers.Conv3D),
|
||||
)
|
||||
def test_group_conv_incorrect_use(self, layer):
|
||||
with self.assertRaisesRegexp(ValueError, 'The number of filters'):
|
||||
layer(16, 3, groups=3)
|
||||
with self.assertRaisesRegexp(ValueError, 'The number of input channels'):
|
||||
layer(16, 3, groups=4).build((32, 12, 12, 3))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('Conv1D', keras.layers.Conv1D, (32, 12, 32)),
|
||||
('Conv2D', keras.layers.Conv2D, (32, 12, 12, 32)),
|
||||
('Conv3D', keras.layers.Conv3D, (32, 12, 12, 12, 32)),
|
||||
)
|
||||
def test_group_conv(self, layer_cls, input_shape):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
with test_util.use_gpu():
|
||||
inputs = random_ops.random_uniform(shape=input_shape)
|
||||
|
||||
layer = layer_cls(16, 3, groups=4, use_bias=False)
|
||||
layer.build(input_shape)
|
||||
|
||||
input_slices = array_ops.split(inputs, 4, axis=-1)
|
||||
weight_slices = array_ops.split(layer.kernel, 4, axis=-1)
|
||||
expected_outputs = array_ops.concat([
|
||||
nn.convolution_v2(inputs, weights)
|
||||
for inputs, weights in zip(input_slices, weight_slices)
|
||||
],
|
||||
axis=-1)
|
||||
|
||||
self.assertAllClose(layer(inputs), expected_outputs, rtol=1e-5)
|
||||
|
||||
def test_group_conv_depthwise(self):
|
||||
if test.is_gpu_available(cuda_only=True):
|
||||
with test_util.use_gpu():
|
||||
inputs = random_ops.random_uniform(shape=(3, 27, 27, 32))
|
||||
|
||||
layer = keras.layers.Conv2D(32, 3, groups=32, use_bias=False)
|
||||
layer.build((3, 27, 27, 32))
|
||||
|
||||
weights_dw = array_ops.reshape(layer.kernel, [3, 3, 32, 1])
|
||||
expected_outputs = nn.depthwise_conv2d(
|
||||
inputs, weights_dw, strides=[1, 1, 1, 1], padding='VALID')
|
||||
|
||||
self.assertAllClose(layer(inputs), expected_outputs, rtol=1e-5)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class Conv1DTransposeTest(keras_parameterized.TestCase):
|
||||
|
||||
|
|
|
@ -1242,12 +1242,12 @@ class Convolution(object):
|
|||
spatial_dims = range(
|
||||
num_batch_dims + 1, num_spatial_dims + num_batch_dims + 1)
|
||||
|
||||
if not input_channels_dim.is_compatible_with(
|
||||
filter_shape[num_spatial_dims]):
|
||||
raise ValueError(
|
||||
"Number of input channels does not match corresponding dimension of "
|
||||
"filter, {} != {}".format(input_channels_dim,
|
||||
filter_shape[num_spatial_dims]))
|
||||
filter_dim = tensor_shape.dimension_at_index(filter_shape, num_spatial_dims)
|
||||
if not (input_channels_dim % filter_dim).is_compatible_with(0):
|
||||
raise ValueError("The number of input channels is not divisible by the "
|
||||
"corresponding number of output filters. Received: "
|
||||
"input channels={}, output filters={}".format(
|
||||
input_channels_dim, filter_dim))
|
||||
|
||||
strides, dilation_rate = _get_strides_and_dilation_rate(
|
||||
num_spatial_dims, strides, dilation_rate)
|
||||
|
@ -2051,9 +2051,9 @@ def conv2d_v2(input, # pylint: disable=redefined-builtin
|
|||
|
||||
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
|
||||
horizontal and vertical strides, `strides = [1, stride, stride, 1]`.
|
||||
|
||||
|
||||
Usage Example:
|
||||
|
||||
|
||||
>>> x_in = np.array([[
|
||||
... [[2], [1], [2], [0], [1]],
|
||||
... [[1], [3], [2], [2], [3]],
|
||||
|
@ -3551,7 +3551,7 @@ def softmax(logits, axis=None, name=None, dim=None):
|
|||
Tensor.
|
||||
RuntimeError: If a registered conversion function returns an invalid
|
||||
value.
|
||||
|
||||
|
||||
"""
|
||||
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
|
||||
if axis is None:
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'valid\', \'channels_last\', \'1\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
|
@ -114,7 +114,7 @@ tf_class {
|
|||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\', \'groups\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'(1, 1, 1)\', \'1\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "add_loss"
|
||||
|
|
Loading…
Reference in New Issue